/* american fuzzy lop++ - LLVM CmpLog instrumentation -------------------------------------------------- Written by Andrea Fioraldi Copyright 2015, 2016 Google Inc. All rights reserved. Copyright 2019-2020 AFLplusplus Project. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at: http://www.apache.org/licenses/LICENSE-2.0 */ #include #include #include #include #include #include #include #include "llvm/Config/llvm-config.h" #include "llvm/ADT/Statistic.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Module.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO/PassManagerBuilder.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Pass.h" #include "llvm/Analysis/ValueTracking.h" #if LLVM_VERSION_MAJOR > 3 || \ (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 4) #include "llvm/IR/Verifier.h" #include "llvm/IR/DebugInfo.h" #else #include "llvm/Analysis/Verifier.h" #include "llvm/DebugInfo.h" #define nullptr 0 #endif #include #include "afl-llvm-common.h" using namespace llvm; namespace { class CmpLogInstructions : public ModulePass { public: static char ID; CmpLogInstructions() : ModulePass(ID) { initInstrumentList(); } bool runOnModule(Module &M) override; #if LLVM_VERSION_MAJOR < 4 const char *getPassName() const override { #else StringRef getPassName() const override { #endif return "cmplog instructions"; } private: bool hookInstrs(Module &M); }; } // namespace char CmpLogInstructions::ID = 0; template Iterator Unique(Iterator first, Iterator last) { while (first != last) { Iterator next(first); last = std::remove(++next, last, *first); first = next; } return last; } bool CmpLogInstructions::hookInstrs(Module &M) { std::vector icomps; std::vector switches; LLVMContext & C = M.getContext(); Type * VoidTy = Type::getVoidTy(C); IntegerType *Int8Ty = IntegerType::getInt8Ty(C); IntegerType *Int16Ty = IntegerType::getInt16Ty(C); IntegerType *Int32Ty = IntegerType::getInt32Ty(C); IntegerType *Int64Ty = IntegerType::getInt64Ty(C); IntegerType *Int128Ty = IntegerType::getInt128Ty(C); #if LLVM_VERSION_MAJOR < 9 Constant * #else FunctionCallee #endif c1 = M.getOrInsertFunction("__cmplog_ins_hook1", VoidTy, Int8Ty, Int8Ty, Int8Ty #if LLVM_VERSION_MAJOR < 5 , NULL #endif ); #if LLVM_VERSION_MAJOR < 9 Function *cmplogHookIns1 = cast(c1); #else FunctionCallee cmplogHookIns1 = c1; #endif #if LLVM_VERSION_MAJOR < 9 Constant * #else FunctionCallee #endif c2 = M.getOrInsertFunction("__cmplog_ins_hook2", VoidTy, Int16Ty, Int16Ty, Int8Ty #if LLVM_VERSION_MAJOR < 5 , NULL #endif ); #if LLVM_VERSION_MAJOR < 9 Function *cmplogHookIns2 = cast(c2); #else FunctionCallee cmplogHookIns2 = c2; #endif #if LLVM_VERSION_MAJOR < 9 Constant * #else FunctionCallee #endif c4 = M.getOrInsertFunction("__cmplog_ins_hook4", VoidTy, Int32Ty, Int32Ty, Int8Ty #if LLVM_VERSION_MAJOR < 5 , NULL #endif ); #if LLVM_VERSION_MAJOR < 9 Function *cmplogHookIns4 = cast(c4); #else FunctionCallee cmplogHookIns4 = c4; #endif #if LLVM_VERSION_MAJOR < 9 Constant * #else FunctionCallee #endif c8 = M.getOrInsertFunction("__cmplog_ins_hook8", VoidTy, Int64Ty, Int64Ty, Int8Ty #if LLVM_VERSION_MAJOR < 5 , NULL #endif ); #if LLVM_VERSION_MAJOR < 9 Function *cmplogHookIns8 = cast(c8); #else FunctionCallee cmplogHookIns8 = c8; #endif #if LLVM_VERSION_MAJOR < 9 Constant * #else FunctionCallee #endif c16 = M.getOrInsertFunction("__cmplog_ins_hook16", VoidTy, Int128Ty, Int128Ty, Int8Ty #if LLVM_VERSION_MAJOR < 5 , NULL #endif ); #if LLVM_VERSION_MAJOR < 9 Function *cmplogHookIns16 = cast(c16); #else FunctionCallee cmplogHookIns16 = c16; #endif #if LLVM_VERSION_MAJOR < 9 Constant * #else FunctionCallee #endif cN = M.getOrInsertFunction("__cmplog_ins_hookN", VoidTy, Int128Ty, Int128Ty, Int8Ty, Int8Ty #if LLVM_VERSION_MAJOR < 5 , NULL #endif ); #if LLVM_VERSION_MAJOR < 9 Function *cmplogHookInsN = cast(cN); #else FunctionCallee cmplogHookInsN = cN; #endif /* iterate over all functions, bbs and instruction and add suitable calls */ for (auto &F : M) { if (!isInInstrumentList(&F)) continue; for (auto &BB : F) { for (auto &IN : BB) { CmpInst *selectcmpInst = nullptr; if ((selectcmpInst = dyn_cast(&IN))) { icomps.push_back(selectcmpInst); } SwitchInst *switchInst = nullptr; if ((switchInst = dyn_cast(BB.getTerminator()))) { if (switchInst->getNumCases() > 1) { switches.push_back(switchInst); } } } } } // unique the collected switches switches.erase(Unique(switches.begin(), switches.end()), switches.end()); // Instrument switch values for cmplog if (switches.size()) { if (!be_quiet) errs() << "Hooking " << switches.size() << " switch instructions\n"; for (auto &SI : switches) { Value * Val = SI->getCondition(); unsigned int max_size = Val->getType()->getIntegerBitWidth(), cast_size; unsigned char do_cast = 0; if (!SI->getNumCases() || max_size <= 8) { // if (!be_quiet) errs() << "skip trivial switch..\n"; continue; } IRBuilder<> IRB(SI->getParent()); IRB.SetInsertPoint(SI); if (max_size % 8) { max_size = (((max_size / 8) + 1) * 8); do_cast = 1; } if (max_size > 128) { if (!be_quiet) { fprintf(stderr, "Cannot handle this switch bit size: %u (truncating)\n", max_size); } max_size = 128; do_cast = 1; } // do we need to cast? switch (max_size) { case 8: case 16: case 32: case 64: case 128: cast_size = max_size; break; default: cast_size = 128; do_cast = 1; } Value *CompareTo = Val; if (do_cast) { ConstantInt *cint = dyn_cast(Val); if (cint) { uint64_t val = cint->getZExtValue(); // fprintf(stderr, "ConstantInt: %lu\n", val); switch (cast_size) { case 8: CompareTo = ConstantInt::get(Int8Ty, val); break; case 16: CompareTo = ConstantInt::get(Int16Ty, val); break; case 32: CompareTo = ConstantInt::get(Int32Ty, val); break; case 64: CompareTo = ConstantInt::get(Int64Ty, val); break; case 128: CompareTo = ConstantInt::get(Int128Ty, val); break; } } else { CompareTo = IRB.CreateBitCast(Val, IntegerType::get(C, cast_size)); } } for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; ++i) { #if LLVM_VERSION_MAJOR < 5 ConstantInt *cint = i.getCaseValue(); #else ConstantInt *cint = i->getCaseValue(); #endif if (cint) { std::vector args; args.push_back(CompareTo); Value *new_param = cint; if (do_cast) { uint64_t val = cint->getZExtValue(); // fprintf(stderr, "ConstantInt: %lu\n", val); switch (cast_size) { case 8: new_param = ConstantInt::get(Int8Ty, val); break; case 16: new_param = ConstantInt::get(Int16Ty, val); break; case 32: new_param = ConstantInt::get(Int32Ty, val); break; case 64: new_param = ConstantInt::get(Int64Ty, val); break; case 128: new_param = ConstantInt::get(Int128Ty, val); break; } } if (new_param) { args.push_back(new_param); ConstantInt *attribute = ConstantInt::get(Int8Ty, 1); args.push_back(attribute); if (cast_size != max_size) { ConstantInt *bitsize = ConstantInt::get(Int8Ty, (max_size / 8) - 1); args.push_back(bitsize); } switch (cast_size) { case 8: IRB.CreateCall(cmplogHookIns1, args); break; case 16: IRB.CreateCall(cmplogHookIns2, args); break; case 32: IRB.CreateCall(cmplogHookIns4, args); break; case 64: IRB.CreateCall(cmplogHookIns8, args); break; case 128: #ifdef WORD_SIZE_64 if (max_size == 128) { IRB.CreateCall(cmplogHookIns16, args); } else { IRB.CreateCall(cmplogHookInsN, args); } #endif break; default: break; } } } } } } if (icomps.size()) { // if (!be_quiet) errs() << "Hooking " << icomps.size() << // " cmp instructions\n"; for (auto &selectcmpInst : icomps) { IRBuilder<> IRB(selectcmpInst->getParent()); IRB.SetInsertPoint(selectcmpInst); Value *op0 = selectcmpInst->getOperand(0); Value *op1 = selectcmpInst->getOperand(1); IntegerType * intTyOp0 = NULL; IntegerType * intTyOp1 = NULL; unsigned max_size = 0, cast_size = 0; unsigned char attr = 0, do_cast = 0; std::vector args; CmpInst *cmpInst = dyn_cast(selectcmpInst); if (!cmpInst) { continue; } switch (cmpInst->getPredicate()) { case CmpInst::ICMP_NE: case CmpInst::FCMP_UNE: case CmpInst::FCMP_ONE: break; case CmpInst::ICMP_EQ: case CmpInst::FCMP_UEQ: case CmpInst::FCMP_OEQ: attr += 1; break; case CmpInst::ICMP_UGT: case CmpInst::ICMP_SGT: case CmpInst::FCMP_OGT: case CmpInst::FCMP_UGT: attr += 2; break; case CmpInst::ICMP_UGE: case CmpInst::ICMP_SGE: case CmpInst::FCMP_OGE: case CmpInst::FCMP_UGE: attr += 3; break; case CmpInst::ICMP_ULT: case CmpInst::ICMP_SLT: case CmpInst::FCMP_OLT: case CmpInst::FCMP_ULT: attr += 4; break; case CmpInst::ICMP_ULE: case CmpInst::ICMP_SLE: case CmpInst::FCMP_OLE: case CmpInst::FCMP_ULE: attr += 5; break; default: break; } if (selectcmpInst->getOpcode() == Instruction::FCmp) { auto ty0 = op0->getType(); if (ty0->isHalfTy() #if LLVM_VERSION_MAJOR >= 11 || ty0->isBFloatTy() #endif ) max_size = 16; else if (ty0->isFloatTy()) max_size = 32; else if (ty0->isDoubleTy()) max_size = 64; else if (ty0->isX86_FP80Ty()) max_size = 80; else if (ty0->isFP128Ty() || ty0->isPPC_FP128Ty()) max_size = 128; attr += 8; do_cast = 1; } else { intTyOp0 = dyn_cast(op0->getType()); intTyOp1 = dyn_cast(op1->getType()); if (intTyOp0 && intTyOp1) { max_size = intTyOp0->getBitWidth() > intTyOp1->getBitWidth() ? intTyOp0->getBitWidth() : intTyOp1->getBitWidth(); } } if (!max_size) { continue; } // _ExtInt() with non-8th values if (max_size % 8) { max_size = (((max_size / 8) + 1) * 8); do_cast = 1; } if (max_size > 128) { if (!be_quiet) { fprintf(stderr, "Cannot handle this compare bit size: %u (truncating)\n", max_size); } max_size = 128; do_cast = 1; } // do we need to cast? switch (max_size) { case 8: case 16: case 32: case 64: case 128: cast_size = max_size; break; default: cast_size = 128; do_cast = 1; } if (do_cast) { // F*cking LLVM optimized out any kind of bitcasts of ConstantInt values // creating illegal calls. WTF. So we have to work around this. ConstantInt *cint = dyn_cast(op0); if (cint) { uint64_t val = cint->getZExtValue(); // fprintf(stderr, "ConstantInt: %lu\n", val); ConstantInt *new_param = NULL; switch (cast_size) { case 8: new_param = ConstantInt::get(Int8Ty, val); break; case 16: new_param = ConstantInt::get(Int16Ty, val); break; case 32: new_param = ConstantInt::get(Int32Ty, val); break; case 64: new_param = ConstantInt::get(Int64Ty, val); break; case 128: new_param = ConstantInt::get(Int128Ty, val); break; } if (!new_param) { continue; } args.push_back(new_param); } else { Value *V0 = IRB.CreateBitCast(op0, IntegerType::get(C, cast_size)); args.push_back(V0); } cint = dyn_cast(op1); if (cint) { uint64_t val = cint->getZExtValue(); ConstantInt *new_param = NULL; switch (cast_size) { case 8: new_param = ConstantInt::get(Int8Ty, val); break; case 16: new_param = ConstantInt::get(Int16Ty, val); break; case 32: new_param = ConstantInt::get(Int32Ty, val); break; case 64: new_param = ConstantInt::get(Int64Ty, val); break; case 128: new_param = ConstantInt::get(Int128Ty, val); break; } if (!new_param) { continue; } args.push_back(new_param); } else { Value *V1 = IRB.CreateBitCast(op1, IntegerType::get(C, cast_size)); args.push_back(V1); } } else { args.push_back(op0); args.push_back(op1); } ConstantInt *attribute = ConstantInt::get(Int8Ty, attr); args.push_back(attribute); if (cast_size != max_size) { ConstantInt *bitsize = ConstantInt::get(Int8Ty, (max_size / 8) - 1); args.push_back(bitsize); } // fprintf(stderr, "_ExtInt(%u) castTo %u with attr %u didcast %u\n", // max_size, cast_size, attr, do_cast); switch (cast_size) { case 8: IRB.CreateCall(cmplogHookIns1, args); break; case 16: IRB.CreateCall(cmplogHookIns2, args); break; case 32: IRB.CreateCall(cmplogHookIns4, args); break; case 64: IRB.CreateCall(cmplogHookIns8, args); break; case 128: if (max_size == 128) { IRB.CreateCall(cmplogHookIns16, args); } else { IRB.CreateCall(cmplogHookInsN, args); } break; } } } if (switches.size() || icomps.size()) return true; else return false; } bool CmpLogInstructions::runOnModule(Module &M) { if (getenv("AFL_QUIET") == NULL) printf("Running cmplog-instructions-pass by andreafioraldi@gmail.com\n"); else be_quiet = 1; hookInstrs(M); verifyModule(M); return true; } static void registerCmpLogInstructionsPass(const PassManagerBuilder &, legacy::PassManagerBase &PM) { auto p = new CmpLogInstructions(); PM.add(p); } static RegisterStandardPasses RegisterCmpLogInstructionsPass( PassManagerBuilder::EP_OptimizerLast, registerCmpLogInstructionsPass); static RegisterStandardPasses RegisterCmpLogInstructionsPass0( PassManagerBuilder::EP_EnabledOnOptLevel0, registerCmpLogInstructionsPass); #if LLVM_VERSION_MAJOR >= 11 static RegisterStandardPasses RegisterCmpLogInstructionsPassLTO( PassManagerBuilder::EP_FullLinkTimeOptimizationLast, registerCmpLogInstructionsPass); #endif