aboutsummaryrefslogtreecommitdiff
path: root/src/main/java/com/android/tools/r8/ir/desugar/LambdaRewriter.java
blob: 10908f0388d53fb47dd17ce278bc39c760a3c418 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
// Copyright (c) 2017, the R8 project authors. Please see the AUTHORS file
// for details. All rights reserved. Use of this source code is governed by a
// BSD-style license that can be found in the LICENSE file.

package com.android.tools.r8.ir.desugar;

import com.android.tools.r8.dex.Constants;
import com.android.tools.r8.graph.DexApplication.Builder;
import com.android.tools.r8.graph.DexCallSite;
import com.android.tools.r8.graph.DexEncodedMethod;
import com.android.tools.r8.graph.DexItemFactory;
import com.android.tools.r8.graph.DexMethod;
import com.android.tools.r8.graph.DexProgramClass;
import com.android.tools.r8.graph.DexProto;
import com.android.tools.r8.graph.DexString;
import com.android.tools.r8.graph.DexType;
import com.android.tools.r8.ir.code.BasicBlock;
import com.android.tools.r8.ir.code.IRCode;
import com.android.tools.r8.ir.code.Instruction;
import com.android.tools.r8.ir.code.InstructionListIterator;
import com.android.tools.r8.ir.code.InvokeCustom;
import com.android.tools.r8.ir.code.InvokeDirect;
import com.android.tools.r8.ir.code.MemberType;
import com.android.tools.r8.ir.code.MoveType;
import com.android.tools.r8.ir.code.NewInstance;
import com.android.tools.r8.ir.code.StaticGet;
import com.android.tools.r8.ir.code.Value;
import com.android.tools.r8.ir.conversion.IRConverter;
import java.util.ArrayList;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;

/**
 * Lambda desugaring rewriter.
 *
 * Performs lambda instantiation point matching,
 * lambda class generation, and instruction patching.
 */
public class LambdaRewriter {
  private static final String METAFACTORY_TYPE_DESCR = "Ljava/lang/invoke/LambdaMetafactory;";
  private static final String CALLSITE_TYPE_DESCR = "Ljava/lang/invoke/CallSite;";
  private static final String LOOKUP_TYPE_DESCR = "Ljava/lang/invoke/MethodHandles$Lookup;";
  private static final String METHODTYPE_TYPE_DESCR = "Ljava/lang/invoke/MethodType;";
  private static final String METHODHANDLE_TYPE_DESCR = "Ljava/lang/invoke/MethodHandle;";
  private static final String OBJECT_ARRAY_TYPE_DESCR = "[Ljava/lang/Object;";
  private static final String SERIALIZABLE_TYPE_DESCR = "Ljava/io/Serializable;";
  private static final String SERIALIZED_LAMBDA_TYPE_DESCR = "Ljava/lang/invoke/SerializedLambda;";

  private static final String METAFACTORY_METHOD_NAME = "metafactory";
  private static final String METAFACTORY_ALT_METHOD_NAME = "altMetafactory";
  private static final String DESERIALIZE_LAMBDA_METHOD_NAME = "$deserializeLambda$";

  static final String LAMBDA_CLASS_NAME_PREFIX = "-$$Lambda$";
  static final String EXPECTED_LAMBDA_METHOD_PREFIX = "lambda$";
  static final DexType[] EMPTY_TYPE_ARRAY = new DexType[0];
  static final String LAMBDA_INSTANCE_FIELD_NAME = "INSTANCE";

  final IRConverter converter;
  final DexItemFactory factory;

  final DexMethod metafactoryMethod;
  final DexMethod objectInitMethod;

  final DexMethod metafactoryAltMethod;
  final DexType serializableType;

  final DexString constructorName;
  final DexString classConstructorName;
  final DexString instanceFieldName;

  final DexString deserializeLambdaMethodName;
  final DexProto deserializeLambdaMethodProto;

  // Maps call sites seen so far to inferred lambda descriptor. It is intended
  // to help avoid re-matching call sites we already seen. Note that same call
  // site may match one or several lambda classes.
  //
  // NOTE: synchronize concurrent access on `knownCallSites`.
  private final Map<DexCallSite, LambdaDescriptor> knownCallSites = new IdentityHashMap<>();

  // Maps lambda class type into lambda class representation. Since lambda class
  // type uniquely defines lambda class, effectively canonicalizes lambda classes.
  // NOTE: synchronize concurrent access on `knownLambdaClasses`.
  private final Map<DexType, LambdaClass> knownLambdaClasses = new IdentityHashMap<>();

  // Checks if the type starts with lambda-class prefix.
  public static boolean hasLambdaClassPrefix(DexType clazz) {
    return clazz.getName().startsWith(LAMBDA_CLASS_NAME_PREFIX);
  }

