CallGraphPass.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. #include "llvm/Pass.h"
  2. #include "llvm/IR/Function.h"
  3. #include "llvm/IR/Module.h"
  4. #include "llvm/IR/BasicBlock.h"
  5. #include "llvm/IR/Instructions.h"
  6. #include "llvm/IR/LegacyPassManager.h"
  7. #include "llvm/Transforms/IPO/PassManagerBuilder.h"
  8. #include "llvm/Support/raw_ostream.h"
  9. #include "llvm/Analysis/CFG.h"
  10. #include "llvm/Analysis/CallGraph.h"
  11. #include "LogSystem.h"
  12. #include <vector>
  13. #include <map>
  14. #include <set>
  15. #include <queue>
  16. #include <algorithm>
  17. #include <string>
  18. using namespace llvm;
  19. namespace {
  20. struct FunctionSignature {
  21. unsigned inDegree;
  22. unsigned outDegree;
  23. unsigned depth;
  24. std::vector<unsigned> callerDepths;
  25. std::vector<unsigned> calleeDepths;
  26. FunctionSignature() : inDegree(0), outDegree(0), depth(0) {}
  27. bool operator==(const FunctionSignature& other) const {
  28. return inDegree == other.inDegree &&
  29. outDegree == other.outDegree &&
  30. depth == other.depth &&
  31. callerDepths == other.callerDepths &&
  32. calleeDepths == other.calleeDepths;
  33. }
  34. };
  35. // 函数调用图的节点结构
  36. struct CallNode {
  37. std::string name;
  38. bool isTarget;
  39. std::set<std::string> callers;
  40. std::set<std::string> callees;
  41. unsigned depth; // 在调用树中的深度
  42. CallNode() : name(""), isTarget(false), depth(0) {}
  43. CallNode(std::string n, bool target = false)
  44. : name(n), isTarget(target), depth(0) {}
  45. };
  46. struct CodeFusionPass : public ModulePass {
  47. public:
  48. static char ID;
  49. CodeFusionPass() : ModulePass(ID) {}
  50. bool runOnModule(Module &M) override {
  51. auto& logger = logging::LogSystem::getInstance();
  52. logger.setGlobalLevel(logging::LogLevel::DEBUG);
  53. LOG_INFO("runOnModule", "Starting analysis for module: {0}", M.getName().str());
  54. // 识别所有目标函数
  55. for (Function &F : M) {
  56. if (!F.isDeclaration() && isTargetCode(F)) {
  57. targetFunctions.insert(F.getName().str());
  58. }
  59. }
  60. // 构建完整的调用图
  61. buildCallGraph(M);
  62. // 计算每个函数的调用深度
  63. calculateCallDepths();
  64. // 生成调用图的可视化
  65. generateCallGraphs();
  66. // 分析相似结构
  67. findSimilarStructures();
  68. return false;
  69. }
  70. private:
  71. std::map<std::string, CallNode> callGraph;
  72. std::set<std::string> targetFunctions;
  73. std::map<std::string, unsigned> maxCallDepths;
  74. void buildCallGraph(Module &M) {
  75. LOG_INFO("buildCallGraph", "Building complete call graph");
  76. // 初始化所有函数节点
  77. for (Function &F : M) {
  78. if (!F.isDeclaration()) {
  79. std::string fname = F.getName().str();
  80. bool isTarget = targetFunctions.find(fname) != targetFunctions.end();
  81. callGraph.insert({fname, CallNode(fname, isTarget)});
  82. }
  83. }
  84. // 分析函数调用关系
  85. for (Function &F : M) {
  86. if (!F.isDeclaration()) {
  87. std::string callerName = F.getName().str();
  88. for (BasicBlock &BB : F) {
  89. for (Instruction &I : BB) {
  90. if (CallInst *CI = dyn_cast<CallInst>(&I)) {
  91. Function *CalledF = CI->getCalledFunction();
  92. if (CalledF && !CalledF->isDeclaration()) {
  93. std::string calleeName = CalledF->getName().str();
  94. // 更新调用关系
  95. callGraph[callerName].callees.insert(calleeName);
  96. callGraph[calleeName].callers.insert(callerName);
  97. LOG_DEBUG("buildCallGraph",
  98. "Found call: {0} -> {1}", callerName, calleeName);
  99. }
  100. }
  101. }
  102. }
  103. }
  104. }
  105. }
  106. void calculateCallDepths() {
  107. LOG_INFO("calculateCallDepths", "Calculating call depths");
  108. for (const std::string &targetFunc : targetFunctions) {
  109. std::map<std::string, unsigned> depths;
  110. std::queue<std::pair<std::string, unsigned>> queue;
  111. std::set<std::string> visited;
  112. queue.push(std::make_pair(targetFunc, 0));
  113. visited.insert(targetFunc);
  114. while (!queue.empty()) {
  115. std::string currentFunc = queue.front().first;
  116. unsigned depth = queue.front().second;
  117. queue.pop();
  118. depths[currentFunc] = depth;
  119. for (const std::string &caller : callGraph[currentFunc].callers) {
  120. if (visited.find(caller) == visited.end()) {
  121. queue.push(std::make_pair(caller, depth + 1));
  122. visited.insert(caller);
  123. }
  124. }
  125. for (const std::string &callee : callGraph[currentFunc].callees) {
  126. if (visited.find(callee) == visited.end()) {
  127. queue.push(std::make_pair(callee, depth + 1));
  128. visited.insert(callee);
  129. }
  130. }
  131. }
  132. // 更新最大深度
  133. for (const auto &pair : depths) {
  134. const std::string &func = pair.first;
  135. unsigned depth = pair.second;
  136. maxCallDepths[func] = std::max(maxCallDepths[func], depth);
  137. }
  138. }
  139. }
  140. void generateCallGraphs() {
  141. LOG_INFO("generateCallGraphs", "Generating complete call graphs");
  142. // 分别生成目标项目和掩体项目的调用图
  143. generateProjectCallGraph("Target Project Call Graph", true);
  144. generateProjectCallGraph("Cover Project Call Graph", false);
  145. }
  146. void generateProjectCallGraph(const std::string &title, bool isTarget) {
  147. errs() << "```mermaid\n";
  148. errs() << "graph TD\n";
  149. errs() << " %% " << title << "\n";
  150. // 添加节点
  151. for (const auto &pair : callGraph) {
  152. const std::string &name = pair.first;
  153. const CallNode &node = pair.second;
  154. if (node.isTarget == isTarget) {
  155. std::string nodeId = sanitizeNodeId(name);
  156. std::string depth = std::to_string(maxCallDepths[name]);
  157. std::string style = node.isTarget ? ":::target" : "";
  158. errs() << " " << nodeId << "[\"" << name
  159. << "\\nDepth: " << depth << "\"]" << style << "\n";
  160. }
  161. }
  162. // 添加边
  163. for (const auto &pair : callGraph) {
  164. const std::string &name = pair.first;
  165. const CallNode &node = pair.second;
  166. if (node.isTarget == isTarget) {
  167. std::string callerId = sanitizeNodeId(name);
  168. for (const auto &callee : node.callees) {
  169. if (callGraph.at(callee).isTarget == isTarget) {
  170. std::string calleeId = sanitizeNodeId(callee);
  171. errs() << " " << callerId << " --> " << calleeId << "\n";
  172. }
  173. }
  174. }
  175. }
  176. // 添加样式定义
  177. errs() << " classDef target fill:#f96,stroke:#333,stroke-width:4px\n";
  178. errs() << "```\n\n";
  179. }
  180. void findSimilarStructures() {
  181. LOG_INFO("findSimilarStructures", "Analyzing call graph similarities");
  182. // 为每个函数计算特征签名
  183. std::map<std::string, FunctionSignature> signatures;
  184. for (const auto &pair : callGraph) {
  185. const std::string &name = pair.first;
  186. const CallNode &node = pair.second;
  187. FunctionSignature sig;
  188. sig.inDegree = node.callers.size();
  189. sig.outDegree = node.callees.size();
  190. sig.depth = maxCallDepths[name];
  191. // 收集调用者深度
  192. for (const auto &caller : node.callers) {
  193. sig.callerDepths.push_back(maxCallDepths[caller]);
  194. }
  195. std::sort(sig.callerDepths.begin(), sig.callerDepths.end());
  196. // 收集被调用者深度
  197. for (const auto &callee : node.callees) {
  198. sig.calleeDepths.push_back(maxCallDepths[callee]);
  199. }
  200. std::sort(sig.calleeDepths.begin(), sig.calleeDepths.end());
  201. signatures[name] = sig;
  202. }
  203. // 比较目标函数和掩体函数的相似度
  204. for (const auto &targetFunc : targetFunctions) {
  205. const auto &targetSig = signatures[targetFunc];
  206. std::vector<std::pair<std::string, double>> similarities;
  207. for (const auto &pair : callGraph) {
  208. const std::string &name = pair.first;
  209. const CallNode &node = pair.second;
  210. if (!node.isTarget) {
  211. const auto &coverSig = signatures[name];
  212. // 计算相似度得分
  213. double similarity = calculateSignatureSimilarity(targetSig, coverSig);
  214. if (similarity > 0.8) { // 相似度阈值
  215. similarities.emplace_back(name, similarity);
  216. }
  217. }
  218. }
  219. // 输出相似函数
  220. if (!similarities.empty()) {
  221. LOG_INFO("findSimilarStructures",
  222. "Similar functions for {0}:", targetFunc);
  223. for (const auto &pair : similarities) {
  224. const std::string &name = pair.first;
  225. double similarity = pair.second;
  226. LOG_INFO("findSimilarStructures",
  227. " {0} (similarity: {1:.2f})", name, similarity);
  228. }
  229. }
  230. }
  231. }
  232. double calculateSignatureSimilarity(
  233. const FunctionSignature &sig1,
  234. const FunctionSignature &sig2) {
  235. double score = 0.0;
  236. unsigned totalFactors = 0;
  237. // 比较入度和出度
  238. if (sig1.inDegree > 0 || sig2.inDegree > 0) {
  239. score += 1.0 - std::abs(int(sig1.inDegree) - int(sig2.inDegree)) /
  240. double(std::max(sig1.inDegree, sig2.inDegree));
  241. totalFactors++;
  242. }
  243. if (sig1.outDegree > 0 || sig2.outDegree > 0) {
  244. score += 1.0 - std::abs(int(sig1.outDegree) - int(sig2.outDegree)) /
  245. double(std::max(sig1.outDegree, sig2.outDegree));
  246. totalFactors++;
  247. }
  248. // 比较深度
  249. if (sig1.depth > 0 || sig2.depth > 0) {
  250. score += 1.0 - std::abs(int(sig1.depth) - int(sig2.depth)) /
  251. double(std::max(sig1.depth, sig2.depth));
  252. totalFactors++;
  253. }
  254. // 比较调用者深度分布
  255. if (!sig1.callerDepths.empty() && !sig2.callerDepths.empty()) {
  256. score += compareDepthVectors(sig1.callerDepths, sig2.callerDepths);
  257. totalFactors++;
  258. }
  259. // 比较被调用者深度分布
  260. if (!sig1.calleeDepths.empty() && !sig2.calleeDepths.empty()) {
  261. score += compareDepthVectors(sig1.calleeDepths, sig2.calleeDepths);
  262. totalFactors++;
  263. }
  264. return totalFactors > 0 ? score / totalFactors : 0.0;
  265. }
  266. double compareDepthVectors(
  267. const std::vector<unsigned> &v1,
  268. const std::vector<unsigned> &v2) {
  269. size_t maxSize = std::max(v1.size(), v2.size());
  270. size_t minSize = std::min(v1.size(), v2.size());
  271. // 首先比较向量大小的相似度
  272. double sizeSimilarity = double(minSize) / maxSize;
  273. // 然后比较实际值的相似度
  274. double valueSimilarity = 0.0;
  275. for (size_t i = 0; i < minSize; i++) {
  276. valueSimilarity += 1.0 - std::abs(int(v1[i]) - int(v2[i])) /
  277. double(std::max(v1[i], v2[i]));
  278. }
  279. valueSimilarity /= maxSize;
  280. return (sizeSimilarity + valueSimilarity) / 2.0;
  281. }
  282. std::string sanitizeNodeId(const std::string &name) {
  283. std::string id = name;
  284. std::replace(id.begin(), id.end(), '.', '_');
  285. std::replace(id.begin(), id.end(), ' ', '_');
  286. std::replace(id.begin(), id.end(), '-', '_');
  287. return id;
  288. }
  289. bool isTargetCode(Function &F) {
  290. if (MDNode *MD = F.getMetadata("project_source")) {
  291. if (MDString *ProjectStr = dyn_cast<MDString>(MD->getOperand(0))) {
  292. std::string projectName = ProjectStr->getString().str();
  293. return (projectName == "Target");
  294. }
  295. }
  296. return false;
  297. }
  298. };
  299. }
  300. char CodeFusionPass::ID = 0;
  301. static RegisterPass<CodeFusionPass> X("codefusion", "Code Fusion Pass");