diff options
Diffstat (limited to 'agent/src/main/java/com/code_intelligence/jazzer/instrumentor/HookMethodVisitor.kt')
-rw-r--r-- | agent/src/main/java/com/code_intelligence/jazzer/instrumentor/HookMethodVisitor.kt | 386 |
1 files changed, 386 insertions, 0 deletions
diff --git a/agent/src/main/java/com/code_intelligence/jazzer/instrumentor/HookMethodVisitor.kt b/agent/src/main/java/com/code_intelligence/jazzer/instrumentor/HookMethodVisitor.kt new file mode 100644 index 00000000..7c23c703 --- /dev/null +++ b/agent/src/main/java/com/code_intelligence/jazzer/instrumentor/HookMethodVisitor.kt @@ -0,0 +1,386 @@ +// Copyright 2021 Code Intelligence GmbH +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.code_intelligence.jazzer.instrumentor + +import com.code_intelligence.jazzer.api.HookType +import com.code_intelligence.jazzer.third_party.objectweb.asm.Handle +import com.code_intelligence.jazzer.third_party.objectweb.asm.MethodVisitor +import com.code_intelligence.jazzer.third_party.objectweb.asm.Opcodes +import com.code_intelligence.jazzer.third_party.objectweb.asm.Type +import com.code_intelligence.jazzer.third_party.objectweb.asm.commons.LocalVariablesSorter + +internal fun makeHookMethodVisitor( + access: Int, + descriptor: String?, + methodVisitor: MethodVisitor?, + hooks: Iterable<Hook>, + java6Mode: Boolean, + random: DeterministicRandom, +): MethodVisitor { + return HookMethodVisitor(access, descriptor, methodVisitor, hooks, java6Mode, random).lvs +} + +private class HookMethodVisitor( + access: Int, + descriptor: String?, + methodVisitor: MethodVisitor?, + hooks: Iterable<Hook>, + private val java6Mode: Boolean, + private val random: DeterministicRandom, +) : MethodVisitor(Instrumentor.ASM_API_VERSION, methodVisitor) { + + val lvs = object : LocalVariablesSorter(Instrumentor.ASM_API_VERSION, access, descriptor, this) { + override fun updateNewLocals(newLocals: Array<Any>) { + // The local variables involved in calling hooks do not need to outlive the current + // basic block and should thus not appear in stack map frames. By requesting the + // LocalVariableSorter to fill their entries in stack map frames with TOP, they will + // be treated like an unused local variable slot. + newLocals.fill(Opcodes.TOP) + } + } + + private val hooks = hooks.associateBy { hook -> + var hookKey = "${hook.hookType}#${hook.targetInternalClassName}#${hook.targetMethodName}" + if (hook.targetMethodDescriptor != null) + hookKey += "#${hook.targetMethodDescriptor}" + hookKey + } + + override fun visitMethodInsn( + opcode: Int, + owner: String, + methodName: String, + methodDescriptor: String, + isInterface: Boolean, + ) { + if (!isMethodInvocationOp(opcode)) { + mv.visitMethodInsn(opcode, owner, methodName, methodDescriptor, isInterface) + return + } + handleMethodInsn(HookType.BEFORE, opcode, owner, methodName, methodDescriptor, isInterface) + } + + /** + * Emits the bytecode for a method call instruction for the next applicable hook type in order (BEFORE, REPLACE, + * AFTER). Since the instrumented code is indistinguishable from an uninstrumented call instruction, it can be + * safely nested. Combining REPLACE hooks with other hooks is however not supported as these hooks already subsume + * the functionality of BEFORE and AFTER hooks. + */ + private fun visitNextHookTypeOrCall( + hookType: HookType, + appliedHook: Boolean, + opcode: Int, + owner: String, + methodName: String, + methodDescriptor: String, + isInterface: Boolean, + ) = when (hookType) { + HookType.BEFORE -> { + val nextHookType = if (appliedHook) { + // After a BEFORE hook has been applied, we can safely apply an AFTER hook by replacing the actual + // call instruction with the full bytecode injected for the AFTER hook. + HookType.AFTER + } else { + // If no BEFORE hook is registered, look for a REPLACE hook next. + HookType.REPLACE + } + handleMethodInsn(nextHookType, opcode, owner, methodName, methodDescriptor, isInterface) + } + HookType.REPLACE -> { + // REPLACE hooks can't (and don't need to) be mixed with other hooks. We only cycle through them if we + // couldn't find a matching REPLACE hook, in which case we try an AFTER hook next. + require(!appliedHook) + handleMethodInsn(HookType.AFTER, opcode, owner, methodName, methodDescriptor, isInterface) + } + // An AFTER hook is always the last in the chain. Whether a hook has been applied or not, always emit the + // actual call instruction. + HookType.AFTER -> mv.visitMethodInsn(opcode, owner, methodName, methodDescriptor, isInterface) + } + + fun handleMethodInsn( + hookType: HookType, + opcode: Int, + owner: String, + methodName: String, + methodDescriptor: String, + isInterface: Boolean, + ) { + val hook = findMatchingHook(hookType, owner, methodName, methodDescriptor) + if (hook == null) { + visitNextHookTypeOrCall(hookType, false, opcode, owner, methodName, methodDescriptor, isInterface) + return + } + + // The hookId is used to identify a call site. + val hookId = random.nextInt() + + val paramDescriptors = extractParameterTypeDescriptors(methodDescriptor) + val localObjArr = storeMethodArguments(paramDescriptors) + // If the method we're hooking is not static there is now a reference to + // the object the method was invoked on at the top of the stack. + // If the method is static, that object is missing. We make up for it by pushing a null ref. + if (opcode == Opcodes.INVOKESTATIC) { + mv.visitInsn(Opcodes.ACONST_NULL) + } + + // Save the owner object to a new local variable + val ownerDescriptor = "L$owner;" + val localOwnerObj = lvs.newLocal(Type.getType(ownerDescriptor)) + mv.visitVarInsn(Opcodes.ASTORE, localOwnerObj) // consume objectref + // We now removed all values for the original method call from the operand stack + // and saved them to local variables. + + // Start to build the arguments for the hook method. + if (methodName == "<init>") { + // Special case for constructors: + // We cannot create a MethodHandle for a constructor, so we push null instead. + mv.visitInsn(Opcodes.ACONST_NULL) // push nullref + // Only pass the this object if it has been initialized by the time the hook is invoked. + if (hook.hookType == HookType.AFTER) { + mv.visitVarInsn(Opcodes.ALOAD, localOwnerObj) + } else { + mv.visitInsn(Opcodes.ACONST_NULL) // push nullref + } + } else { + // Push a MethodHandle representing the hooked method. + val handleOpcode = when (opcode) { + Opcodes.INVOKEVIRTUAL -> Opcodes.H_INVOKEVIRTUAL + Opcodes.INVOKEINTERFACE -> Opcodes.H_INVOKEINTERFACE + Opcodes.INVOKESTATIC -> Opcodes.H_INVOKESTATIC + Opcodes.INVOKESPECIAL -> Opcodes.H_INVOKESPECIAL + else -> -1 + } + if (java6Mode) { + // MethodHandle constants (type 15) are not supported in Java 6 class files (major version 50). + mv.visitInsn(Opcodes.ACONST_NULL) // push nullref + } else { + mv.visitLdcInsn( + Handle( + handleOpcode, + owner, + methodName, + methodDescriptor, + isInterface + ) + ) // push MethodHandle + } + // Stack layout: ... | MethodHandle (objectref) + // Push the owner object again + mv.visitVarInsn(Opcodes.ALOAD, localOwnerObj) + } + // Stack layout: ... | MethodHandle (objectref) | owner (objectref) + // Push a reference to our object array with the saved arguments + mv.visitVarInsn(Opcodes.ALOAD, localObjArr) + // Stack layout: ... | MethodHandle (objectref) | owner (objectref) | object array (arrayref) + // Push the hook id + mv.visitLdcInsn(hookId) + // Stack layout: ... | MethodHandle (objectref) | owner (objectref) | object array (arrayref) | hookId (int) + // How we proceed depends on the type of hook we want to implement + when (hook.hookType) { + HookType.BEFORE -> { + // Call the hook method + mv.visitMethodInsn( + Opcodes.INVOKESTATIC, + hook.hookInternalClassName, + hook.hookMethodName, + hook.hookMethodDescriptor, + false + ) + // Stack layout: ... + // Push the values for the original method call onto the stack again + if (opcode != Opcodes.INVOKESTATIC) { + mv.visitVarInsn(Opcodes.ALOAD, localOwnerObj) // push owner object + } + loadMethodArguments(paramDescriptors, localObjArr) // push all method arguments + // Stack layout: ... | [owner (objectref)] | arg1 (primitive/objectref) | arg2 (primitive/objectref) | ... + // Call the original method or the next hook in order. + visitNextHookTypeOrCall(hookType, true, opcode, owner, methodName, methodDescriptor, isInterface) + } + HookType.REPLACE -> { + // Call the hook method + mv.visitMethodInsn( + Opcodes.INVOKESTATIC, + hook.hookInternalClassName, + hook.hookMethodName, + hook.hookMethodDescriptor, + false + ) + // Stack layout: ... | [return value (primitive/objectref)] + // Check if we need to process the return value + val returnTypeDescriptor = extractReturnTypeDescriptor(methodDescriptor) + if (returnTypeDescriptor != "V") { + val hookMethodReturnType = extractReturnTypeDescriptor(hook.hookMethodDescriptor) + // if the hook method's return type is primitive we don't need to unwrap or cast it + if (!isPrimitiveType(hookMethodReturnType)) { + // Check if the returned object type is different than the one that should be returned + // If a primitive should be returned we check it's wrapper type + val expectedType = getWrapperTypeDescriptor(returnTypeDescriptor) + if (expectedType != hookMethodReturnType) { + // Cast object + mv.visitTypeInsn(Opcodes.CHECKCAST, extractInternalClassName(expectedType)) + } + // Check if we need to unwrap the returned object + unwrapTypeIfPrimitive(returnTypeDescriptor) + } + } + } + HookType.AFTER -> { + // Push the values for the original method call again onto the stack + if (opcode != Opcodes.INVOKESTATIC) { + mv.visitVarInsn(Opcodes.ALOAD, localOwnerObj) // push owner object + } + loadMethodArguments(paramDescriptors, localObjArr) // push all method arguments + // Stack layout: ... | MethodHandle (objectref) | owner (objectref) | object array (arrayref) | hookId (int) + // | [owner (objectref)] | arg1 (primitive/objectref) | arg2 (primitive/objectref) | ... + // Call the original method or the next hook in order + visitNextHookTypeOrCall(hookType, true, opcode, owner, methodName, methodDescriptor, isInterface) + val returnTypeDescriptor = extractReturnTypeDescriptor(methodDescriptor) + if (returnTypeDescriptor == "V") { + // If the method didn't return anything, we push a nullref as placeholder + mv.visitInsn(Opcodes.ACONST_NULL) // push nullref + } + // Wrap return value if it is a primitive type + wrapTypeIfPrimitive(returnTypeDescriptor) + // Stack layout: ... | MethodHandle (objectref) | owner (objectref) | object array (arrayref) | hookId (int) + // | return value (objectref) + // Store the result value in a local variable (but keep it on the stack) + val localReturnObj = lvs.newLocal(Type.getType(getWrapperTypeDescriptor(returnTypeDescriptor))) + mv.visitVarInsn(Opcodes.ASTORE, localReturnObj) // consume objectref + mv.visitVarInsn(Opcodes.ALOAD, localReturnObj) // push objectref + // Call the hook method + mv.visitMethodInsn( + Opcodes.INVOKESTATIC, + hook.hookInternalClassName, + hook.hookMethodName, + hook.hookMethodDescriptor, + false + ) + // Stack layout: ... + if (returnTypeDescriptor != "V") { + // Push the return value again + mv.visitVarInsn(Opcodes.ALOAD, localReturnObj) // push objectref + // Unwrap it, if it was a primitive value + unwrapTypeIfPrimitive(returnTypeDescriptor) + // Stack layout: ... | return value (primitive/objectref) + } + } + } + } + + private fun isMethodInvocationOp(opcode: Int) = opcode in listOf( + Opcodes.INVOKEVIRTUAL, + Opcodes.INVOKEINTERFACE, + Opcodes.INVOKESTATIC, + Opcodes.INVOKESPECIAL + ) + + private fun findMatchingHook(hookType: HookType, owner: String, name: String, descriptor: String): Hook? { + val withoutDescriptorKey = "$hookType#$owner#$name" + val withDescriptorKey = "$withoutDescriptorKey#$descriptor" + return hooks[withDescriptorKey] ?: hooks[withoutDescriptorKey] + } + + // Stores all arguments for a method call in a local object array. + // paramDescriptors: The type descriptors for all method arguments + private fun storeMethodArguments(paramDescriptors: List<String>): Int { + // Allocate a new Object[] for the methods parameters. + mv.visitIntInsn(Opcodes.SIPUSH, paramDescriptors.size) + mv.visitTypeInsn(Opcodes.ANEWARRAY, "java/lang/Object") + val localObjArr = lvs.newLocal(Type.getType("[Ljava/lang/Object;")) + mv.visitVarInsn(Opcodes.ASTORE, localObjArr) + + // Loop over all arguments in reverse order (because the last argument is on top). + for ((argIdx, argDescriptor) in paramDescriptors.withIndex().reversed()) { + // If the argument is a primitive type, wrap it in it's wrapper class + wrapTypeIfPrimitive(argDescriptor) + // Store the argument in our object array, for that we need to shape the stack first. + // Stack layout: ... | method argument (objectref) + mv.visitVarInsn(Opcodes.ALOAD, localObjArr) + // Stack layout: ... | method argument (objectref) | object array (arrayref) + mv.visitInsn(Opcodes.SWAP) + // Stack layout: ... | object array (arrayref) | method argument (objectref) + mv.visitIntInsn(Opcodes.SIPUSH, argIdx) + // Stack layout: ... | object array (arrayref) | method argument (objectref) | argument index (int) + mv.visitInsn(Opcodes.SWAP) + // Stack layout: ... | object array (arrayref) | argument index (int) | method argument (objectref) + mv.visitInsn(Opcodes.AASTORE) // consume all three: arrayref, index, value + // Stack layout: ... + // Continue with the remaining method arguments + } + + // Return a reference to the array with the parameters. + return localObjArr + } + + // Loads all arguments for a method call from a local object array. + // argTypeSigs: The type signatures for all method arguments + // localObjArr: Index of a local variable containing an object array where the arguments will be loaded from + private fun loadMethodArguments(paramDescriptors: List<String>, localObjArr: Int) { + // Loop over all arguments + for ((argIdx, argDescriptor) in paramDescriptors.withIndex()) { + // Push a reference to the object array on the stack + mv.visitVarInsn(Opcodes.ALOAD, localObjArr) + // Stack layout: ... | object array (arrayref) + // Push the index of the current argument on the stack + mv.visitIntInsn(Opcodes.SIPUSH, argIdx) + // Stack layout: ... | object array (arrayref) | argument index (int) + // Load the argument from the array + mv.visitInsn(Opcodes.AALOAD) + // Stack layout: ... | method argument (objectref) + // Cast object to it's original type (or it's wrapper object) + val wrapperTypeDescriptor = getWrapperTypeDescriptor(argDescriptor) + mv.visitTypeInsn(Opcodes.CHECKCAST, extractInternalClassName(wrapperTypeDescriptor)) + // If the argument is a supposed to be a primitive type, unwrap the wrapped type + unwrapTypeIfPrimitive(argDescriptor) + // Stack layout: ... | method argument (primitive/objectref) + // Continue with the remaining method arguments + } + } + + // Removes a primitive value from the top of the operand stack + // and pushes it enclosed in it's wrapper type (e.g. removes int, pushes Integer). + // This is done by calling .valueOf(...) on the wrapper class. + private fun wrapTypeIfPrimitive(unwrappedTypeDescriptor: String) { + if (!isPrimitiveType(unwrappedTypeDescriptor) || unwrappedTypeDescriptor == "V") return + val wrapperTypeDescriptor = getWrapperTypeDescriptor(unwrappedTypeDescriptor) + val wrapperType = extractInternalClassName(wrapperTypeDescriptor) + val valueOfDescriptor = "($unwrappedTypeDescriptor)$wrapperTypeDescriptor" + mv.visitMethodInsn(Opcodes.INVOKESTATIC, wrapperType, "valueOf", valueOfDescriptor, false) + } + + // Removes a wrapper object around a given primitive type from the top of the operand stack + // and pushes the primitive value it contains (e.g. removes Integer, pushes int). + // This is done by calling .intValue(...) / .charValue(...) / ... on the wrapper object. + private fun unwrapTypeIfPrimitive(primitiveTypeDescriptor: String) { + val (methodName, wrappedTypeDescriptor) = when (primitiveTypeDescriptor) { + "B" -> Pair("byteValue", "java/lang/Byte") + "C" -> Pair("charValue", "java/lang/Character") + "D" -> Pair("doubleValue", "java/lang/Double") + "F" -> Pair("floatValue", "java/lang/Float") + "I" -> Pair("intValue", "java/lang/Integer") + "J" -> Pair("longValue", "java/lang/Long") + "S" -> Pair("shortValue", "java/lang/Short") + "Z" -> Pair("booleanValue", "java/lang/Boolean") + else -> return + } + mv.visitMethodInsn( + Opcodes.INVOKEVIRTUAL, + wrappedTypeDescriptor, + methodName, + "()$primitiveTypeDescriptor", + false + ) + } +} |