  public LambdaRewriter(IRConverter converter) {
    assert converter != null;
    this.converter = converter;
    this.factory = converter.application.dexItemFactory;

    DexType metafactoryType = factory.createType(METAFACTORY_TYPE_DESCR);
    DexType callSiteType = factory.createType(CALLSITE_TYPE_DESCR);
    DexType lookupType = factory.createType(LOOKUP_TYPE_DESCR);
    DexType methodTypeType = factory.createType(METHODTYPE_TYPE_DESCR);
    DexType methodHandleType = factory.createType(METHODHANDLE_TYPE_DESCR);
    DexType objectArrayType = factory.createType(OBJECT_ARRAY_TYPE_DESCR);

    this.metafactoryMethod = factory.createMethod(metafactoryType,
        factory.createProto(callSiteType, new DexType[] {
            lookupType, factory.stringType, methodTypeType,
            methodTypeType, methodHandleType, methodTypeType
        }),
        factory.createString(METAFACTORY_METHOD_NAME));

    this.metafactoryAltMethod = factory.createMethod(metafactoryType,
        factory.createProto(callSiteType, new DexType[] {
            lookupType, factory.stringType, methodTypeType, objectArrayType
        }),
        factory.createString(METAFACTORY_ALT_METHOD_NAME));

    this.constructorName = factory.createString(Constants.INSTANCE_INITIALIZER_NAME);
    DexProto initProto = factory.createProto(factory.voidType, EMPTY_TYPE_ARRAY);
    this.objectInitMethod = factory.createMethod(factory.objectType, initProto, constructorName);
    this.classConstructorName = factory.createString(Constants.CLASS_INITIALIZER_NAME);
    this.instanceFieldName = factory.createString(LAMBDA_INSTANCE_FIELD_NAME);
    this.serializableType = factory.createType(SERIALIZABLE_TYPE_DESCR);

    this.deserializeLambdaMethodName = factory.createString(DESERIALIZE_LAMBDA_METHOD_NAME);
    this.deserializeLambdaMethodProto = factory.createProto(
        factory.objectType, new DexType[] { factory.createType(SERIALIZED_LAMBDA_TYPE_DESCR) });
  }

  /**
   * Detect and desugar lambdas and method references found in the code.
   *
   * NOTE: this method can be called concurrently for several different methods.
   */
  public void desugarLambdas(DexEncodedMethod encodedMethod, IRCode code) {
    DexType currentType = encodedMethod.method.holder;
    ListIterator<BasicBlock> blocks = code.listIterator();
    while (blocks.hasNext()) {
      BasicBlock block = blocks.next();
      InstructionListIterator instructions = block.listIterator();
      while (instructions.hasNext()) {
        Instruction instruction = instructions.next();
        if (instruction.isInvokeCustom()) {
          LambdaDescriptor descriptor = inferLambdaDescriptor(
              instruction.asInvokeCustom().getCallSite());
          if (descriptor == LambdaDescriptor.MATCH_FAILED) {
            continue;
          }

          // We have a descriptor, get or create lambda class.
          LambdaClass lambdaClass = getOrCreateLambdaClass(descriptor, currentType);
          assert lambdaClass != null;

          // We rely on patch performing its work in a way which
          // keeps both `instructions` and `blocks` iterators in
          // valid state so that we can continue iteration.
          patchInstruction(lambdaClass, code, blocks, instructions);
        }
      }
    }
  }

  /** Remove lambda deserialization methods. */
  public void removeLambdaDeserializationMethods(Iterable<DexProgramClass> classes) {
    for (DexProgramClass clazz : classes) {
      // Search for a lambda deserialization method and remove it if found.
      DexEncodedMethod[] directMethods = clazz.directMethods;
      if (directMethods != null) {
        int methodCount = directMethods.length;
        for (int i = 0; i < methodCount; i++) {
          DexEncodedMethod encoded = directMethods[i];
          DexMethod method = encoded.method;
          if (method.name == deserializeLambdaMethodName &&
              method.proto == deserializeLambdaMethodProto) {
            assert encoded.accessFlags.isStatic();
            assert encoded.accessFlags.isPrivate();
            assert encoded.accessFlags.isSynthetic();

            DexEncodedMethod[] newMethods = new DexEncodedMethod[methodCount - 1];
            System.arraycopy(directMethods, 0, newMethods, 0, i);
            System.arraycopy(directMethods, i + 1, newMethods, i, methodCount - i - 1);
            clazz.directMethods = newMethods;

            // We assume there is only one such method in the class.
            break;
          }
        }
      }
    }
  }

  /**
   * Adjust accessibility of referenced application symbols or
   * creates necessary accessors.
   */
  public void adjustAccessibility() {
    // For each lambda class perform necessary adjustment of the
    // referenced symbols to make them accessible. This can result in
    // method access relaxation or creation of accessor method.
    for (LambdaClass lambdaClass : knownLambdaClasses.values()) {
      lambdaClass.target.ensureAccessibility();
    }
  }

  /** Generates lambda classes and adds them to the builder. */
  public void synthesizeLambdaClasses(Builder builder) {
    for (LambdaClass lambdaClass : knownLambdaClasses.values()) {
      DexProgramClass synthesizedClass = lambdaClass.synthesizeLambdaClass();
      converter.optimizeSynthesizedClass(synthesizedClass);
      builder.addSynthesizedClass(synthesizedClass, lambdaClass.addToMainDexList.get());
    }
  }

