|
@@ -0,0 +1,409 @@
|
|
|
+#include "llvm/Pass.h"
|
|
|
+#include "llvm/IR/Function.h"
|
|
|
+#include "llvm/IR/BasicBlock.h"
|
|
|
+#include "llvm/IR/Instructions.h"
|
|
|
+#include "llvm/IR/LegacyPassManager.h"
|
|
|
+#include "llvm/Transforms/IPO/PassManagerBuilder.h"
|
|
|
+#include "llvm/Support/raw_ostream.h"
|
|
|
+#include "llvm/Analysis/CFG.h"
|
|
|
+#include "LogSystem.h"
|
|
|
+#include <vector>
|
|
|
+#include <map>
|
|
|
+#include <set>
|
|
|
+#include <queue>
|
|
|
+#include <algorithm>
|
|
|
+#include <string>
|
|
|
+
|
|
|
+using namespace llvm;
|
|
|
+
|
|
|
+namespace
|
|
|
+{
|
|
|
+ struct CodeFusionPass : public FunctionPass
|
|
|
+ {
|
|
|
+ public:
|
|
|
+ static char ID;
|
|
|
+ CodeFusionPass() : FunctionPass(ID) {}
|
|
|
+
|
|
|
+ bool runOnFunction(Function &F) override
|
|
|
+ {
|
|
|
+ // 配置日志系统
|
|
|
+ auto& logger = logging::LogSystem::getInstance();
|
|
|
+ logger.setGlobalLevel(logging::LogLevel::DEBUG);
|
|
|
+ logger.setContextFunction(F.getName().str());
|
|
|
+ // 配置Function输出权限
|
|
|
+ logger.enableFunction("checkCriticalPoint");
|
|
|
+ // logger.enableFunction("dumpControlFlowGraph");
|
|
|
+ // logger.enableFunction("isTargetCode");
|
|
|
+ LOG_INFO("runOnFunction", "Starting analysis for function: {0}", F.getName().str());
|
|
|
+
|
|
|
+ if(isTargetCode(F)){
|
|
|
+ for (BasicBlock &BB : F) {
|
|
|
+ if(checkCriticalPoint(&BB)){
|
|
|
+ criticalPoints.insert(&BB);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 绘制控制流图
|
|
|
+ // dumpControlFlowGraph(F);
|
|
|
+
|
|
|
+ // 主要的混淆逻辑在这里实现
|
|
|
+ // analyzeFunction(F);
|
|
|
+ // selectFusionPoints(F);
|
|
|
+ // sliceTargetCode(F);
|
|
|
+ // fuseCode(F);
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+
|
|
|
+ private:
|
|
|
+ std::set<BasicBlock*> criticalPoints;
|
|
|
+ std::vector<BasicBlock *> fusionPoints;
|
|
|
+ std::vector<std::vector<BasicBlock *>> targetSlices;
|
|
|
+
|
|
|
+ bool isTargetCode(Function &Func) {
|
|
|
+ LOG_INFO("isTargetCode", "Checking if function '{0}' is target code", Func.getName().str());
|
|
|
+ if (MDNode *MD = Func.getMetadata("project_source")) {
|
|
|
+ if (MDString *ProjectStr = dyn_cast<MDString>(MD->getOperand(0))) {
|
|
|
+ std::string projectName = ProjectStr->getString().str();
|
|
|
+ bool isTarget = (projectName == "Target");
|
|
|
+
|
|
|
+ LOG_INFO("isTargetCode", "Function '{0}' project: {1}, isTarget: {2}",
|
|
|
+ Func.getName().str(), projectName, isTarget ? "true" : "false");
|
|
|
+
|
|
|
+ return isTarget;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (!Func.isDeclaration()) {
|
|
|
+ LOG_WARNING("isTargetCode", "Function '{0}' has no project metadata, considering as non-target",
|
|
|
+ Func.getName().str());
|
|
|
+ }
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ std::string getNodeId(BasicBlock *BB){
|
|
|
+ std::string nodeId;
|
|
|
+ raw_string_ostream idOS(nodeId);
|
|
|
+ BB->printAsOperand(idOS, false);
|
|
|
+ idOS.flush();
|
|
|
+ return nodeId;
|
|
|
+ }
|
|
|
+
|
|
|
+ bool checkCriticalPoint(BasicBlock *BB) {
|
|
|
+ LOG_DEBUG("checkCriticalPoint", "Starting critical point check: {0}", getNodeId(BB));
|
|
|
+ if (!BB) {
|
|
|
+ LOG_ERROR("checkCriticalPoint", "Null basic block provided");
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ Function *F = BB->getParent();
|
|
|
+ if (!F) {
|
|
|
+ LOG_ERROR("checkCriticalPoint", "Cannot get parent function for block");
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ // 获取起始和终止基本块
|
|
|
+ BasicBlock *Entry = &F->getEntryBlock();
|
|
|
+ BasicBlock *Exit = nullptr;
|
|
|
+ for (BasicBlock &B : *F) {
|
|
|
+ if (isa<ReturnInst>(B.getTerminator()) ||
|
|
|
+ isa<UnreachableInst>(B.getTerminator())) {
|
|
|
+ Exit = &B;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (!Exit) {
|
|
|
+ LOG_WARNING("checkCriticalPoint", "No exit block found in function");
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ // 存储已访问的节点
|
|
|
+ std::set<BasicBlock*> visitedWithoutCurrent;
|
|
|
+
|
|
|
+ // 不经过当前基本块的遍历函数
|
|
|
+ std::function<bool(BasicBlock*, BasicBlock*)> canReachExitWithout;
|
|
|
+ canReachExitWithout = [&canReachExitWithout, &visitedWithoutCurrent, BB]
|
|
|
+ (BasicBlock *Start, BasicBlock *Target) -> bool {
|
|
|
+ if (Start == BB) return false;
|
|
|
+ if (Start == Target) return true;
|
|
|
+
|
|
|
+ if (visitedWithoutCurrent.count(Start)) return false;
|
|
|
+ visitedWithoutCurrent.insert(Start);
|
|
|
+
|
|
|
+ for (BasicBlock *Succ : successors(Start)) {
|
|
|
+ if (canReachExitWithout(Succ, Target)) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return false;
|
|
|
+ };
|
|
|
+
|
|
|
+ // 输出块的基本信息
|
|
|
+ LOG_TRACE("checkCriticalPoint", "Analyzing block: {0}", getNodeId(BB));
|
|
|
+ LOG_TRACE("checkCriticalPoint", " Predecessors: {0}, Successors: {1}",
|
|
|
+ pred_size(BB), succ_size(BB));
|
|
|
+
|
|
|
+ // 列出前驱基本块
|
|
|
+ LOG_TRACE("checkCriticalPoint", "Predecessor blocks:");
|
|
|
+ for (BasicBlock *Pred : predecessors(BB)) {
|
|
|
+ LOG_TRACE("checkCriticalPoint", " {0}", getNodeId(Pred));
|
|
|
+ }
|
|
|
+
|
|
|
+ // 列出后继基本块
|
|
|
+ LOG_TRACE("checkCriticalPoint", "Successor blocks:");
|
|
|
+ for (BasicBlock *Succ : successors(BB)) {
|
|
|
+ LOG_TRACE("checkCriticalPoint", " {0}", getNodeId(Succ));
|
|
|
+ }
|
|
|
+
|
|
|
+ // 清空访问集合
|
|
|
+ visitedWithoutCurrent.clear();
|
|
|
+
|
|
|
+ // 判断从入口到出口是否必须经过当前基本块
|
|
|
+ bool isCritical = !canReachExitWithout(Entry, Exit);
|
|
|
+
|
|
|
+ LOG_DEBUG("checkCriticalPoint", "Block {0} critical status: {1}",
|
|
|
+ BB->getName().str(), isCritical ? "yes" : "no");
|
|
|
+
|
|
|
+ return isCritical;
|
|
|
+ }
|
|
|
+
|
|
|
+ void dumpControlFlowGraph(Function &F) {
|
|
|
+ LOG_INFO("dumpControlFlowGraph", "Generating control flow graph for function: {0}",
|
|
|
+ F.getName().str());
|
|
|
+
|
|
|
+ if (F.empty()) {
|
|
|
+ LOG_WARNING("dumpControlFlowGraph", "Function is empty!");
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ LOG_INFO("dumpControlFlowGraph", "Starting Mermaid graph generation");
|
|
|
+ errs() << "```mermaid\n";
|
|
|
+ errs() << "graph TD\n";
|
|
|
+
|
|
|
+ // 为所有基本块创建节点
|
|
|
+ for (BasicBlock &BB : F) {
|
|
|
+ // 获取基本块的ID
|
|
|
+ std::string nodeId = getNodeId(&BB);
|
|
|
+
|
|
|
+ if (!nodeId.empty() && nodeId[0] == '%') {
|
|
|
+ nodeId = nodeId.substr(1);
|
|
|
+ }
|
|
|
+
|
|
|
+ // 清理节点ID中的特殊字符
|
|
|
+ std::replace(nodeId.begin(), nodeId.end(), '.', '_');
|
|
|
+ std::replace(nodeId.begin(), nodeId.end(), ' ', '_');
|
|
|
+ std::replace(nodeId.begin(), nodeId.end(), '%', '_');
|
|
|
+ std::replace(nodeId.begin(), nodeId.end(), '-', '_');
|
|
|
+
|
|
|
+ LOG_TRACE("dumpControlFlowGraph", "Processing block with ID: {0}", nodeId);
|
|
|
+
|
|
|
+ // 构建基本块内容字符串
|
|
|
+ std::string blockContent;
|
|
|
+ raw_string_ostream contentOS(blockContent);
|
|
|
+
|
|
|
+ // 添加基本块标签
|
|
|
+ contentOS << "Block " << nodeId << ":\\n";
|
|
|
+
|
|
|
+ // 添加基本块中的每条指令
|
|
|
+ for (Instruction &I : BB) {
|
|
|
+ std::string instStr;
|
|
|
+ raw_string_ostream instOS(instStr);
|
|
|
+ I.print(instOS);
|
|
|
+ instOS.flush();
|
|
|
+
|
|
|
+ // 清理指令字符串
|
|
|
+ std::replace(instStr.begin(), instStr.end(), '"', '\'');
|
|
|
+ std::replace(instStr.begin(), instStr.end(), '\n', ' ');
|
|
|
+
|
|
|
+ // 如果指令太长,截断它
|
|
|
+ if (instStr.length() > 50) {
|
|
|
+ LOG_TRACE("dumpControlFlowGraph", "Truncating long instruction in block {0}", nodeId);
|
|
|
+ instStr = instStr.substr(0, 47) + "...";
|
|
|
+ }
|
|
|
+
|
|
|
+ contentOS << instStr << "\\n";
|
|
|
+ }
|
|
|
+ contentOS.flush();
|
|
|
+
|
|
|
+ // 生成节点
|
|
|
+ if (criticalPoints.count(&BB)) {
|
|
|
+ LOG_TRACE("dumpControlFlowGraph", "Marking block {0} as critical", nodeId);
|
|
|
+ errs() << " " << nodeId << "[\"" << blockContent << "\"]:::critical\n";
|
|
|
+ } else {
|
|
|
+ errs() << " " << nodeId << "[\"" << blockContent << "\"]\n";
|
|
|
+ }
|
|
|
+
|
|
|
+ // 处理边
|
|
|
+ for (BasicBlock *Succ : successors(&BB)) {
|
|
|
+ std::string succId;
|
|
|
+ raw_string_ostream succOS(succId);
|
|
|
+ Succ->printAsOperand(succOS, false);
|
|
|
+ succOS.flush();
|
|
|
+
|
|
|
+ if (!succId.empty() && succId[0] == '%') {
|
|
|
+ succId = succId.substr(1);
|
|
|
+ }
|
|
|
+
|
|
|
+ std::replace(succId.begin(), succId.end(), '.', '_');
|
|
|
+ std::replace(succId.begin(), succId.end(), ' ', '_');
|
|
|
+ std::replace(succId.begin(), succId.end(), '%', '_');
|
|
|
+ std::replace(succId.begin(), succId.end(), '-', '_');
|
|
|
+
|
|
|
+ LOG_TRACE("dumpControlFlowGraph", "Processing edge from {0} to {1}", nodeId, succId);
|
|
|
+
|
|
|
+ // 检查分支类型并添加边
|
|
|
+ if (BranchInst *BI = dyn_cast<BranchInst>(BB.getTerminator())) {
|
|
|
+ if (BI->isConditional()) {
|
|
|
+ // 获取条件表达式
|
|
|
+ std::string condStr;
|
|
|
+ raw_string_ostream condOS(condStr);
|
|
|
+ BI->getCondition()->print(condOS);
|
|
|
+ condOS.flush();
|
|
|
+
|
|
|
+ bool isTrue = BI->getSuccessor(0) == Succ;
|
|
|
+ LOG_TRACE("dumpControlFlowGraph", "Adding conditional branch edge: {0}",
|
|
|
+ isTrue ? "true" : "false");
|
|
|
+ errs() << " " << nodeId << " -->|"
|
|
|
+ << (isTrue ? "true" : "false") << "| "
|
|
|
+ << succId << "\n";
|
|
|
+ } else {
|
|
|
+ errs() << " " << nodeId << " --> " << succId << "\n";
|
|
|
+ }
|
|
|
+ } else if (SwitchInst *SI = dyn_cast<SwitchInst>(BB.getTerminator())) {
|
|
|
+ // 处理switch指令
|
|
|
+ if (Succ == SI->getDefaultDest()) {
|
|
|
+ LOG_TRACE("dumpControlFlowGraph", "Adding switch default edge");
|
|
|
+ errs() << " " << nodeId << " -->|default| " << succId << "\n";
|
|
|
+ } else {
|
|
|
+ for (auto Case : SI->cases()) {
|
|
|
+ if (Case.getCaseSuccessor() == Succ) {
|
|
|
+ std::string caseStr;
|
|
|
+ raw_string_ostream caseOS(caseStr);
|
|
|
+ Case.getCaseValue()->print(caseOS);
|
|
|
+ LOG_TRACE("dumpControlFlowGraph", "Adding switch case edge: {0}",
|
|
|
+ caseStr);
|
|
|
+ errs() << " " << nodeId << " -->|case "
|
|
|
+ << caseStr << "| " << succId << "\n";
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ errs() << " " << nodeId << " --> " << succId << "\n";
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ errs() << " classDef critical fill:#f96,stroke:#333,stroke-width:4px\n";
|
|
|
+ errs() << "```\n";
|
|
|
+ LOG_INFO("dumpControlFlowGraph", "Completed graph generation");
|
|
|
+ }
|
|
|
+
|
|
|
+ void analyzeFunction(Function &F)
|
|
|
+ {
|
|
|
+ for (BasicBlock &BB : F)
|
|
|
+ {
|
|
|
+ for (Instruction &I : BB)
|
|
|
+ {
|
|
|
+ if (CallInst *CI = dyn_cast<CallInst>(&I))
|
|
|
+ {
|
|
|
+ // 分析函数调用
|
|
|
+ Function *CalledF = CI->getCalledFunction();
|
|
|
+ if (CalledF && CalledF->isDeclaration())
|
|
|
+ {
|
|
|
+ errs() << "External function call: " << CalledF->getName() << "\n";
|
|
|
+ // 这里可以添加区分变量
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // 可以添加更多的分析,如全局变量访问等
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ void selectFusionPoints(Function &F) {
|
|
|
+ errs() << "\n=== Starting Fusion Points Selection for function: "
|
|
|
+ << F.getName() << " ===\n";
|
|
|
+
|
|
|
+ fusionPoints.clear();
|
|
|
+ criticalPoints.clear();
|
|
|
+
|
|
|
+ std::set<BasicBlock*> visited;
|
|
|
+
|
|
|
+ // 首先收集所有可能的关键点
|
|
|
+ for (BasicBlock &BB : F) {
|
|
|
+ if (checkCriticalPoint(&BB)) {
|
|
|
+ criticalPoints.insert(&BB);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ errs() << "\nIdentified " << criticalPoints.size() << " potential critical points\n";
|
|
|
+
|
|
|
+ // 按照支配关系排序关键点
|
|
|
+ std::vector<BasicBlock*> orderedPoints;
|
|
|
+ BasicBlock *entryBlock = &F.getEntryBlock();
|
|
|
+ std::vector<BasicBlock*> workList;
|
|
|
+ workList.push_back(entryBlock);
|
|
|
+
|
|
|
+ while (!workList.empty()) {
|
|
|
+ BasicBlock *current = workList.back();
|
|
|
+ workList.pop_back();
|
|
|
+
|
|
|
+ if (visited.count(current))
|
|
|
+ continue;
|
|
|
+
|
|
|
+ visited.insert(current);
|
|
|
+
|
|
|
+ if (criticalPoints.count(current)) {
|
|
|
+ orderedPoints.push_back(current);
|
|
|
+ errs() << "Adding ordered critical point: " << current->getName() << "\n";
|
|
|
+ }
|
|
|
+
|
|
|
+ for (BasicBlock *succ : successors(current)) {
|
|
|
+ if (!visited.count(succ)) {
|
|
|
+ workList.push_back(succ);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ fusionPoints = orderedPoints;
|
|
|
+
|
|
|
+ errs() << "\nSelected " << fusionPoints.size() << " fusion points in order:\n";
|
|
|
+ for (BasicBlock *BB : fusionPoints) {
|
|
|
+ errs() << " " << BB->getName() << "\n";
|
|
|
+ }
|
|
|
+
|
|
|
+ errs() << "\nValidating fusion points:\n";
|
|
|
+ for (size_t i = 0; i < fusionPoints.size(); ++i) {
|
|
|
+ BasicBlock *current = fusionPoints[i];
|
|
|
+ errs() << "Fusion point " << i << ": " << current->getName() << "\n";
|
|
|
+
|
|
|
+ errs() << "Instructions in this fusion point:\n";
|
|
|
+ for (Instruction &I : *current) {
|
|
|
+ errs() << " " << I << "\n";
|
|
|
+ }
|
|
|
+
|
|
|
+ if (i < fusionPoints.size() - 1) {
|
|
|
+ BasicBlock *next = fusionPoints[i + 1];
|
|
|
+ errs() << "Distance to next fusion point: "
|
|
|
+ << std::distance(current->getIterator(), next->getIterator())
|
|
|
+ << " blocks\n";
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ errs() << "=== Finished Fusion Points Selection ===\n\n";
|
|
|
+ }
|
|
|
+
|
|
|
+ void sliceTargetCode(Function &F)
|
|
|
+ {
|
|
|
+ errs() << "Slicing target code\n";
|
|
|
+ // 目标代码分片逻辑
|
|
|
+ }
|
|
|
+
|
|
|
+ void fuseCode(Function &F)
|
|
|
+ {
|
|
|
+ errs() << "Fusing code\n";
|
|
|
+ // 代码融合逻辑
|
|
|
+ }
|
|
|
+ };
|
|
|
+}
|
|
|
+
|
|
|
+char CodeFusionPass::ID = 0;
|
|
|
+static RegisterPass<CodeFusionPass> X("codefusion", "Code Fusion Pass");
|