aboutsummaryrefslogtreecommitdiff
path: root/instrumentation
diff options
context:
space:
mode:
authorvan Hauser <vh@thc.org>2021-01-22 15:58:12 +0100
committervan Hauser <vh@thc.org>2021-01-22 15:58:12 +0100
commitbaf1ac2e69145d9e411f13df8ab4e7060f2f2685 (patch)
tree2a9cdd0ba596c9794417641e2cdad2499162dee1 /instrumentation
parent46010a87049bdf32ef23b08d25c186aba3aae442 (diff)
downloadAFLplusplus-baf1ac2e69145d9e411f13df8ab4e7060f2f2685.tar.gz
basic cmplog std::string support
Diffstat (limited to 'instrumentation')
-rw-r--r--instrumentation/afl-compiler-rt.o.c4
-rw-r--r--instrumentation/cmplog-routines-pass.cc214
2 files changed, 207 insertions, 11 deletions
diff --git a/instrumentation/afl-compiler-rt.o.c b/instrumentation/afl-compiler-rt.o.c
index 322141ba..433a1d89 100644
--- a/instrumentation/afl-compiler-rt.o.c
+++ b/instrumentation/afl-compiler-rt.o.c
@@ -1549,10 +1549,10 @@ void __cmplog_rtn_hook(u8 *ptr1, u8 *ptr2) {
u32 i;
if (!area_is_mapped(ptr1, 32) || !area_is_mapped(ptr2, 32)) return;
fprintf(stderr, "rtn arg0=");
- for (i = 0; i < 8; i++)
+ for (i = 0; i < 24; i++)
fprintf(stderr, "%02x", ptr1[i]);
fprintf(stderr, " arg1=");
- for (i = 0; i < 8; i++)
+ for (i = 0; i < 24; i++)
fprintf(stderr, "%02x", ptr2[i]);
fprintf(stderr, "\n");
*/
diff --git a/instrumentation/cmplog-routines-pass.cc b/instrumentation/cmplog-routines-pass.cc
index 8adf42d5..9d6999ca 100644
--- a/instrumentation/cmplog-routines-pass.cc
+++ b/instrumentation/cmplog-routines-pass.cc
@@ -87,7 +87,7 @@ char CmpLogRoutines::ID = 0;
bool CmpLogRoutines::hookRtns(Module &M) {
- std::vector<CallInst *> calls;
+ std::vector<CallInst *> calls, llvmStdStd, llvmStdC, gccStdStd, gccStdC;
LLVMContext & C = M.getContext();
Type *VoidTy = Type::getVoidTy(C);
@@ -112,6 +112,78 @@ bool CmpLogRoutines::hookRtns(Module &M) {
FunctionCallee cmplogHookFn = c;
#endif
+#if LLVM_VERSION_MAJOR < 9
+ Constant *
+#else
+ FunctionCallee
+#endif
+ c1 = M.getOrInsertFunction("__cmplog_rtn_llvm_stdstring_stdstring",
+ VoidTy, i8PtrTy, i8PtrTy
+#if LLVM_VERSION_MAJOR < 5
+ ,
+ NULL
+#endif
+ );
+#if LLVM_VERSION_MAJOR < 9
+ Function *cmplogLlvmStdStd = cast<Function>(c1);
+#else
+ FunctionCallee cmplogLlvmStdStd = c1;
+#endif
+
+#if LLVM_VERSION_MAJOR < 9
+ Constant *
+#else
+ FunctionCallee
+#endif
+ c2 = M.getOrInsertFunction("__cmplog_rtn_llvm_stdstring_cstring", VoidTy,
+ i8PtrTy, i8PtrTy
+#if LLVM_VERSION_MAJOR < 5
+ ,
+ NULL
+#endif
+ );
+#if LLVM_VERSION_MAJOR < 9
+ Function *cmplogLlvmStdC = cast<Function>(c2);
+#else
+ FunctionCallee cmplogLlvmStdC = c2;
+#endif
+
+#if LLVM_VERSION_MAJOR < 9
+ Constant *
+#else
+ FunctionCallee
+#endif
+ c3 = M.getOrInsertFunction("__cmplog_rtn_gcc_stdstring_stdstring", VoidTy,
+ i8PtrTy, i8PtrTy
+#if LLVM_VERSION_MAJOR < 5
+ ,
+ NULL
+#endif
+ );
+#if LLVM_VERSION_MAJOR < 9
+ Function *cmplogGccStdStd = cast<Function>(c3);
+#else
+ FunctionCallee cmplogGccStdStd = c3;
+#endif
+
+#if LLVM_VERSION_MAJOR < 9
+ Constant *
+#else
+ FunctionCallee
+#endif
+ c4 = M.getOrInsertFunction("__cmplog_rtn_gcc_stdstring_cstring", VoidTy,
+ i8PtrTy, i8PtrTy
+#if LLVM_VERSION_MAJOR < 5
+ ,
+ NULL
+#endif
+ );
+#if LLVM_VERSION_MAJOR < 9
+ Function *cmplogGccStdC = cast<Function>(c4);
+#else
+ FunctionCallee cmplogGccStdC = c4;
+#endif
+
/* iterate over all functions, bbs and instruction and add suitable calls */
for (auto &F : M) {
@@ -131,19 +203,67 @@ bool CmpLogRoutines::hookRtns(Module &M) {
FunctionType *FT = Callee->getFunctionType();
- // _ZNKSt3__112basic_stringIcNS_11char_traitsIcEENS_9allocatorIcEEE7compareEmmPKcm
- // => libc++ => llvm => __cmplog_rtn_llvm_stdstring_cstring(u8 *stdstring1, u8 *stdstring2)
- // _ZNKSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE7compareEPKc
- // => libstdc++ => gcc => __cmplog_rtn_gcc_stdstring_cstring
-
bool isPtrRtn = FT->getNumParams() >= 2 &&
!FT->getReturnType()->isVoidTy() &&
FT->getParamType(0) == FT->getParamType(1) &&
FT->getParamType(0)->isPointerTy();
- if (!isPtrRtn) continue;
-
- calls.push_back(callInst);
+ bool isGccStdStringStdString =
+ Callee->getName().find("__is_charIT_EE7__value") !=
+ std::string::npos &&
+ Callee->getName().find(
+ "St7__cxx1112basic_stringIS2_St11char_traits") !=
+ std::string::npos &&
+ FT->getNumParams() >= 2 &&
+ FT->getParamType(0) == FT->getParamType(1) &&
+ FT->getParamType(0)->isPointerTy();
+
+ bool isGccStdStringCString =
+ Callee->getName().find(
+ "St7__cxx1112basic_stringIcSt11char_"
+ "traitsIcESaIcEE7compareEPK") != std::string::npos &&
+ FT->getNumParams() >= 2 && FT->getParamType(0)->isPointerTy() &&
+ FT->getParamType(1)->isPointerTy();
+
+ bool isLlvmStdStringStdString =
+ Callee->getName().find("_ZNSt3__1eqINS") != std::string::npos &&
+ Callee->getName().find("_12basic_stringIcNS_11char_traits") !=
+ std::string::npos &&
+ FT->getNumParams() >= 2 && FT->getParamType(0)->isPointerTy() &&
+ FT->getParamType(1)->isPointerTy();
+
+ bool isLlvmStdStringCString =
+ Callee->getName().find("ZNSt3__1eqIcNS") != std::string::npos &&
+ Callee->getName().find("_12basic_stringI") != std::string::npos &&
+ FT->getNumParams() >= 2 && FT->getParamType(0)->isPointerTy() &&
+ FT->getParamType(1)->isPointerTy();
+
+ /*
+ fprintf(stderr, "F:%s C:%s argc:%u\n",
+ F.getName().str().c_str(), Callee->getName().str().c_str(),
+ FT->getNumParams()); fprintf(stderr, "ptr0:%u ptr1:%u ptr2:%u\n",
+ FT->getParamType(0)->isPointerTy(),
+ FT->getParamType(1)->isPointerTy(),
+ FT->getNumParams() > 2 ? FT->getParamType(2)->isPointerTy()
+ : 22 );
+ */
+
+ if (isGccStdStringCString || isGccStdStringStdString ||
+ isLlvmStdStringStdString || isLlvmStdStringCString) {
+
+ isPtrRtn = false;
+
+ fprintf(stderr, "%u %u %u %u\n", isGccStdStringCString,
+ isGccStdStringStdString, isLlvmStdStringStdString,
+ isLlvmStdStringCString);
+
+ }
+
+ if (isPtrRtn) { calls.push_back(callInst); }
+ if (isGccStdStringStdString) { gccStdStd.push_back(callInst); }
+ if (isGccStdStringCString) { gccStdC.push_back(callInst); }
+ if (isLlvmStdStringStdString) { llvmStdStd.push_back(callInst); }
+ if (isLlvmStdStringCString) { llvmStdC.push_back(callInst); }
}
@@ -179,6 +299,82 @@ bool CmpLogRoutines::hookRtns(Module &M) {
}
+ for (auto &callInst : gccStdStd) {
+
+ Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1);
+
+ IRBuilder<> IRB(callInst->getParent());
+ IRB.SetInsertPoint(callInst);
+
+ std::vector<Value *> args;
+ Value * v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy);
+ Value * v2Pcasted = IRB.CreatePointerCast(v2P, i8PtrTy);
+ args.push_back(v1Pcasted);
+ args.push_back(v2Pcasted);
+
+ IRB.CreateCall(cmplogGccStdStd, args);
+
+ // errs() << callInst->getCalledFunction()->getName() << "\n";
+
+ }
+
+ for (auto &callInst : gccStdC) {
+
+ Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1);
+
+ IRBuilder<> IRB(callInst->getParent());
+ IRB.SetInsertPoint(callInst);
+
+ std::vector<Value *> args;
+ Value * v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy);
+ Value * v2Pcasted = IRB.CreatePointerCast(v2P, i8PtrTy);
+ args.push_back(v1Pcasted);
+ args.push_back(v2Pcasted);
+
+ IRB.CreateCall(cmplogGccStdC, args);
+
+ // errs() << callInst->getCalledFunction()->getName() << "\n";
+
+ }
+
+ for (auto &callInst : llvmStdStd) {
+
+ Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1);
+
+ IRBuilder<> IRB(callInst->getParent());
+ IRB.SetInsertPoint(callInst);
+
+ std::vector<Value *> args;
+ Value * v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy);
+ Value * v2Pcasted = IRB.CreatePointerCast(v2P, i8PtrTy);
+ args.push_back(v1Pcasted);
+ args.push_back(v2Pcasted);
+
+ IRB.CreateCall(cmplogLlvmStdStd, args);
+
+ // errs() << callInst->getCalledFunction()->getName() << "\n";
+
+ }
+
+ for (auto &callInst : llvmStdC) {
+
+ Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1);
+
+ IRBuilder<> IRB(callInst->getParent());
+ IRB.SetInsertPoint(callInst);
+
+ std::vector<Value *> args;
+ Value * v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy);
+ Value * v2Pcasted = IRB.CreatePointerCast(v2P, i8PtrTy);
+ args.push_back(v1Pcasted);
+ args.push_back(v2Pcasted);
+
+ IRB.CreateCall(cmplogLlvmStdC, args);
+
+ // errs() << callInst->getCalledFunction()->getName() << "\n";
+
+ }
+
return true;
}