SliceFusionPass.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. #include "llvm/Pass.h"
  2. #include "llvm/IR/Function.h"
  3. #include "llvm/IR/BasicBlock.h"
  4. #include "llvm/IR/Instructions.h"
  5. #include "llvm/IR/LegacyPassManager.h"
  6. #include "llvm/Transforms/IPO/PassManagerBuilder.h"
  7. #include "llvm/Support/raw_ostream.h"
  8. #include "llvm/Analysis/CFG.h"
  9. #include "LogSystem.h"
  10. #include <vector>
  11. #include <map>
  12. #include <set>
  13. #include <queue>
  14. #include <algorithm>
  15. #include <string>
  16. using namespace llvm;
  17. namespace
  18. {
  19. struct CodeFusionPass : public FunctionPass
  20. {
  21. public:
  22. static char ID;
  23. CodeFusionPass() : FunctionPass(ID) {}
  24. bool runOnFunction(Function &F) override
  25. {
  26. // 配置日志系统
  27. auto& logger = logging::LogSystem::getInstance();
  28. logger.setGlobalLevel(logging::LogLevel::DEBUG);
  29. logger.setContextFunction(F.getName().str());
  30. // 配置Function输出权限
  31. logger.enableFunction("checkCriticalPoint");
  32. // logger.enableFunction("dumpControlFlowGraph");
  33. // logger.enableFunction("isTargetCode");
  34. LOG_INFO("runOnFunction", "Starting analysis for function: {0}", F.getName().str());
  35. if(isTargetCode(F)){
  36. for (BasicBlock &BB : F) {
  37. if(checkCriticalPoint(&BB)){
  38. criticalPoints.insert(&BB);
  39. }
  40. }
  41. }
  42. // 绘制控制流图
  43. // dumpControlFlowGraph(F);
  44. // 主要的混淆逻辑在这里实现
  45. // analyzeFunction(F);
  46. // selectFusionPoints(F);
  47. // sliceTargetCode(F);
  48. // fuseCode(F);
  49. return true;
  50. }
  51. private:
  52. std::set<BasicBlock*> criticalPoints;
  53. std::vector<BasicBlock *> fusionPoints;
  54. std::vector<std::vector<BasicBlock *>> targetSlices;
  55. bool isTargetCode(Function &Func) {
  56. LOG_INFO("isTargetCode", "Checking if function '{0}' is target code", Func.getName().str());
  57. if (MDNode *MD = Func.getMetadata("project_source")) {
  58. if (MDString *ProjectStr = dyn_cast<MDString>(MD->getOperand(0))) {
  59. std::string projectName = ProjectStr->getString().str();
  60. bool isTarget = (projectName == "Target");
  61. LOG_INFO("isTargetCode", "Function '{0}' project: {1}, isTarget: {2}",
  62. Func.getName().str(), projectName, isTarget ? "true" : "false");
  63. return isTarget;
  64. }
  65. }
  66. if (!Func.isDeclaration()) {
  67. LOG_WARNING("isTargetCode", "Function '{0}' has no project metadata, considering as non-target",
  68. Func.getName().str());
  69. }
  70. return false;
  71. }
  72. std::string getNodeId(BasicBlock *BB){
  73. std::string nodeId;
  74. raw_string_ostream idOS(nodeId);
  75. BB->printAsOperand(idOS, false);
  76. idOS.flush();
  77. return nodeId;
  78. }
  79. bool checkCriticalPoint(BasicBlock *BB) {
  80. LOG_DEBUG("checkCriticalPoint", "Starting critical point check: {0}", getNodeId(BB));
  81. if (!BB) {
  82. LOG_ERROR("checkCriticalPoint", "Null basic block provided");
  83. return false;
  84. }
  85. Function *F = BB->getParent();
  86. if (!F) {
  87. LOG_ERROR("checkCriticalPoint", "Cannot get parent function for block");
  88. return false;
  89. }
  90. // 获取起始和终止基本块
  91. BasicBlock *Entry = &F->getEntryBlock();
  92. BasicBlock *Exit = nullptr;
  93. for (BasicBlock &B : *F) {
  94. if (isa<ReturnInst>(B.getTerminator()) ||
  95. isa<UnreachableInst>(B.getTerminator())) {
  96. Exit = &B;
  97. break;
  98. }
  99. }
  100. if (!Exit) {
  101. LOG_WARNING("checkCriticalPoint", "No exit block found in function");
  102. return false;
  103. }
  104. // 存储已访问的节点
  105. std::set<BasicBlock*> visitedWithoutCurrent;
  106. // 不经过当前基本块的遍历函数
  107. std::function<bool(BasicBlock*, BasicBlock*)> canReachExitWithout;
  108. canReachExitWithout = [&canReachExitWithout, &visitedWithoutCurrent, BB]
  109. (BasicBlock *Start, BasicBlock *Target) -> bool {
  110. if (Start == BB) return false;
  111. if (Start == Target) return true;
  112. if (visitedWithoutCurrent.count(Start)) return false;
  113. visitedWithoutCurrent.insert(Start);
  114. for (BasicBlock *Succ : successors(Start)) {
  115. if (canReachExitWithout(Succ, Target)) {
  116. return true;
  117. }
  118. }
  119. return false;
  120. };
  121. // 输出块的基本信息
  122. LOG_TRACE("checkCriticalPoint", "Analyzing block: {0}", getNodeId(BB));
  123. LOG_TRACE("checkCriticalPoint", " Predecessors: {0}, Successors: {1}",
  124. pred_size(BB), succ_size(BB));
  125. // 列出前驱基本块
  126. LOG_TRACE("checkCriticalPoint", "Predecessor blocks:");
  127. for (BasicBlock *Pred : predecessors(BB)) {
  128. LOG_TRACE("checkCriticalPoint", " {0}", getNodeId(Pred));
  129. }
  130. // 列出后继基本块
  131. LOG_TRACE("checkCriticalPoint", "Successor blocks:");
  132. for (BasicBlock *Succ : successors(BB)) {
  133. LOG_TRACE("checkCriticalPoint", " {0}", getNodeId(Succ));
  134. }
  135. // 清空访问集合
  136. visitedWithoutCurrent.clear();
  137. // 判断从入口到出口是否必须经过当前基本块
  138. bool isCritical = !canReachExitWithout(Entry, Exit);
  139. LOG_DEBUG("checkCriticalPoint", "Block {0} critical status: {1}",
  140. BB->getName().str(), isCritical ? "yes" : "no");
  141. return isCritical;
  142. }
  143. void dumpControlFlowGraph(Function &F) {
  144. LOG_INFO("dumpControlFlowGraph", "Generating control flow graph for function: {0}",
  145. F.getName().str());
  146. if (F.empty()) {
  147. LOG_WARNING("dumpControlFlowGraph", "Function is empty!");
  148. return;
  149. }
  150. LOG_INFO("dumpControlFlowGraph", "Starting Mermaid graph generation");
  151. errs() << "```mermaid\n";
  152. errs() << "graph TD\n";
  153. // 为所有基本块创建节点
  154. for (BasicBlock &BB : F) {
  155. // 获取基本块的ID
  156. std::string nodeId = getNodeId(&BB);
  157. if (!nodeId.empty() && nodeId[0] == '%') {
  158. nodeId = nodeId.substr(1);
  159. }
  160. // 清理节点ID中的特殊字符
  161. std::replace(nodeId.begin(), nodeId.end(), '.', '_');
  162. std::replace(nodeId.begin(), nodeId.end(), ' ', '_');
  163. std::replace(nodeId.begin(), nodeId.end(), '%', '_');
  164. std::replace(nodeId.begin(), nodeId.end(), '-', '_');
  165. LOG_TRACE("dumpControlFlowGraph", "Processing block with ID: {0}", nodeId);
  166. // 构建基本块内容字符串
  167. std::string blockContent;
  168. raw_string_ostream contentOS(blockContent);
  169. // 添加基本块标签
  170. contentOS << "Block " << nodeId << ":\\n";
  171. // 添加基本块中的每条指令
  172. for (Instruction &I : BB) {
  173. std::string instStr;
  174. raw_string_ostream instOS(instStr);
  175. I.print(instOS);
  176. instOS.flush();
  177. // 清理指令字符串
  178. std::replace(instStr.begin(), instStr.end(), '"', '\'');
  179. std::replace(instStr.begin(), instStr.end(), '\n', ' ');
  180. // 如果指令太长,截断它
  181. if (instStr.length() > 50) {
  182. LOG_TRACE("dumpControlFlowGraph", "Truncating long instruction in block {0}", nodeId);
  183. instStr = instStr.substr(0, 47) + "...";
  184. }
  185. contentOS << instStr << "\\n";
  186. }
  187. contentOS.flush();
  188. // 生成节点
  189. if (criticalPoints.count(&BB)) {
  190. LOG_TRACE("dumpControlFlowGraph", "Marking block {0} as critical", nodeId);
  191. errs() << " " << nodeId << "[\"" << blockContent << "\"]:::critical\n";
  192. } else {
  193. errs() << " " << nodeId << "[\"" << blockContent << "\"]\n";
  194. }
  195. // 处理边
  196. for (BasicBlock *Succ : successors(&BB)) {
  197. std::string succId;
  198. raw_string_ostream succOS(succId);
  199. Succ->printAsOperand(succOS, false);
  200. succOS.flush();
  201. if (!succId.empty() && succId[0] == '%') {
  202. succId = succId.substr(1);
  203. }
  204. std::replace(succId.begin(), succId.end(), '.', '_');
  205. std::replace(succId.begin(), succId.end(), ' ', '_');
  206. std::replace(succId.begin(), succId.end(), '%', '_');
  207. std::replace(succId.begin(), succId.end(), '-', '_');
  208. LOG_TRACE("dumpControlFlowGraph", "Processing edge from {0} to {1}", nodeId, succId);
  209. // 检查分支类型并添加边
  210. if (BranchInst *BI = dyn_cast<BranchInst>(BB.getTerminator())) {
  211. if (BI->isConditional()) {
  212. // 获取条件表达式
  213. std::string condStr;
  214. raw_string_ostream condOS(condStr);
  215. BI->getCondition()->print(condOS);
  216. condOS.flush();
  217. bool isTrue = BI->getSuccessor(0) == Succ;
  218. LOG_TRACE("dumpControlFlowGraph", "Adding conditional branch edge: {0}",
  219. isTrue ? "true" : "false");
  220. errs() << " " << nodeId << " -->|"
  221. << (isTrue ? "true" : "false") << "| "
  222. << succId << "\n";
  223. } else {
  224. errs() << " " << nodeId << " --> " << succId << "\n";
  225. }
  226. } else if (SwitchInst *SI = dyn_cast<SwitchInst>(BB.getTerminator())) {
  227. // 处理switch指令
  228. if (Succ == SI->getDefaultDest()) {
  229. LOG_TRACE("dumpControlFlowGraph", "Adding switch default edge");
  230. errs() << " " << nodeId << " -->|default| " << succId << "\n";
  231. } else {
  232. for (auto Case : SI->cases()) {
  233. if (Case.getCaseSuccessor() == Succ) {
  234. std::string caseStr;
  235. raw_string_ostream caseOS(caseStr);
  236. Case.getCaseValue()->print(caseOS);
  237. LOG_TRACE("dumpControlFlowGraph", "Adding switch case edge: {0}",
  238. caseStr);
  239. errs() << " " << nodeId << " -->|case "
  240. << caseStr << "| " << succId << "\n";
  241. }
  242. }
  243. }
  244. } else {
  245. errs() << " " << nodeId << " --> " << succId << "\n";
  246. }
  247. }
  248. }
  249. errs() << " classDef critical fill:#f96,stroke:#333,stroke-width:4px\n";
  250. errs() << "```\n";
  251. LOG_INFO("dumpControlFlowGraph", "Completed graph generation");
  252. }
  253. void analyzeFunction(Function &F)
  254. {
  255. for (BasicBlock &BB : F)
  256. {
  257. for (Instruction &I : BB)
  258. {
  259. if (CallInst *CI = dyn_cast<CallInst>(&I))
  260. {
  261. // 分析函数调用
  262. Function *CalledF = CI->getCalledFunction();
  263. if (CalledF && CalledF->isDeclaration())
  264. {
  265. errs() << "External function call: " << CalledF->getName() << "\n";
  266. // 这里可以添加区分变量
  267. }
  268. }
  269. // 可以添加更多的分析,如全局变量访问等
  270. }
  271. }
  272. }
  273. void selectFusionPoints(Function &F) {
  274. errs() << "\n=== Starting Fusion Points Selection for function: "
  275. << F.getName() << " ===\n";
  276. fusionPoints.clear();
  277. criticalPoints.clear();
  278. std::set<BasicBlock*> visited;
  279. // 首先收集所有可能的关键点
  280. for (BasicBlock &BB : F) {
  281. if (checkCriticalPoint(&BB)) {
  282. criticalPoints.insert(&BB);
  283. }
  284. }
  285. errs() << "\nIdentified " << criticalPoints.size() << " potential critical points\n";
  286. // 按照支配关系排序关键点
  287. std::vector<BasicBlock*> orderedPoints;
  288. BasicBlock *entryBlock = &F.getEntryBlock();
  289. std::vector<BasicBlock*> workList;
  290. workList.push_back(entryBlock);
  291. while (!workList.empty()) {
  292. BasicBlock *current = workList.back();
  293. workList.pop_back();
  294. if (visited.count(current))
  295. continue;
  296. visited.insert(current);
  297. if (criticalPoints.count(current)) {
  298. orderedPoints.push_back(current);
  299. errs() << "Adding ordered critical point: " << current->getName() << "\n";
  300. }
  301. for (BasicBlock *succ : successors(current)) {
  302. if (!visited.count(succ)) {
  303. workList.push_back(succ);
  304. }
  305. }
  306. }
  307. fusionPoints = orderedPoints;
  308. errs() << "\nSelected " << fusionPoints.size() << " fusion points in order:\n";
  309. for (BasicBlock *BB : fusionPoints) {
  310. errs() << " " << BB->getName() << "\n";
  311. }
  312. errs() << "\nValidating fusion points:\n";
  313. for (size_t i = 0; i < fusionPoints.size(); ++i) {
  314. BasicBlock *current = fusionPoints[i];
  315. errs() << "Fusion point " << i << ": " << current->getName() << "\n";
  316. errs() << "Instructions in this fusion point:\n";
  317. for (Instruction &I : *current) {
  318. errs() << " " << I << "\n";
  319. }
  320. if (i < fusionPoints.size() - 1) {
  321. BasicBlock *next = fusionPoints[i + 1];
  322. errs() << "Distance to next fusion point: "
  323. << std::distance(current->getIterator(), next->getIterator())
  324. << " blocks\n";
  325. }
  326. }
  327. errs() << "=== Finished Fusion Points Selection ===\n\n";
  328. }
  329. void sliceTargetCode(Function &F)
  330. {
  331. errs() << "Slicing target code\n";
  332. // 目标代码分片逻辑
  333. }
  334. void fuseCode(Function &F)
  335. {
  336. errs() << "Fusing code\n";
  337. // 代码融合逻辑
  338. }
  339. };
  340. }
  341. char CodeFusionPass::ID = 0;
  342. static RegisterPass<CodeFusionPass> X("codefusion", "Code Fusion Pass");