  // Matches invoke-custom instruction operands to infer lambda descriptor
  // corresponding to this lambda invocation point.
  //
  // Returns the lambda descriptor or `MATCH_FAILED`.
  private LambdaDescriptor inferLambdaDescriptor(DexCallSite callSite) {
    // We check the map before and after inferring lambda descriptor to minimize time
    // spent in synchronized block. As a result we may throw away calculated descriptor
    // in rare case when another thread has same call site processed concurrently,
    // but this is a low price to pay comparing to making whole method synchronous.
    LambdaDescriptor descriptor = getKnown(knownCallSites, callSite);
    return descriptor != null ? descriptor
        : putIfAbsent(knownCallSites, callSite, LambdaDescriptor.infer(this, callSite));
  }

  private boolean isInMainDexList(DexType type) {
    return converter.application.mainDexList.contains(type);
  }

  // Returns a lambda class corresponding to the lambda descriptor and context,
  // creates the class if it does not yet exist.
  private LambdaClass getOrCreateLambdaClass(LambdaDescriptor descriptor, DexType accessedFrom) {
    DexType lambdaClassType = LambdaClass.createLambdaClassType(this, accessedFrom, descriptor);
    // We check the map twice to to minimize time spent in synchronized block.
    LambdaClass lambdaClass = getKnown(knownLambdaClasses, lambdaClassType);
    if (lambdaClass == null) {
      lambdaClass = putIfAbsent(knownLambdaClasses, lambdaClassType,
          new LambdaClass(this, accessedFrom, lambdaClassType, descriptor));
    }
    if (isInMainDexList(accessedFrom)) {
      lambdaClass.addToMainDexList.set(true);
    }
    return lambdaClass;
  }

  private <K, V> V getKnown(Map<K, V> map, K key) {
    synchronized (map) {
      return map.get(key);
    }
  }

  private <K, V> V putIfAbsent(Map<K, V> map, K key, V value) {
    synchronized (map) {
      V known = map.get(key);
      if (known != null) {
        return known;
      }
      map.put(key, value);
      return value;
    }
  }

  // Patches invoke-custom instruction to create or get an instance
  // of the generated lambda class.
  private void patchInstruction(LambdaClass lambdaClass, IRCode code,
      ListIterator<BasicBlock> blocks, InstructionListIterator instructions) {
    assert lambdaClass != null;
    assert instructions != null;
    assert instructions.peekPrevious().isInvokeCustom();

    // Move to the previous instruction, must be InvokeCustom
    InvokeCustom invoke = instructions.previous().asInvokeCustom();

    // The value representing new lambda instance: we reuse the
    // value from the original invoke-custom instruction, and thus
    // all its usages.
    Value lambdaInstanceValue = invoke.outValue();
    if (lambdaInstanceValue == null) {
      // The out value might be empty in case it was optimized out.
      lambdaInstanceValue = code.createValue(MoveType.OBJECT);
    }

    // For stateless lambdas we replace InvokeCustom instruction with StaticGet
    // reading the value of INSTANCE field created for singleton lambda class.
    if (lambdaClass.isStateless()) {
      instructions.replaceCurrentInstruction(
          new StaticGet(MemberType.OBJECT, lambdaInstanceValue, lambdaClass.instanceField));
      // Note that since we replace one throwing operation with another we don't need
      // to have any special handling for catch handlers.
      return;
    }

    // For stateful lambdas we always create a new instance since we need to pass
    // captured values to the constructor.
    //
    // We replace InvokeCustom instruction with a new NewInstance instruction
    // instantiating lambda followed by InvokeDirect instruction calling a
    // constructor on it.
    //
    //    original:
    //      Invoke-Custom rResult <- { rArg0, rArg1, ... }; call site: ...
    //
    //    result:
    //      NewInstance   rResult <-  LambdaClass
    //      Invoke-Direct { rResult, rArg0, rArg1, ... }; method: void LambdaClass.<init>(...)
    NewInstance newInstance = new NewInstance(lambdaClass.type, lambdaInstanceValue);
    instructions.replaceCurrentInstruction(newInstance);

    List<Value> arguments = new ArrayList<>();
    arguments.add(lambdaInstanceValue);
    arguments.addAll(invoke.arguments()); // Optional captures.
    InvokeDirect constructorCall = new InvokeDirect(
        lambdaClass.constructor, null /* no return value */, arguments);
    instructions.add(constructorCall);

    // If we don't have catch handlers we are done.
    if (!constructorCall.getBlock().hasCatchHandlers()) {
      return;
    }

    // Move the iterator back to position it between the two instructions, split
    // the block between the two instructions, and copy the catch handlers.
    instructions.previous();
    assert instructions.peekNext().isInvokeDirect();
    BasicBlock currentBlock = newInstance.getBlock();
    BasicBlock nextBlock = instructions.split(code, blocks);
    assert !instructions.hasNext();
    nextBlock.copyCatchHandlers(code, blocks, currentBlock);
  }
}