diff options
Diffstat (limited to 'src/main/java/com/code_intelligence/jazzer/instrumentor')
13 files changed, 1876 insertions, 0 deletions
diff --git a/src/main/java/com/code_intelligence/jazzer/instrumentor/BUILD.bazel b/src/main/java/com/code_intelligence/jazzer/instrumentor/BUILD.bazel new file mode 100644 index 00000000..bbb449a4 --- /dev/null +++ b/src/main/java/com/code_intelligence/jazzer/instrumentor/BUILD.bazel @@ -0,0 +1,41 @@ +load("@io_bazel_rules_kotlin//kotlin:jvm.bzl", "kt_jvm_library") +load("//bazel:kotlin.bzl", "ktlint") + +kt_jvm_library( + name = "instrumentor", + srcs = [ + "ClassInstrumentor.kt", + "CoverageRecorder.kt", + "DescriptorUtils.kt", + "DeterministicRandom.kt", + "EdgeCoverageInstrumentor.kt", + "Hook.kt", + "HookInstrumentor.kt", + "HookMethodVisitor.kt", + "Hooks.kt", + "Instrumentor.kt", + "StaticMethodStrategy.java", + "TraceDataFlowInstrumentor.kt", + ], + visibility = [ + "//src/jmh/java/com/code_intelligence/jazzer/instrumentor:__pkg__", + "//src/main/java/com/code_intelligence/jazzer/agent:__pkg__", + "//src/main/java/com/code_intelligence/jazzer/driver:__pkg__", + "//src/test/java/com/code_intelligence/jazzer/instrumentor:__pkg__", + ], + deps = [ + "//src/main/java/com/code_intelligence/jazzer/api:hooks", + "//src/main/java/com/code_intelligence/jazzer/runtime:jazzer_bootstrap_compile_only", + "//src/main/java/com/code_intelligence/jazzer/utils", + "//src/main/java/com/code_intelligence/jazzer/utils:class_name_globber", + "//src/main/java/com/code_intelligence/jazzer/utils:log", + "@com_github_classgraph_classgraph//:classgraph", + "@com_github_jetbrains_kotlin//:kotlin-reflect", + "@jazzer_jacoco//:jacoco_internal", + "@org_ow2_asm_asm//jar", + "@org_ow2_asm_asm_commons//jar", + "@org_ow2_asm_asm_tree//jar", + ], +) + +ktlint() diff --git a/src/main/java/com/code_intelligence/jazzer/instrumentor/ClassInstrumentor.kt b/src/main/java/com/code_intelligence/jazzer/instrumentor/ClassInstrumentor.kt new file mode 100644 index 00000000..a93e29c7 --- /dev/null +++ b/src/main/java/com/code_intelligence/jazzer/instrumentor/ClassInstrumentor.kt @@ -0,0 +1,55 @@ +// Copyright 2021 Code Intelligence GmbH +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.code_intelligence.jazzer.instrumentor + +import com.code_intelligence.jazzer.runtime.CoverageMap + +fun extractClassFileMajorVersion(classfileBuffer: ByteArray): Int { + return ((classfileBuffer[6].toInt() and 0xff) shl 8) or (classfileBuffer[7].toInt() and 0xff) +} + +class ClassInstrumentor(private val internalClassName: String, bytecode: ByteArray) { + + var instrumentedBytecode = bytecode + private set + + fun coverage(initialEdgeId: Int): Int { + val edgeCoverageInstrumentor = EdgeCoverageInstrumentor( + defaultEdgeCoverageStrategy, + defaultCoverageMap, + initialEdgeId, + ) + instrumentedBytecode = edgeCoverageInstrumentor.instrument(internalClassName, instrumentedBytecode) + return edgeCoverageInstrumentor.numEdges + } + + fun traceDataFlow(instrumentations: Set<InstrumentationType>) { + instrumentedBytecode = + TraceDataFlowInstrumentor(instrumentations).instrument(internalClassName, instrumentedBytecode) + } + + fun hooks(hooks: Iterable<Hook>, classWithHooksEnabledField: String?) { + instrumentedBytecode = HookInstrumentor( + hooks, + java6Mode = extractClassFileMajorVersion(instrumentedBytecode) < 51, + classWithHooksEnabledField = classWithHooksEnabledField, + ).instrument(internalClassName, instrumentedBytecode) + } + + companion object { + val defaultEdgeCoverageStrategy = StaticMethodStrategy() + val defaultCoverageMap = CoverageMap::class.java + } +} diff --git a/src/main/java/com/code_intelligence/jazzer/instrumentor/CoverageRecorder.kt b/src/main/java/com/code_intelligence/jazzer/instrumentor/CoverageRecorder.kt new file mode 100644 index 00000000..56fb5725 --- /dev/null +++ b/src/main/java/com/code_intelligence/jazzer/instrumentor/CoverageRecorder.kt @@ -0,0 +1,252 @@ +// Copyright 2021 Code Intelligence GmbH +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.code_intelligence.jazzer.instrumentor + +import com.code_intelligence.jazzer.runtime.CoverageMap +import com.code_intelligence.jazzer.third_party.org.jacoco.core.analysis.CoverageBuilder +import com.code_intelligence.jazzer.third_party.org.jacoco.core.data.ExecutionData +import com.code_intelligence.jazzer.third_party.org.jacoco.core.data.ExecutionDataStore +import com.code_intelligence.jazzer.third_party.org.jacoco.core.data.ExecutionDataWriter +import com.code_intelligence.jazzer.third_party.org.jacoco.core.data.SessionInfo +import com.code_intelligence.jazzer.third_party.org.jacoco.core.internal.data.CRC64 +import com.code_intelligence.jazzer.utils.ClassNameGlobber +import io.github.classgraph.ClassGraph +import java.io.File +import java.io.FileOutputStream +import java.io.OutputStream +import java.time.Instant +import java.util.UUID + +private data class InstrumentedClassInfo( + val classId: Long, + val initialEdgeId: Int, + val nextEdgeId: Int, + val bytecode: ByteArray, +) + +object CoverageRecorder { + var classNameGlobber = ClassNameGlobber(emptyList(), emptyList()) + private val instrumentedClassInfo = mutableMapOf<String, InstrumentedClassInfo>() + private var startTimestamp: Instant? = null + private val additionalCoverage = mutableSetOf<Int>() + + fun recordInstrumentedClass(internalClassName: String, bytecode: ByteArray, firstId: Int, numIds: Int) { + if (startTimestamp == null) { + startTimestamp = Instant.now() + } + instrumentedClassInfo[internalClassName] = InstrumentedClassInfo( + CRC64.classId(bytecode), + firstId, + firstId + numIds, + bytecode, + ) + } + + /** + * Manually records coverage IDs based on the current state of [CoverageMap]. + * Should be called after static initializers have run. + */ + @JvmStatic + fun updateCoveredIdsWithCoverageMap() { + additionalCoverage.addAll(CoverageMap.getCoveredIds()) + } + + /** + * [dumpCoverageReport] dumps a human-readable coverage report of files using any [coveredIds] to [dumpFileName]. + */ + @JvmStatic + @JvmOverloads + fun dumpCoverageReport(dumpFileName: String, coveredIds: IntArray = CoverageMap.getEverCoveredIds()) { + File(dumpFileName).bufferedWriter().use { writer -> + writer.write(computeFileCoverage(coveredIds)) + } + } + + private fun computeFileCoverage(coveredIds: IntArray): String { + fun Double.format(digits: Int) = "%.${digits}f".format(this) + val coverage = analyzeCoverage(coveredIds.toSet()) ?: return "No classes were instrumented" + return coverage.sourceFiles.joinToString( + "\n", + prefix = "Branch coverage:\n", + postfix = "\n\n", + ) { fileCoverage -> + val counter = fileCoverage.branchCounter + val percentage = 100 * counter.coveredRatio + "${fileCoverage.name}: ${counter.coveredCount}/${counter.totalCount} (${percentage.format(2)}%)" + } + coverage.sourceFiles.joinToString( + "\n", + prefix = "Line coverage:\n", + postfix = "\n\n", + ) { fileCoverage -> + val counter = fileCoverage.lineCounter + val percentage = 100 * counter.coveredRatio + "${fileCoverage.name}: ${counter.coveredCount}/${counter.totalCount} (${percentage.format(2)}%)" + } + coverage.sourceFiles.joinToString( + "\n", + prefix = "Incompletely covered lines:\n", + postfix = "\n\n", + ) { fileCoverage -> + "${fileCoverage.name}: " + (fileCoverage.firstLine..fileCoverage.lastLine).filter { + val instructions = fileCoverage.getLine(it).instructionCounter + instructions.coveredCount in 1 until instructions.totalCount + }.toString() + } + coverage.sourceFiles.joinToString( + "\n", + prefix = "Missed lines:\n", + ) { fileCoverage -> + "${fileCoverage.name}: " + (fileCoverage.firstLine..fileCoverage.lastLine).filter { + val instructions = fileCoverage.getLine(it).instructionCounter + instructions.coveredCount == 0 && instructions.totalCount > 0 + }.toString() + } + } + + /** + * [dumpJacocoCoverage] dumps the JaCoCo coverage of files using any [coveredIds] to [dumpFileName]. + * JaCoCo only exports coverage for files containing at least one coverage data point. The dump + * can be used by the JaCoCo report command to create reports also including not covered files. + */ + @JvmStatic + @JvmOverloads + fun dumpJacocoCoverage(dumpFileName: String, coveredIds: IntArray = CoverageMap.getEverCoveredIds()) { + FileOutputStream(dumpFileName).use { outStream -> + dumpJacocoCoverage(outStream, coveredIds) + } + } + + /** + * [dumpJacocoCoverage] dumps the JaCoCo coverage of files using any [coveredIds] to [outStream]. + */ + @JvmStatic + fun dumpJacocoCoverage(outStream: OutputStream, coveredIds: IntArray) { + // Return if no class has been instrumented. + val startTimestamp = startTimestamp ?: return + + // Update the list of covered IDs with the coverage information for the current run. + updateCoveredIdsWithCoverageMap() + + val dumpTimestamp = Instant.now() + val outWriter = ExecutionDataWriter(outStream) + outWriter.visitSessionInfo( + SessionInfo(UUID.randomUUID().toString(), startTimestamp.epochSecond, dumpTimestamp.epochSecond), + ) + analyzeJacocoCoverage(coveredIds.toSet()).accept(outWriter) + } + + /** + * Build up a JaCoCo [ExecutionDataStore] based on [coveredIds] containing the internally gathered coverage information. + */ + private fun analyzeJacocoCoverage(coveredIds: Set<Int>): ExecutionDataStore { + val executionDataStore = ExecutionDataStore() + val sortedCoveredIds = (additionalCoverage + coveredIds).sorted().toIntArray() + for ((internalClassName, info) in instrumentedClassInfo) { + // Determine the subarray of coverage IDs in sortedCoveredIds that contains the IDs generated while + // instrumenting the current class. Since the ID array is sorted, use binary search. + var coveredIdsStart = sortedCoveredIds.binarySearch(info.initialEdgeId) + if (coveredIdsStart < 0) { + coveredIdsStart = -(coveredIdsStart + 1) + } + var coveredIdsEnd = sortedCoveredIds.binarySearch(info.nextEdgeId) + if (coveredIdsEnd < 0) { + coveredIdsEnd = -(coveredIdsEnd + 1) + } + if (coveredIdsStart == coveredIdsEnd) { + // No coverage data for the class. + continue + } + check(coveredIdsStart in 0 until coveredIdsEnd && coveredIdsEnd <= sortedCoveredIds.size) { + "Invalid range [$coveredIdsStart, $coveredIdsEnd) with coveredIds.size=${sortedCoveredIds.size}" + } + // Generate a probes array for the current class only, i.e., mapping info.initialEdgeId to 0. + val probes = BooleanArray(info.nextEdgeId - info.initialEdgeId) + (coveredIdsStart until coveredIdsEnd).asSequence() + .map { + val globalEdgeId = sortedCoveredIds[it] + globalEdgeId - info.initialEdgeId + } + .forEach { classLocalEdgeId -> + probes[classLocalEdgeId] = true + } + executionDataStore.visitClassExecution(ExecutionData(info.classId, internalClassName, probes)) + } + return executionDataStore + } + + /** + * Create a [CoverageBuilder] containing all classes matching the include/exclude pattern and their coverage statistics. + */ + fun analyzeCoverage(coveredIds: Set<Int>): CoverageBuilder? { + return try { + val coverage = CoverageBuilder() + analyzeAllUncoveredClasses(coverage) + val executionDataStore = analyzeJacocoCoverage(coveredIds) + for ((internalClassName, info) in instrumentedClassInfo) { + EdgeCoverageInstrumentor(ClassInstrumentor.defaultEdgeCoverageStrategy, ClassInstrumentor.defaultCoverageMap, 0) + .analyze( + executionDataStore, + coverage, + info.bytecode, + internalClassName, + ) + } + coverage + } catch (e: Exception) { + e.printStackTrace() + null + } + } + + /** + * Traverses the entire classpath and analyzes all uncovered classes that match the include/exclude pattern. + * The returned [CoverageBuilder] will report coverage information for *all* classes on the classpath, not just + * those that were loaded while the fuzzer ran. + */ + private fun analyzeAllUncoveredClasses(coverage: CoverageBuilder): CoverageBuilder { + val coveredClassNames = instrumentedClassInfo + .keys + .asSequence() + .map { it.replace('/', '.') } + .toSet() + ClassGraph() + .enableClassInfo() + .ignoreClassVisibility() + .rejectPackages( + // Always exclude Jazzer-internal packages (including ClassGraph itself) from coverage reports. Classes + // from the Java standard library are never traversed. + "com.code_intelligence.jazzer.*", + "jaz", + ) + .scan().use { result -> + // ExecutionDataStore is used to look up existing coverage during analysis of the class files, + // no entries are added during that. Passing in an empty store is fine for uncovered files. + val emptyExecutionDataStore = ExecutionDataStore() + result.allClasses + .asSequence() + .filter { classInfo -> classNameGlobber.includes(classInfo.name) } + .filterNot { classInfo -> classInfo.name in coveredClassNames } + .forEach { classInfo -> + classInfo.resource.use { resource -> + EdgeCoverageInstrumentor(ClassInstrumentor.defaultEdgeCoverageStrategy, ClassInstrumentor.defaultCoverageMap, 0).analyze( + emptyExecutionDataStore, + coverage, + resource.load(), + classInfo.name.replace('.', '/'), + ) + } + } + } + return coverage + } +} diff --git a/src/main/java/com/code_intelligence/jazzer/instrumentor/DescriptorUtils.kt b/src/main/java/com/code_intelligence/jazzer/instrumentor/DescriptorUtils.kt new file mode 100644 index 00000000..9d02c04f --- /dev/null +++ b/src/main/java/com/code_intelligence/jazzer/instrumentor/DescriptorUtils.kt @@ -0,0 +1,105 @@ +// Copyright 2021 Code Intelligence GmbH +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.code_intelligence.jazzer.instrumentor + +import org.objectweb.asm.Type +import java.lang.reflect.Constructor +import java.lang.reflect.Executable +import java.lang.reflect.Method + +val Class<*>.descriptor: String + get() = Type.getDescriptor(this) + +val Executable.descriptor: String + get() = if (this is Method) { + Type.getMethodDescriptor(this) + } else { + Type.getConstructorDescriptor(this as Constructor<*>?) + } + +internal fun isPrimitiveType(typeDescriptor: String): Boolean { + return typeDescriptor in arrayOf("B", "C", "D", "F", "I", "J", "S", "V", "Z") +} + +private fun isPrimitiveType(typeDescriptor: Char) = isPrimitiveType(typeDescriptor.toString()) + +internal fun getWrapperTypeDescriptor(typeDescriptor: String): String = when (typeDescriptor) { + "B" -> "Ljava/lang/Byte;" + "C" -> "Ljava/lang/Character;" + "D" -> "Ljava/lang/Double;" + "F" -> "Ljava/lang/Float;" + "I" -> "Ljava/lang/Integer;" + "J" -> "Ljava/lang/Long;" + "S" -> "Ljava/lang/Short;" + "V" -> "Ljava/lang/Void;" + "Z" -> "Ljava/lang/Boolean;" + else -> typeDescriptor +} + +// Removes the 'L' and ';' prefix/suffix from signatures to get the full class name. +// Note that array signatures '[Ljava/lang/String;' already have the correct form. +internal fun extractInternalClassName(typeDescriptor: String): String { + return if (typeDescriptor.startsWith("L") && typeDescriptor.endsWith(";")) { + typeDescriptor.substring(1, typeDescriptor.length - 1) + } else { + typeDescriptor + } +} + +internal fun extractParameterTypeDescriptors(methodDescriptor: String): List<String> { + require(methodDescriptor.startsWith('(')) { "Method descriptor must start with '('" } + val endOfParameterPart = methodDescriptor.indexOf(')') - 1 + require(endOfParameterPart >= 0) { "Method descriptor must contain ')'" } + var remainingDescriptorList = methodDescriptor.substring(1..endOfParameterPart) + val parameterDescriptors = mutableListOf<String>() + while (remainingDescriptorList.isNotEmpty()) { + val nextDescriptor = extractNextTypeDescriptor(remainingDescriptorList) + parameterDescriptors.add(nextDescriptor) + remainingDescriptorList = remainingDescriptorList.removePrefix(nextDescriptor) + } + return parameterDescriptors +} + +internal fun extractReturnTypeDescriptor(methodDescriptor: String): String { + require(methodDescriptor.startsWith('(')) { "Method descriptor must start with '('" } + val endBracketPos = methodDescriptor.indexOf(')') + require(endBracketPos >= 0) { "Method descriptor must contain ')'" } + val startOfReturnValue = endBracketPos + 1 + return extractNextTypeDescriptor(methodDescriptor.substring(startOfReturnValue)) +} + +private fun extractNextTypeDescriptor(input: String): String { + require(input.isNotEmpty()) { "Type descriptor must not be empty" } + // Skip over arbitrarily many '[' to support multi-dimensional arrays. + val firstNonArrayPrefixCharPos = input.indexOfFirst { it != '[' } + require(firstNonArrayPrefixCharPos >= 0) { "Array descriptor must contain type" } + val firstTypeChar = input[firstNonArrayPrefixCharPos] + return when { + // Primitive type + isPrimitiveType(firstTypeChar) -> { + input.substring(0..firstNonArrayPrefixCharPos) + } + // Object type + firstTypeChar == 'L' -> { + val endOfClassNamePos = input.indexOf(';') + require(endOfClassNamePos > 0) { "Class type indicated by L must end with ;" } + input.substring(0..endOfClassNamePos) + } + // Invalid type + else -> { + throw IllegalArgumentException("Invalid type: $firstTypeChar") + } + } +} diff --git a/src/main/java/com/code_intelligence/jazzer/instrumentor/DeterministicRandom.kt b/src/main/java/com/code_intelligence/jazzer/instrumentor/DeterministicRandom.kt new file mode 100644 index 00000000..d4210dc4 --- /dev/null +++ b/src/main/java/com/code_intelligence/jazzer/instrumentor/DeterministicRandom.kt @@ -0,0 +1,35 @@ +// Copyright 2021 Code Intelligence GmbH +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.code_intelligence.jazzer.instrumentor + +import java.security.MessageDigest +import java.security.SecureRandom + +// This RNG is resistant to collisions (even under XOR) but fully deterministic. +internal class DeterministicRandom(vararg contexts: String) { + private val random = SecureRandom.getInstance("SHA1PRNG").apply { + val contextHash = MessageDigest.getInstance("SHA-256").run { + for (context in contexts) { + update(context.toByteArray()) + } + digest() + } + setSeed(contextHash) + } + + fun nextInt(bound: Int) = random.nextInt(bound) + + fun nextInt() = random.nextInt() +} diff --git a/src/main/java/com/code_intelligence/jazzer/instrumentor/EdgeCoverageInstrumentor.kt b/src/main/java/com/code_intelligence/jazzer/instrumentor/EdgeCoverageInstrumentor.kt new file mode 100644 index 00000000..975f3987 --- /dev/null +++ b/src/main/java/com/code_intelligence/jazzer/instrumentor/EdgeCoverageInstrumentor.kt @@ -0,0 +1,187 @@ +// Copyright 2021 Code Intelligence GmbH +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.code_intelligence.jazzer.instrumentor + +import com.code_intelligence.jazzer.third_party.org.jacoco.core.analysis.Analyzer +import com.code_intelligence.jazzer.third_party.org.jacoco.core.analysis.ICoverageVisitor +import com.code_intelligence.jazzer.third_party.org.jacoco.core.data.ExecutionDataStore +import com.code_intelligence.jazzer.third_party.org.jacoco.core.internal.flow.ClassProbesAdapter +import com.code_intelligence.jazzer.third_party.org.jacoco.core.internal.flow.ClassProbesVisitor +import com.code_intelligence.jazzer.third_party.org.jacoco.core.internal.flow.IClassProbesAdapterFactory +import com.code_intelligence.jazzer.third_party.org.jacoco.core.internal.instr.ClassInstrumenter +import com.code_intelligence.jazzer.third_party.org.jacoco.core.internal.instr.IProbeArrayStrategy +import com.code_intelligence.jazzer.third_party.org.jacoco.core.internal.instr.IProbeInserterFactory +import com.code_intelligence.jazzer.third_party.org.jacoco.core.internal.instr.InstrSupport +import com.code_intelligence.jazzer.third_party.org.jacoco.core.internal.instr.ProbeInserter +import org.objectweb.asm.ClassReader +import org.objectweb.asm.ClassVisitor +import org.objectweb.asm.ClassWriter +import org.objectweb.asm.MethodVisitor +import java.lang.invoke.MethodHandle +import java.lang.invoke.MethodHandles.publicLookup +import java.lang.invoke.MethodType.methodType +import kotlin.math.max + +/** + * A particular way to instrument bytecode for edge coverage using a coverage map class available to + * hold the collected coverage data at runtime. + */ +interface EdgeCoverageStrategy { + + /** + * Inject bytecode instrumentation on a control flow edge with ID [edgeId], with access to the + * local variable [variable] that is populated at the beginning of each method by the + * instrumentation injected in [loadLocalVariable]. + */ + fun instrumentControlFlowEdge( + mv: MethodVisitor, + edgeId: Int, + variable: Int, + coverageMapInternalClassName: String, + ) + + /** + * The maximal number of stack elements used by [instrumentControlFlowEdge]. + */ + val instrumentControlFlowEdgeStackSize: Int + + /** + * The type of the local variable used by the instrumentation in the format used by + * [MethodVisitor.visitFrame]'s `local` parameter, or `null` if the instrumentation does not use + * one. + * @see https://asm.ow2.io/javadoc/org/objectweb/asm/MethodVisitor.html#visitFrame(int,int,java.lang.Object%5B%5D,int,java.lang.Object%5B%5D) + */ + val localVariableType: Any? + + /** + * Inject bytecode that loads the coverage counters of the coverage map class described by + * [coverageMapInternalClassName] into the local variable [variable]. + */ + fun loadLocalVariable(mv: MethodVisitor, variable: Int, coverageMapInternalClassName: String) + + /** + * The maximal number of stack elements used by [loadLocalVariable]. + */ + val loadLocalVariableStackSize: Int +} + +// An instance of EdgeCoverageInstrumentor should only be used to instrument a single class as it +// internally tracks the edge IDs, which have to be globally unique. +class EdgeCoverageInstrumentor( + private val strategy: EdgeCoverageStrategy, + /** + * The class must have the following public static member + * - method enlargeIfNeeded(int nextEdgeId): Called before a new edge ID is emitted. + */ + coverageMapClass: Class<*>, + private val initialEdgeId: Int, +) : Instrumentor { + private var nextEdgeId = initialEdgeId + + private val coverageMapInternalClassName = coverageMapClass.name.replace('.', '/') + private val enlargeIfNeeded: MethodHandle = + publicLookup().findStatic( + coverageMapClass, + "enlargeIfNeeded", + methodType( + Void::class.javaPrimitiveType, + Int::class.javaPrimitiveType, + ), + ) + + override fun instrument(internalClassName: String, bytecode: ByteArray): ByteArray { + val reader = InstrSupport.classReaderFor(bytecode) + val writer = ClassWriter(reader, 0) + val version = InstrSupport.getMajorVersion(reader) + val visitor = EdgeCoverageClassProbesAdapter( + ClassInstrumenter(edgeCoverageProbeArrayStrategy, edgeCoverageProbeInserterFactory, writer), + InstrSupport.needsFrames(version), + ) + reader.accept(visitor, ClassReader.EXPAND_FRAMES) + return writer.toByteArray() + } + + fun analyze(executionData: ExecutionDataStore, coverageVisitor: ICoverageVisitor, bytecode: ByteArray, internalClassName: String) { + Analyzer(executionData, coverageVisitor, edgeCoverageClassProbesAdapterFactory).run { + analyzeClass(bytecode, internalClassName) + } + } + + val numEdges + get() = nextEdgeId - initialEdgeId + + private fun nextEdgeId(): Int { + enlargeIfNeeded.invokeExact(nextEdgeId) + return nextEdgeId++ + } + + /** + * A [ProbeInserter] that injects bytecode instrumentation at every control flow edge and + * modifies the stack size and number of local variables accordingly. + */ + private inner class EdgeCoverageProbeInserter( + access: Int, + name: String, + desc: String, + mv: MethodVisitor, + arrayStrategy: IProbeArrayStrategy, + ) : ProbeInserter(access, name, desc, mv, arrayStrategy) { + override fun insertProbe(id: Int) { + strategy.instrumentControlFlowEdge(mv, id, variable, coverageMapInternalClassName) + } + + override fun visitMaxs(maxStack: Int, maxLocals: Int) { + val newMaxStack = max(maxStack + strategy.instrumentControlFlowEdgeStackSize, strategy.loadLocalVariableStackSize) + val newMaxLocals = maxLocals + if (strategy.localVariableType != null) 1 else 0 + mv.visitMaxs(newMaxStack, newMaxLocals) + } + + override fun getLocalVariableType() = strategy.localVariableType + } + + private val edgeCoverageProbeInserterFactory = + IProbeInserterFactory { access, name, desc, mv, arrayStrategy -> + EdgeCoverageProbeInserter(access, name, desc, mv, arrayStrategy) + } + + private inner class EdgeCoverageClassProbesAdapter(private val cpv: ClassProbesVisitor, trackFrames: Boolean) : + ClassProbesAdapter(cpv, trackFrames) { + override fun nextId(): Int = nextEdgeId() + + override fun visitEnd() { + cpv.visitTotalProbeCount(numEdges) + // Avoid calling super.visitEnd() as that invokes cpv.visitTotalProbeCount with an + // incorrect value of `count`. + cpv.visitEnd() + } + } + + private val edgeCoverageClassProbesAdapterFactory = IClassProbesAdapterFactory { probesVisitor, trackFrames -> + EdgeCoverageClassProbesAdapter(probesVisitor, trackFrames) + } + + private val edgeCoverageProbeArrayStrategy = object : IProbeArrayStrategy { + override fun storeInstance(mv: MethodVisitor, clinit: Boolean, variable: Int): Int { + strategy.loadLocalVariable(mv, variable, coverageMapInternalClassName) + return strategy.loadLocalVariableStackSize + } + + override fun addMembers(cv: ClassVisitor, probeCount: Int) {} + } +} + +fun MethodVisitor.push(value: Int) { + InstrSupport.push(this, value) +} diff --git a/src/main/java/com/code_intelligence/jazzer/instrumentor/Hook.kt b/src/main/java/com/code_intelligence/jazzer/instrumentor/Hook.kt new file mode 100644 index 00000000..077ab10e --- /dev/null +++ b/src/main/java/com/code_intelligence/jazzer/instrumentor/Hook.kt @@ -0,0 +1,132 @@ +// Copyright 2021 Code Intelligence GmbH +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +@file:Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN") + +package com.code_intelligence.jazzer.instrumentor + +import com.code_intelligence.jazzer.api.HookType +import com.code_intelligence.jazzer.api.MethodHook +import java.lang.invoke.MethodHandle +import java.lang.reflect.Method +import java.lang.reflect.Modifier + +class Hook private constructor( + private val targetClassName: String, + val hookType: HookType, + val targetMethodName: String, + val targetMethodDescriptor: String?, + val additionalClassesToHook: List<String>, + val targetInternalClassName: String, + private val targetReturnTypeDescriptor: String?, + private val targetWrappedReturnTypeDescriptor: String?, + private val hookClassName: String, + val hookInternalClassName: String, + val hookMethodName: String, + val hookMethodDescriptor: String, +) { + + override fun toString(): String { + return "$hookType $targetClassName.$targetMethodName: $hookClassName.$hookMethodName $additionalClassesToHook" + } + + companion object { + fun createAndVerifyHook(hookMethod: Method, hookData: MethodHook, className: String): Hook { + return createHook(hookMethod, hookData, className).also { + verify(hookMethod, it) + } + } + + private fun createHook(hookMethod: Method, annotation: MethodHook, targetClassName: String): Hook { + val targetReturnTypeDescriptor = annotation.targetMethodDescriptor + .takeIf { it.isNotBlank() }?.let { extractReturnTypeDescriptor(it) } + val hookClassName: String = hookMethod.declaringClass.name + return Hook( + targetClassName = targetClassName, + hookType = annotation.type, + targetMethodName = annotation.targetMethod, + targetMethodDescriptor = annotation.targetMethodDescriptor.takeIf { it.isNotBlank() }, + additionalClassesToHook = annotation.additionalClassesToHook.asList(), + targetInternalClassName = targetClassName.replace('.', '/'), + targetReturnTypeDescriptor = targetReturnTypeDescriptor, + targetWrappedReturnTypeDescriptor = targetReturnTypeDescriptor?.let { getWrapperTypeDescriptor(it) }, + hookClassName = hookClassName, + hookInternalClassName = hookClassName.replace('.', '/'), + hookMethodName = hookMethod.name, + hookMethodDescriptor = hookMethod.descriptor, + ) + } + + private fun verify(hookMethod: Method, potentialHook: Hook) { + // Verify the hook method's modifiers (public static). + require(Modifier.isPublic(hookMethod.modifiers)) { "$potentialHook: hook method must be public" } + require(Modifier.isStatic(hookMethod.modifiers)) { "$potentialHook: hook method must be static" } + + // Verify the hook method's parameter count. + val numParameters = hookMethod.parameters.size + when (potentialHook.hookType) { + HookType.BEFORE, HookType.REPLACE -> require(numParameters == 4) { "$potentialHook: incorrect number of parameters (expected 4)" } + HookType.AFTER -> require(numParameters == 5) { "$potentialHook: incorrect number of parameters (expected 5)" } + } + + // Verify the hook method's parameter types. + val parameterTypes = hookMethod.parameterTypes + require(parameterTypes[0] == MethodHandle::class.java) { "$potentialHook: first parameter must have type MethodHandle" } + require(parameterTypes[1] == Object::class.java || parameterTypes[1].name == potentialHook.targetClassName) { "$potentialHook: second parameter must have type Object or ${potentialHook.targetClassName}" } + require(parameterTypes[2] == Array<Object>::class.java) { "$potentialHook: third parameter must have type Object[]" } + require(parameterTypes[3] == Int::class.javaPrimitiveType) { "$potentialHook: fourth parameter must have type int" } + + // Verify the hook method's return type if possible. + when (potentialHook.hookType) { + HookType.BEFORE, HookType.AFTER -> require(hookMethod.returnType == Void.TYPE) { + "$potentialHook: return type must be void" + } + HookType.REPLACE -> if (potentialHook.targetReturnTypeDescriptor != null) { + if (potentialHook.targetMethodName == "<init>") { + require(hookMethod.returnType.name == potentialHook.targetClassName) { "$potentialHook: return type must be ${potentialHook.targetClassName} to match target constructor" } + } else if (potentialHook.targetReturnTypeDescriptor == "V") { + require(hookMethod.returnType.descriptor == "V") { "$potentialHook: return type must be void" } + } else { + require( + hookMethod.returnType.descriptor in listOf( + java.lang.Object::class.java.descriptor, + potentialHook.targetReturnTypeDescriptor, + potentialHook.targetWrappedReturnTypeDescriptor, + ), + ) { + "$potentialHook: return type must have type Object or match the descriptors ${potentialHook.targetReturnTypeDescriptor} or ${potentialHook.targetWrappedReturnTypeDescriptor}" + } + } + } + } + + // AfterMethodHook only: Verify the type of the last parameter if known. Even if not + // known, it must not be a primitive value. + if (potentialHook.hookType == HookType.AFTER) { + if (potentialHook.targetReturnTypeDescriptor != null) { + require( + parameterTypes[4] == java.lang.Object::class.java || + parameterTypes[4].descriptor == potentialHook.targetWrappedReturnTypeDescriptor, + ) { + "$potentialHook: fifth parameter must have type Object or match the descriptor ${potentialHook.targetWrappedReturnTypeDescriptor}" + } + } else { + require(!parameterTypes[4].isPrimitive) { + "$potentialHook: fifth parameter must not be a primitive type, use a boxed type instead" + } + } + } + } + } +} diff --git a/src/main/java/com/code_intelligence/jazzer/instrumentor/HookInstrumentor.kt b/src/main/java/com/code_intelligence/jazzer/instrumentor/HookInstrumentor.kt new file mode 100644 index 00000000..3c0d97c9 --- /dev/null +++ b/src/main/java/com/code_intelligence/jazzer/instrumentor/HookInstrumentor.kt @@ -0,0 +1,63 @@ +// Copyright 2021 Code Intelligence GmbH +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.code_intelligence.jazzer.instrumentor + +import org.objectweb.asm.ClassReader +import org.objectweb.asm.ClassVisitor +import org.objectweb.asm.ClassWriter +import org.objectweb.asm.MethodVisitor + +internal class HookInstrumentor( + private val hooks: Iterable<Hook>, + private val java6Mode: Boolean, + private val classWithHooksEnabledField: String?, +) : Instrumentor { + + private lateinit var random: DeterministicRandom + + override fun instrument(internalClassName: String, bytecode: ByteArray): ByteArray { + val reader = ClassReader(bytecode) + val writer = ClassWriter(reader, ClassWriter.COMPUTE_MAXS) + random = DeterministicRandom("hook", reader.className) + val interceptor = object : ClassVisitor(Instrumentor.ASM_API_VERSION, writer) { + override fun visitMethod( + access: Int, + name: String?, + descriptor: String?, + signature: String?, + exceptions: Array<String>?, + ): MethodVisitor? { + val mv = cv.visitMethod(access, name, descriptor, signature, exceptions) ?: return null + return if (shouldInstrument(access)) { + makeHookMethodVisitor( + internalClassName, + access, + name, + descriptor, + mv, + hooks, + java6Mode, + random, + classWithHooksEnabledField, + ) + } else { + mv + } + } + } + reader.accept(interceptor, ClassReader.EXPAND_FRAMES) + return writer.toByteArray() + } +} diff --git a/src/main/java/com/code_intelligence/jazzer/instrumentor/HookMethodVisitor.kt b/src/main/java/com/code_intelligence/jazzer/instrumentor/HookMethodVisitor.kt new file mode 100644 index 00000000..f5118fd6 --- /dev/null +++ b/src/main/java/com/code_intelligence/jazzer/instrumentor/HookMethodVisitor.kt @@ -0,0 +1,513 @@ +// Copyright 2021 Code Intelligence GmbH +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.code_intelligence.jazzer.instrumentor + +import com.code_intelligence.jazzer.api.HookType +import org.objectweb.asm.Handle +import org.objectweb.asm.Label +import org.objectweb.asm.MethodVisitor +import org.objectweb.asm.Opcodes +import org.objectweb.asm.Type +import org.objectweb.asm.commons.AnalyzerAdapter +import org.objectweb.asm.commons.LocalVariablesSorter +import java.util.concurrent.atomic.AtomicBoolean + +internal fun makeHookMethodVisitor( + owner: String, + access: Int, + name: String?, + descriptor: String?, + methodVisitor: MethodVisitor?, + hooks: Iterable<Hook>, + java6Mode: Boolean, + random: DeterministicRandom, + classWithHooksEnabledField: String?, +): MethodVisitor { + return HookMethodVisitor( + owner, + access, + name, + descriptor, + methodVisitor, + hooks, + java6Mode, + random, + classWithHooksEnabledField, + ).lvs +} + +private class HookMethodVisitor( + owner: String, + access: Int, + val name: String?, + descriptor: String?, + methodVisitor: MethodVisitor?, + hooks: Iterable<Hook>, + private val java6Mode: Boolean, + private val random: DeterministicRandom, + private val classWithHooksEnabledField: String?, +) : MethodVisitor( + Instrumentor.ASM_API_VERSION, + // AnalyzerAdapter computes stack map frames at every instruction, which is needed for the + // conditional hook logic as it adds a conditional jump. Before Java 7, stack map frames were + // neither included nor required in class files. + // + // Note: Delegating to AnalyzerAdapter rather than having AnalyzerAdapter delegate to our + // MethodVisitor is unusual. We do this since we insert conditional jumps around method calls, + // which requires knowing the stack map both before and after the call. If AnalyzerAdapter + // delegated to this MethodVisitor, we would only be able to access the stack map before the + // method call in visitMethodInsn. + if (classWithHooksEnabledField != null && !java6Mode) { + AnalyzerAdapter( + owner, + access, + name, + descriptor, + methodVisitor, + ) + } else { + methodVisitor + }, +) { + + companion object { + private val showUnsupportedHookWarning = AtomicBoolean(true) + } + + val lvs = object : LocalVariablesSorter(Instrumentor.ASM_API_VERSION, access, descriptor, this) { + override fun updateNewLocals(newLocals: Array<Any>) { + // The local variables involved in calling hooks do not need to outlive the current + // basic block and should thus not appear in stack map frames. By requesting the + // LocalVariableSorter to fill their entries in stack map frames with TOP, they will + // be treated like an unused local variable slot. + newLocals.fill(Opcodes.TOP) + } + } + + private val hooks = hooks.groupBy { hook -> + var hookKey = "${hook.hookType}#${hook.targetInternalClassName}#${hook.targetMethodName}" + if (hook.targetMethodDescriptor != null) { + hookKey += "#${hook.targetMethodDescriptor}" + } + hookKey + } + + override fun visitMethodInsn( + opcode: Int, + owner: String, + methodName: String, + methodDescriptor: String, + isInterface: Boolean, + ) { + if (!isMethodInvocationOp(opcode)) { + mv.visitMethodInsn(opcode, owner, methodName, methodDescriptor, isInterface) + return + } + handleMethodInsn(opcode, owner, methodName, methodDescriptor, isInterface) + } + + // Transforms a stack map specification from the form used by the JVM and AnalyzerAdapter, where + // LONG and DOUBLE values are followed by an additional TOP entry, to the form accepted by + // visitFrame, which doesn't expect this additional entry. + private fun dropImplicitTop(stack: Collection<Any>?): Array<Any>? { + if (stack == null) { + return null + } + val filteredStack = mutableListOf<Any>() + var previousElement: Any? = null + for (element in stack) { + if (element != Opcodes.TOP || (previousElement != Opcodes.DOUBLE && previousElement != Opcodes.LONG)) { + filteredStack.add(element) + } + previousElement = element + } + return filteredStack.toTypedArray() + } + + private fun storeFrame(aa: AnalyzerAdapter?): Pair<Array<Any>?, Array<Any>?>? { + return Pair(dropImplicitTop((aa ?: return null).locals), dropImplicitTop(aa.stack)) + } + + fun handleMethodInsn( + opcode: Int, + owner: String, + methodName: String, + methodDescriptor: String, + isInterface: Boolean, + ) { + val matchingHooks = findMatchingHooks(owner, methodName, methodDescriptor) + + if (matchingHooks.isEmpty()) { + mv.visitMethodInsn(opcode, owner, methodName, methodDescriptor, isInterface) + return + } + + val skipHooksLabel = Label() + val applyHooksLabel = Label() + val useConditionalHooks = classWithHooksEnabledField != null + var postCallFrame: Pair<Array<Any>?, Array<Any>?>? = null + if (useConditionalHooks) { + val preCallFrame = (mv as? AnalyzerAdapter)?.let { storeFrame(it) } + // If hooks aren't enabled, skip the hook invocations. + mv.visitFieldInsn( + Opcodes.GETSTATIC, + classWithHooksEnabledField, + "hooksEnabled", + "Z", + ) + mv.visitJumpInsn(Opcodes.IFNE, applyHooksLabel) + mv.visitMethodInsn(opcode, owner, methodName, methodDescriptor, isInterface) + postCallFrame = (mv as? AnalyzerAdapter)?.let { storeFrame(it) } + mv.visitJumpInsn(Opcodes.GOTO, skipHooksLabel) + // Needs a stack map frame as both the successor of an unconditional jump and the target + // of a jump. + mv.visitLabel(applyHooksLabel) + if (preCallFrame != null) { + mv.visitFrame( + Opcodes.F_NEW, + preCallFrame.first?.size ?: 0, + preCallFrame.first, + preCallFrame.second?.size ?: 0, + preCallFrame.second, + ) + } + // All successor instructions emitted below do not have a stack map frame attached, so + // we do not need to emit a NOP to prevent duplicated stack map frames. + } + + val paramDescriptors = extractParameterTypeDescriptors(methodDescriptor) + val localObjArr = storeMethodArguments(paramDescriptors) + // If the method we're hooking is not static there is now a reference to + // the object the method was invoked on at the top of the stack. + // If the method is static, that object is missing. We make up for it by pushing a null ref. + if (opcode == Opcodes.INVOKESTATIC) { + mv.visitInsn(Opcodes.ACONST_NULL) + } + + // Save the owner object to a new local variable + val ownerDescriptor = "L$owner;" + val localOwnerObj = lvs.newLocal(Type.getType(ownerDescriptor)) + mv.visitVarInsn(Opcodes.ASTORE, localOwnerObj) // consume objectref + // We now removed all values for the original method call from the operand stack + // and saved them to local variables. + + val returnTypeDescriptor = extractReturnTypeDescriptor(methodDescriptor) + // Create a local variable to store the return value + val localReturnObj = lvs.newLocal(Type.getType(getWrapperTypeDescriptor(returnTypeDescriptor))) + + matchingHooks.forEachIndexed { index, hook -> + // The hookId is used to identify a call site. + val hookId = random.nextInt() + + // Start to build the arguments for the hook method. + if (methodName == "<init>") { + // Constructor is invoked on an uninitialized object, and that's still on the stack. + // In case of REPLACE pop it from the stack and replace it afterwards with the returned + // one from the hook. + if (hook.hookType == HookType.REPLACE) { + mv.visitInsn(Opcodes.POP) + } + // Special case for constructors: + // We cannot create a MethodHandle for a constructor, so we push null instead. + mv.visitInsn(Opcodes.ACONST_NULL) // push nullref + // Only pass the this object if it has been initialized by the time the hook is invoked. + if (hook.hookType == HookType.AFTER) { + mv.visitVarInsn(Opcodes.ALOAD, localOwnerObj) + } else { + mv.visitInsn(Opcodes.ACONST_NULL) // push nullref + } + } else { + // Push a MethodHandle representing the hooked method. + val handleOpcode = when (opcode) { + Opcodes.INVOKEVIRTUAL -> Opcodes.H_INVOKEVIRTUAL + Opcodes.INVOKEINTERFACE -> Opcodes.H_INVOKEINTERFACE + Opcodes.INVOKESTATIC -> Opcodes.H_INVOKESTATIC + Opcodes.INVOKESPECIAL -> Opcodes.H_INVOKESPECIAL + else -> -1 + } + if (java6Mode) { + // MethodHandle constants (type 15) are not supported in Java 6 class files (major version 50). + mv.visitInsn(Opcodes.ACONST_NULL) // push nullref + } else { + mv.visitLdcInsn( + Handle( + handleOpcode, + owner, + methodName, + methodDescriptor, + isInterface, + ), + ) // push MethodHandle + } + // Stack layout: ... | MethodHandle (objectref) + // Push the owner object again + mv.visitVarInsn(Opcodes.ALOAD, localOwnerObj) + } + // Stack layout: ... | MethodHandle (objectref) | owner (objectref) + // Push a reference to our object array with the saved arguments + mv.visitVarInsn(Opcodes.ALOAD, localObjArr) + // Stack layout: ... | MethodHandle (objectref) | owner (objectref) | object array (arrayref) + // Push the hook id + mv.visitLdcInsn(hookId) + // Stack layout: ... | MethodHandle (objectref) | owner (objectref) | object array (arrayref) | hookId (int) + // How we proceed depends on the type of hook we want to implement + when (hook.hookType) { + HookType.BEFORE -> { + // Call the hook method + mv.visitMethodInsn( + Opcodes.INVOKESTATIC, + hook.hookInternalClassName, + hook.hookMethodName, + hook.hookMethodDescriptor, + false, + ) + + // Call the original method if this is the last BEFORE hook. If not, the original method will be + // called by the next AFTER hook. + if (index == matchingHooks.lastIndex) { + // Stack layout: ... + // Push the values for the original method call onto the stack again + if (opcode != Opcodes.INVOKESTATIC) { + mv.visitVarInsn(Opcodes.ALOAD, localOwnerObj) // push owner object + } + loadMethodArguments(paramDescriptors, localObjArr) // push all method arguments + // Stack layout: ... | [owner (objectref)] | arg1 (primitive/objectref) | arg2 (primitive/objectref) | ... + mv.visitMethodInsn(opcode, owner, methodName, methodDescriptor, isInterface) + } + } + + HookType.REPLACE -> { + // Call the hook method + mv.visitMethodInsn( + Opcodes.INVOKESTATIC, + hook.hookInternalClassName, + hook.hookMethodName, + hook.hookMethodDescriptor, + false, + ) + // Stack layout: ... | [return value (primitive/objectref)] + // Check if we need to process the return value + if (returnTypeDescriptor != "V") { + val hookMethodReturnType = extractReturnTypeDescriptor(hook.hookMethodDescriptor) + // if the hook method's return type is primitive we don't need to unwrap or cast it + if (!isPrimitiveType(hookMethodReturnType)) { + // Check if the returned object type is different than the one that should be returned + // If a primitive should be returned we check it's wrapper type + val expectedType = getWrapperTypeDescriptor(returnTypeDescriptor) + if (expectedType != hookMethodReturnType) { + // Cast object + mv.visitTypeInsn(Opcodes.CHECKCAST, extractInternalClassName(expectedType)) + } + // Check if we need to unwrap the returned object + unwrapTypeIfPrimitive(returnTypeDescriptor) + } + } + } + + HookType.AFTER -> { + // Call the original method before the first AFTER hook + if (index == 0 || matchingHooks[index - 1].hookType != HookType.AFTER) { + // Push the values for the original method call again onto the stack + if (opcode != Opcodes.INVOKESTATIC) { + mv.visitVarInsn(Opcodes.ALOAD, localOwnerObj) // push owner object + } + loadMethodArguments(paramDescriptors, localObjArr) // push all method arguments + // Stack layout: ... | MethodHandle (objectref) | owner (objectref) | object array (arrayref) | hookId (int) + // | [owner (objectref)] | arg1 (primitive/objectref) | arg2 (primitive/objectref) | ... + mv.visitMethodInsn(opcode, owner, methodName, methodDescriptor, isInterface) + if (returnTypeDescriptor == "V") { + // If the method didn't return anything, we push a nullref as placeholder + mv.visitInsn(Opcodes.ACONST_NULL) // push nullref + } + // Wrap return value if it is a primitive type + wrapTypeIfPrimitive(returnTypeDescriptor) + mv.visitVarInsn(Opcodes.ASTORE, localReturnObj) // consume objectref + } + mv.visitVarInsn(Opcodes.ALOAD, localReturnObj) // push objectref + + // Stack layout: ... | MethodHandle (objectref) | owner (objectref) | object array (arrayref) | hookId (int) + // | return value (objectref) + // Store the result value in a local variable (but keep it on the stack) + // Call the hook method + mv.visitMethodInsn( + Opcodes.INVOKESTATIC, + hook.hookInternalClassName, + hook.hookMethodName, + hook.hookMethodDescriptor, + false, + ) + // Stack layout: ... + // Push the return value on the stack after the last AFTER hook if the original method returns a value + if (index == matchingHooks.size - 1 && returnTypeDescriptor != "V") { + // Push the return value again + mv.visitVarInsn(Opcodes.ALOAD, localReturnObj) // push objectref + // Unwrap it, if it was a primitive value + unwrapTypeIfPrimitive(returnTypeDescriptor) + // Stack layout: ... | return value (primitive/objectref) + } + } + } + } + if (useConditionalHooks) { + // Needs a stack map frame as the target of a jump. + mv.visitLabel(skipHooksLabel) + if (postCallFrame != null) { + mv.visitFrame( + Opcodes.F_NEW, + postCallFrame.first?.size ?: 0, + postCallFrame.first, + postCallFrame.second?.size ?: 0, + postCallFrame.second, + ) + } + // We do not control the next visitor calls, but we must not emit two frames for the + // same instruction. + mv.visitInsn(Opcodes.NOP) + } + } + + private fun isMethodInvocationOp(opcode: Int) = opcode in listOf( + Opcodes.INVOKEVIRTUAL, + Opcodes.INVOKEINTERFACE, + Opcodes.INVOKESTATIC, + Opcodes.INVOKESPECIAL, + ) + + private fun findMatchingHooks(owner: String, name: String, descriptor: String): List<Hook> { + val result = HookType.values().flatMap { hookType -> + val withoutDescriptorKey = "$hookType#$owner#$name" + val withDescriptorKey = "$withoutDescriptorKey#$descriptor" + hooks[withDescriptorKey].orEmpty() + hooks[withoutDescriptorKey].orEmpty() + }.sortedBy { it.hookType } + val replaceHookCount = result.count { it.hookType == HookType.REPLACE } + check( + replaceHookCount == 0 || + (replaceHookCount == 1 && result.size == 1), + ) { + "For a given method, You can either have a single REPLACE hook or BEFORE/AFTER hooks. Found:\n $result" + } + + return result + .filter { !isReplaceHookInJava6mode(it) } + .sortedByDescending { it.toString() } + } + + private fun isReplaceHookInJava6mode(hook: Hook): Boolean { + if (java6Mode && hook.hookType == HookType.REPLACE) { + if (showUnsupportedHookWarning.getAndSet(false)) { + println( + """WARN: Some hooks could not be applied to class files built for Java 7 or lower. + |WARN: Ensure that the fuzz target and its dependencies are compiled with + |WARN: -target 8 or higher to identify as many bugs as possible. + """.trimMargin(), + ) + } + return true + } + return false + } + + // Stores all arguments for a method call in a local object array. + // paramDescriptors: The type descriptors for all method arguments + private fun storeMethodArguments(paramDescriptors: List<String>): Int { + // Allocate a new Object[] for the methods parameters. + mv.visitIntInsn(Opcodes.SIPUSH, paramDescriptors.size) + mv.visitTypeInsn(Opcodes.ANEWARRAY, "java/lang/Object") + val localObjArr = lvs.newLocal(Type.getType("[Ljava/lang/Object;")) + mv.visitVarInsn(Opcodes.ASTORE, localObjArr) + + // Loop over all arguments in reverse order (because the last argument is on top). + for ((argIdx, argDescriptor) in paramDescriptors.withIndex().reversed()) { + // If the argument is a primitive type, wrap it in it's wrapper class + wrapTypeIfPrimitive(argDescriptor) + // Store the argument in our object array, for that we need to shape the stack first. + // Stack layout: ... | method argument (objectref) + mv.visitVarInsn(Opcodes.ALOAD, localObjArr) + // Stack layout: ... | method argument (objectref) | object array (arrayref) + mv.visitInsn(Opcodes.SWAP) + // Stack layout: ... | object array (arrayref) | method argument (objectref) + mv.visitIntInsn(Opcodes.SIPUSH, argIdx) + // Stack layout: ... | object array (arrayref) | method argument (objectref) | argument index (int) + mv.visitInsn(Opcodes.SWAP) + // Stack layout: ... | object array (arrayref) | argument index (int) | method argument (objectref) + mv.visitInsn(Opcodes.AASTORE) // consume all three: arrayref, index, value + // Stack layout: ... + // Continue with the remaining method arguments + } + + // Return a reference to the array with the parameters. + return localObjArr + } + + // Loads all arguments for a method call from a local object array. + // argTypeSigs: The type signatures for all method arguments + // localObjArr: Index of a local variable containing an object array where the arguments will be loaded from + private fun loadMethodArguments(paramDescriptors: List<String>, localObjArr: Int) { + // Loop over all arguments + for ((argIdx, argDescriptor) in paramDescriptors.withIndex()) { + // Push a reference to the object array on the stack + mv.visitVarInsn(Opcodes.ALOAD, localObjArr) + // Stack layout: ... | object array (arrayref) + // Push the index of the current argument on the stack + mv.visitIntInsn(Opcodes.SIPUSH, argIdx) + // Stack layout: ... | object array (arrayref) | argument index (int) + // Load the argument from the array + mv.visitInsn(Opcodes.AALOAD) + // Stack layout: ... | method argument (objectref) + // Cast object to it's original type (or it's wrapper object) + val wrapperTypeDescriptor = getWrapperTypeDescriptor(argDescriptor) + mv.visitTypeInsn(Opcodes.CHECKCAST, extractInternalClassName(wrapperTypeDescriptor)) + // If the argument is a supposed to be a primitive type, unwrap the wrapped type + unwrapTypeIfPrimitive(argDescriptor) + // Stack layout: ... | method argument (primitive/objectref) + // Continue with the remaining method arguments + } + } + + // Removes a primitive value from the top of the operand stack + // and pushes it enclosed in its wrapper type (e.g. removes int, pushes Integer). + // This is done by calling .valueOf(...) on the wrapper class. + private fun wrapTypeIfPrimitive(unwrappedTypeDescriptor: String) { + if (!isPrimitiveType(unwrappedTypeDescriptor) || unwrappedTypeDescriptor == "V") return + val wrapperTypeDescriptor = getWrapperTypeDescriptor(unwrappedTypeDescriptor) + val wrapperType = extractInternalClassName(wrapperTypeDescriptor) + val valueOfDescriptor = "($unwrappedTypeDescriptor)$wrapperTypeDescriptor" + mv.visitMethodInsn(Opcodes.INVOKESTATIC, wrapperType, "valueOf", valueOfDescriptor, false) + } + + // Removes a wrapper object around a given primitive type from the top of the operand stack + // and pushes the primitive value it contains (e.g. removes Integer, pushes int). + // This is done by calling .intValue(...) / .charValue(...) / ... on the wrapper object. + private fun unwrapTypeIfPrimitive(primitiveTypeDescriptor: String) { + val (methodName, wrappedTypeDescriptor) = when (primitiveTypeDescriptor) { + "B" -> Pair("byteValue", "java/lang/Byte") + "C" -> Pair("charValue", "java/lang/Character") + "D" -> Pair("doubleValue", "java/lang/Double") + "F" -> Pair("floatValue", "java/lang/Float") + "I" -> Pair("intValue", "java/lang/Integer") + "J" -> Pair("longValue", "java/lang/Long") + "S" -> Pair("shortValue", "java/lang/Short") + "Z" -> Pair("booleanValue", "java/lang/Boolean") + else -> return + } + mv.visitMethodInsn( + Opcodes.INVOKEVIRTUAL, + wrappedTypeDescriptor, + methodName, + "()$primitiveTypeDescriptor", + false, + ) + } +} diff --git a/src/main/java/com/code_intelligence/jazzer/instrumentor/Hooks.kt b/src/main/java/com/code_intelligence/jazzer/instrumentor/Hooks.kt new file mode 100644 index 00000000..a26c0d6b --- /dev/null +++ b/src/main/java/com/code_intelligence/jazzer/instrumentor/Hooks.kt @@ -0,0 +1,132 @@ +// Copyright 2021 Code Intelligence GmbH +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.code_intelligence.jazzer.instrumentor + +import com.code_intelligence.jazzer.api.MethodHook +import com.code_intelligence.jazzer.api.MethodHooks +import com.code_intelligence.jazzer.utils.ClassNameGlobber +import com.code_intelligence.jazzer.utils.Log +import io.github.classgraph.ClassGraph +import io.github.classgraph.ScanResult +import java.lang.instrument.Instrumentation +import java.lang.reflect.Method +import java.util.jar.JarFile + +data class Hooks( + val hooks: List<Hook>, + val hookClasses: Set<Class<*>>, + val additionalHookClassNameGlobber: ClassNameGlobber, +) { + + companion object { + + fun appendHooksToBootstrapClassLoaderSearch(instrumentation: Instrumentation, hookClassNames: Set<String>) { + hookClassNames.mapNotNull { hook -> + val hookClassFilePath = "/${hook.replace('.', '/')}.class" + val hookClassFile = Companion::class.java.getResource(hookClassFilePath) ?: return@mapNotNull null + if ("jar" != hookClassFile.protocol) { + return@mapNotNull null + } + // hookClassFile.file looks as follows: + // file:/tmp/ExampleFuzzerHooks_deploy.jar!/com/example/ExampleFuzzerHooks.class + hookClassFile.file.removePrefix("file:").takeWhile { it != '!' } + } + .toSet() + .map { JarFile(it) } + .forEach { instrumentation.appendToBootstrapClassLoaderSearch(it) } + } + + fun loadHooks(excludeHookClassNames: List<String>, vararg hookClassNames: Set<String>): List<Hooks> { + return ClassGraph() + .enableClassInfo() + .enableSystemJarsAndModules() + .rejectPackages("jaz.*", "com.code_intelligence.jazzer.*") + .scan() + .use { scanResult -> + // Capture scanResult in HooksLoader field to not pass it through + // all internal hook loading methods. + val loader = HooksLoader(scanResult, excludeHookClassNames) + hookClassNames.map(loader::load) + } + } + + private class HooksLoader(private val scanResult: ScanResult, val excludeHookClassNames: List<String>) { + + fun load(hookClassNames: Set<String>): Hooks { + val hooksWithHookClasses = hookClassNames.flatMap(::loadHooks) + val hooks = hooksWithHookClasses.map { it.first } + val hookClasses = hooksWithHookClasses.map { it.second }.toSet() + val additionalHookClassNameGlobber = ClassNameGlobber( + hooks.flatMap(Hook::additionalClassesToHook), + excludeHookClassNames, + ) + return Hooks(hooks, hookClasses, additionalHookClassNameGlobber) + } + + private fun loadHooks(hookClassName: String): List<Pair<Hook, Class<*>>> { + return try { + // We let the static initializers of hook classes execute so that hooks can run + // code before the fuzz target class has been loaded (e.g., register themselves + // for the onFuzzTargetReady callback). + val hookClass = + Class.forName(hookClassName, true, Companion::class.java.classLoader) + loadHooks(hookClass).also { + Log.info("Loaded ${it.size} hooks from $hookClassName") + }.map { + it to hookClass + } + } catch (e: ClassNotFoundException) { + Log.warn("Failed to load hooks from $hookClassName", e) + emptyList() + } + } + + private fun loadHooks(hookClass: Class<*>): List<Hook> { + val hooks = mutableListOf<Hook>() + for (method in hookClass.methods.sortedBy { it.descriptor }) { + method.getAnnotation(MethodHook::class.java)?.let { + hooks.addAll(verifyAndGetHooks(method, it)) + } + method.getAnnotation(MethodHooks::class.java)?.let { + it.value.forEach { hookAnnotation -> + hooks.addAll(verifyAndGetHooks(method, hookAnnotation)) + } + } + } + return hooks + } + + private fun verifyAndGetHooks(hookMethod: Method, hookData: MethodHook): List<Hook> { + return lookupClassesToHook(hookData.targetClassName) + .map { className -> + Hook.createAndVerifyHook(hookMethod, hookData, className) + } + } + + private fun lookupClassesToHook(annotationTargetClassName: String): List<String> { + // Allowing arbitrary exterior whitespace in the target class name allows for an easy workaround + // for mangled hooks due to shading applied to hooks. + val targetClassName = annotationTargetClassName.trim() + val targetClassInfo = scanResult.getClassInfo(targetClassName) ?: return listOf(targetClassName) + val additionalTargetClasses = when { + targetClassInfo.isInterface -> scanResult.getClassesImplementing(targetClassName) + targetClassInfo.isAbstract -> scanResult.getSubclasses(targetClassName) + else -> emptyList() + } + return (listOf(targetClassName) + additionalTargetClasses.map { it.name }).sorted() + } + } + } +} diff --git a/src/main/java/com/code_intelligence/jazzer/instrumentor/Instrumentor.kt b/src/main/java/com/code_intelligence/jazzer/instrumentor/Instrumentor.kt new file mode 100644 index 00000000..c6db94c0 --- /dev/null +++ b/src/main/java/com/code_intelligence/jazzer/instrumentor/Instrumentor.kt @@ -0,0 +1,45 @@ +// Copyright 2021 Code Intelligence GmbH +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.code_intelligence.jazzer.instrumentor + +import org.objectweb.asm.Opcodes +import org.objectweb.asm.tree.MethodNode + +enum class InstrumentationType { + CMP, + COV, + DIV, + GEP, + INDIR, + NATIVE, +} + +internal interface Instrumentor { + fun instrument(internalClassName: String, bytecode: ByteArray): ByteArray + + fun shouldInstrument(access: Int): Boolean { + return (access and Opcodes.ACC_ABSTRACT == 0) && + (access and Opcodes.ACC_NATIVE == 0) + } + + fun shouldInstrument(method: MethodNode): Boolean { + return shouldInstrument(method.access) && + method.instructions.size() > 0 + } + + companion object { + const val ASM_API_VERSION = Opcodes.ASM9 + } +} diff --git a/src/main/java/com/code_intelligence/jazzer/instrumentor/StaticMethodStrategy.java b/src/main/java/com/code_intelligence/jazzer/instrumentor/StaticMethodStrategy.java new file mode 100644 index 00000000..0512ec2a --- /dev/null +++ b/src/main/java/com/code_intelligence/jazzer/instrumentor/StaticMethodStrategy.java @@ -0,0 +1,48 @@ +// Copyright 2022 Code Intelligence GmbH +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.code_intelligence.jazzer.instrumentor; + +import com.code_intelligence.jazzer.third_party.org.jacoco.core.internal.instr.InstrSupport; +import org.objectweb.asm.MethodVisitor; +import org.objectweb.asm.Opcodes; + +public class StaticMethodStrategy implements EdgeCoverageStrategy { + @Override + public void instrumentControlFlowEdge( + MethodVisitor mv, int edgeId, int variable, String coverageMapInternalClassName) { + InstrSupport.push(mv, edgeId); + mv.visitMethodInsn( + Opcodes.INVOKESTATIC, coverageMapInternalClassName, "recordCoverage", "(I)V", false); + } + + @Override + public int getInstrumentControlFlowEdgeStackSize() { + return 1; + } + + @Override + public Object getLocalVariableType() { + return null; + } + + @Override + public void loadLocalVariable( + MethodVisitor mv, int variable, String coverageMapInternalClassName) {} + + @Override + public int getLoadLocalVariableStackSize() { + return 0; + } +} diff --git a/src/main/java/com/code_intelligence/jazzer/instrumentor/TraceDataFlowInstrumentor.kt b/src/main/java/com/code_intelligence/jazzer/instrumentor/TraceDataFlowInstrumentor.kt new file mode 100644 index 00000000..a46e72df --- /dev/null +++ b/src/main/java/com/code_intelligence/jazzer/instrumentor/TraceDataFlowInstrumentor.kt @@ -0,0 +1,268 @@ +// Copyright 2021 Code Intelligence GmbH +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.code_intelligence.jazzer.instrumentor + +import org.objectweb.asm.ClassReader +import org.objectweb.asm.ClassWriter +import org.objectweb.asm.Opcodes +import org.objectweb.asm.tree.AbstractInsnNode +import org.objectweb.asm.tree.ClassNode +import org.objectweb.asm.tree.InsnList +import org.objectweb.asm.tree.InsnNode +import org.objectweb.asm.tree.IntInsnNode +import org.objectweb.asm.tree.LdcInsnNode +import org.objectweb.asm.tree.LookupSwitchInsnNode +import org.objectweb.asm.tree.MethodInsnNode +import org.objectweb.asm.tree.MethodNode +import org.objectweb.asm.tree.TableSwitchInsnNode + +internal class TraceDataFlowInstrumentor( + private val types: Set<InstrumentationType>, + private val callbackInternalClassName: String = "com/code_intelligence/jazzer/runtime/TraceDataFlowNativeCallbacks", +) : Instrumentor { + + private lateinit var random: DeterministicRandom + + override fun instrument(internalClassName: String, bytecode: ByteArray): ByteArray { + val node = ClassNode() + val reader = ClassReader(bytecode) + reader.accept(node, 0) + random = DeterministicRandom("trace", node.name) + for (method in node.methods) { + if (shouldInstrument(method)) { + addDataFlowInstrumentation(method) + } + } + + val writer = ClassWriter(ClassWriter.COMPUTE_MAXS) + node.accept(writer) + return writer.toByteArray() + } + + private fun addDataFlowInstrumentation(method: MethodNode) { + loop@ for (inst in method.instructions.toArray()) { + when (inst.opcode) { + Opcodes.LCMP -> { + if (InstrumentationType.CMP !in types) continue@loop + method.instructions.insertBefore(inst, longCmpInstrumentation()) + method.instructions.remove(inst) + } + Opcodes.IF_ICMPEQ, Opcodes.IF_ICMPNE, + Opcodes.IF_ICMPLT, Opcodes.IF_ICMPLE, + Opcodes.IF_ICMPGT, Opcodes.IF_ICMPGE, + -> { + if (InstrumentationType.CMP !in types) continue@loop + method.instructions.insertBefore(inst, intCmpInstrumentation()) + } + Opcodes.IFEQ, Opcodes.IFNE, + Opcodes.IFLT, Opcodes.IFLE, + Opcodes.IFGT, Opcodes.IFGE, + -> { + if (InstrumentationType.CMP !in types) continue@loop + // The IF* opcodes are often used to branch based on the result of a compare + // instruction for a type other than int. The operands of this compare will + // already be reported via the instrumentation above (for non-floating point + // numbers) and the follow-up compare does not provide a good signal as all + // operands will be in {-1, 0, 1}. Skip instrumentation for it. + if (inst.previous?.opcode in listOf(Opcodes.DCMPG, Opcodes.DCMPL, Opcodes.FCMPG, Opcodes.DCMPL) || + (inst.previous as? MethodInsnNode)?.name == "traceCmpLongWrapper" + ) { + continue@loop + } + method.instructions.insertBefore(inst, ifInstrumentation()) + } + Opcodes.LOOKUPSWITCH, Opcodes.TABLESWITCH -> { + if (InstrumentationType.CMP !in types) continue@loop + // Mimic the exclusion logic for small label values in libFuzzer: + // https://github.com/llvm-mirror/compiler-rt/blob/69445f095c22aac2388f939bedebf224a6efcdaf/lib/fuzzer/FuzzerTracePC.cpp#L520 + // Case values are reported to libFuzzer via an array of unsigned long values and thus need to be + // sorted by unsigned value. + val caseValues = when (inst) { + is LookupSwitchInsnNode -> { + if (inst.keys.isEmpty() || (0 <= inst.keys.first() && inst.keys.last() < 256)) { + continue@loop + } + inst.keys + } + is TableSwitchInsnNode -> { + if (0 <= inst.min && inst.max < 256) { + continue@loop + } + (inst.min..inst.max).filter { caseValue -> + val index = caseValue - inst.min + // Filter out "gap cases". + inst.labels[index].label != inst.dflt.label + }.toList() + } + // Not reached. + else -> continue@loop + }.sortedBy { it.toUInt() }.map { it.toLong() }.toLongArray() + method.instructions.insertBefore(inst, switchInstrumentation(caseValues)) + } + Opcodes.IDIV -> { + if (InstrumentationType.DIV !in types) continue@loop + method.instructions.insertBefore(inst, intDivInstrumentation()) + } + Opcodes.LDIV -> { + if (InstrumentationType.DIV !in types) continue@loop + method.instructions.insertBefore(inst, longDivInstrumentation()) + } + Opcodes.AALOAD, Opcodes.BALOAD, + Opcodes.CALOAD, Opcodes.DALOAD, + Opcodes.FALOAD, Opcodes.IALOAD, + Opcodes.LALOAD, Opcodes.SALOAD, + -> { + if (InstrumentationType.GEP !in types) continue@loop + if (!isConstantIntegerPushInsn(inst.previous)) continue@loop + method.instructions.insertBefore(inst, gepLoadInstrumentation()) + } + Opcodes.INVOKEINTERFACE, Opcodes.INVOKESPECIAL, Opcodes.INVOKESTATIC, Opcodes.INVOKEVIRTUAL -> { + if (InstrumentationType.GEP !in types) continue@loop + if (!isGepLoadMethodInsn(inst as MethodInsnNode)) continue@loop + if (!isConstantIntegerPushInsn(inst.previous)) continue@loop + method.instructions.insertBefore(inst, gepLoadInstrumentation()) + } + } + } + } + + private fun InsnList.pushFakePc() { + add(LdcInsnNode(random.nextInt(512))) + } + + private fun longCmpInstrumentation() = InsnList().apply { + pushFakePc() + // traceCmpLong returns the result of the comparison as duplicating two longs on the stack + // is not possible without local variables. + add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceCmpLongWrapper", "(JJI)I", false)) + } + + private fun intCmpInstrumentation() = InsnList().apply { + add(InsnNode(Opcodes.DUP2)) + pushFakePc() + add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceCmpInt", "(III)V", false)) + } + + private fun ifInstrumentation() = InsnList().apply { + add(InsnNode(Opcodes.DUP)) + // All if* instructions are compares to the constant 0. + add(InsnNode(Opcodes.ICONST_0)) + add(InsnNode(Opcodes.SWAP)) + pushFakePc() + add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceConstCmpInt", "(III)V", false)) + } + + private fun intDivInstrumentation() = InsnList().apply { + add(InsnNode(Opcodes.DUP)) + pushFakePc() + add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceDivInt", "(II)V", false)) + } + + private fun longDivInstrumentation() = InsnList().apply { + add(InsnNode(Opcodes.DUP2)) + pushFakePc() + add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceDivLong", "(JI)V", false)) + } + + private fun switchInstrumentation(caseValues: LongArray) = InsnList().apply { + // duplicate {lookup,table}switch key for use as first function argument + add(InsnNode(Opcodes.DUP)) + add(InsnNode(Opcodes.I2L)) + // Set up array with switch case values. The format libfuzzer expects is created here directly, i.e., the first + // two entries are the number of cases and the bit size of values (always 32). + add(IntInsnNode(Opcodes.SIPUSH, caseValues.size + 2)) + add(IntInsnNode(Opcodes.NEWARRAY, Opcodes.T_LONG)) + // Store number of cases + add(InsnNode(Opcodes.DUP)) + add(IntInsnNode(Opcodes.SIPUSH, 0)) + add(LdcInsnNode(caseValues.size.toLong())) + add(InsnNode(Opcodes.LASTORE)) + // Store bit size of keys + add(InsnNode(Opcodes.DUP)) + add(IntInsnNode(Opcodes.SIPUSH, 1)) + add(LdcInsnNode(32.toLong())) + add(InsnNode(Opcodes.LASTORE)) + // Store {lookup,table}switch case values + for ((i, caseValue) in caseValues.withIndex()) { + add(InsnNode(Opcodes.DUP)) + add(IntInsnNode(Opcodes.SIPUSH, 2 + i)) + add(LdcInsnNode(caseValue)) + add(InsnNode(Opcodes.LASTORE)) + } + pushFakePc() + // call the native callback function + add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceSwitch", "(J[JI)V", false)) + } + + /** + * Returns true if [node] represents an instruction that possibly pushes a valid, non-zero, constant array index + * onto the stack. + */ + private fun isConstantIntegerPushInsn(node: AbstractInsnNode?) = node?.opcode in CONSTANT_INTEGER_PUSH_OPCODES + + /** + * Returns true if [node] represents a call to a method that performs an indexed lookup into an array-like + * structure. + */ + private fun isGepLoadMethodInsn(node: MethodInsnNode): Boolean { + if (!node.desc.startsWith("(I)")) return false + val returnType = node.desc.removePrefix("(I)") + return MethodInfo(node.owner, node.name, returnType) in GEP_LOAD_METHODS + } + + private fun gepLoadInstrumentation() = InsnList().apply { + // Duplicate the index and convert to long. + add(InsnNode(Opcodes.DUP)) + add(InsnNode(Opcodes.I2L)) + pushFakePc() + add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceGep", "(JI)V", false)) + } + + companion object { + // Low constants (0, 1) are omitted as they create a lot of noise. + val CONSTANT_INTEGER_PUSH_OPCODES = listOf( + Opcodes.BIPUSH, + Opcodes.SIPUSH, + Opcodes.LDC, + Opcodes.ICONST_2, + Opcodes.ICONST_3, + Opcodes.ICONST_4, + Opcodes.ICONST_5, + ) + + data class MethodInfo(val internalClassName: String, val name: String, val returnType: String) + + val GEP_LOAD_METHODS = setOf( + MethodInfo("java/util/AbstractList", "get", "Ljava/lang/Object;"), + MethodInfo("java/util/ArrayList", "get", "Ljava/lang/Object;"), + MethodInfo("java/util/List", "get", "Ljava/lang/Object;"), + MethodInfo("java/util/Stack", "get", "Ljava/lang/Object;"), + MethodInfo("java/util/Vector", "get", "Ljava/lang/Object;"), + MethodInfo("java/lang/CharSequence", "charAt", "C"), + MethodInfo("java/lang/String", "charAt", "C"), + MethodInfo("java/lang/StringBuffer", "charAt", "C"), + MethodInfo("java/lang/StringBuilder", "charAt", "C"), + MethodInfo("java/lang/String", "codePointAt", "I"), + MethodInfo("java/lang/String", "codePointBefore", "I"), + MethodInfo("java/nio/ByteBuffer", "get", "B"), + MethodInfo("java/nio/ByteBuffer", "getChar", "C"), + MethodInfo("java/nio/ByteBuffer", "getDouble", "D"), + MethodInfo("java/nio/ByteBuffer", "getFloat", "F"), + MethodInfo("java/nio/ByteBuffer", "getInt", "I"), + MethodInfo("java/nio/ByteBuffer", "getLong", "J"), + MethodInfo("java/nio/ByteBuffer", "getShort", "S"), + ) + } +} |