CppFusion.cpp 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. #include "CppFusion.h"
  2. #include "../Util/LogSystem.h"
  3. #include "llvm/IR/IRBuilder.h"
  4. #include "llvm/Transforms/Utils/Cloning.h"
  5. #include "llvm/IR/CFG.h"
  6. #include "llvm/IR/Verifier.h"
  7. #include <random>
  8. #include <algorithm>
  9. using namespace llvm;
  10. namespace slicefusion {
  11. bool CppFusion::matchFunctionsForFusion(const std::set<std::string>& targetFunctions,
  12. const std::set<std::string>& bunkerFunctions) {
  13. auto& logger = logging::LogSystem::getInstance();
  14. logger.enableFunction("matchFunctionsForFusion");
  15. LOG_INFO("matchFunctionsForFusion", "Starting C++ function matching process");
  16. // 直接指定要融合的函数对
  17. const std::string SPECIFIED_TARGET = "_ZN8ProjectB10testPointsEi";
  18. const std::string SPECIFIED_BUNKER = "_ZN8ProjectA9expandKeyEPhS0_NS_7keySizeEm";
  19. LOG_INFO("matchFunctionsForFusion", "Using specified function pair:");
  20. LOG_INFO("matchFunctionsForFusion", " Target: {0}", SPECIFIED_TARGET);
  21. LOG_INFO("matchFunctionsForFusion", " Bunker: {0}", SPECIFIED_BUNKER);
  22. // 检查指定的函数是否存在于集合中
  23. if (targetFunctions.find(SPECIFIED_TARGET) == targetFunctions.end()) {
  24. LOG_ERROR("matchFunctionsForFusion", "Specified target function {0} not found in target functions", SPECIFIED_TARGET);
  25. return false;
  26. }
  27. if (bunkerFunctions.find(SPECIFIED_BUNKER) == bunkerFunctions.end()) {
  28. LOG_ERROR("matchFunctionsForFusion", "Specified bunker function {0} not found in bunker functions", SPECIFIED_BUNKER);
  29. return false;
  30. }
  31. // 检查函数的融合点是否足够
  32. const auto& targetNode = callGraph[SPECIFIED_TARGET];
  33. const auto& bunkerNode = callGraph[SPECIFIED_BUNKER];
  34. LOG_INFO("matchFunctionsForFusion", "Target function {0} has {1} slices", SPECIFIED_TARGET, targetNode.slices_num);
  35. LOG_INFO("matchFunctionsForFusion", "Bunker function {0} has {1} fusion points", SPECIFIED_BUNKER, bunkerNode.points_num);
  36. if (bunkerNode.points_num < targetNode.slices_num) {
  37. LOG_ERROR("matchFunctionsForFusion",
  38. "Insufficient fusion points in bunker function {0} ({1}) for target function {2} ({3})",
  39. SPECIFIED_BUNKER, bunkerNode.points_num,
  40. SPECIFIED_TARGET, targetNode.slices_num);
  41. return false;
  42. }
  43. // 建立匹配关系
  44. fusionPairs[SPECIFIED_TARGET] = SPECIFIED_BUNKER;
  45. LOG_INFO("matchFunctionsForFusion",
  46. "Successfully matched target {0} ({1} slices) with bunker {2} ({3} fusion points)",
  47. SPECIFIED_TARGET, targetNode.slices_num,
  48. SPECIFIED_BUNKER, bunkerNode.points_num);
  49. return true;
  50. }
  51. void CppFusion::performCodeFusion(Module &M) {
  52. auto& logger = logging::LogSystem::getInstance();
  53. logger.enableFunction("performCodeFusion");
  54. LOG_INFO("performCodeFusion", "Starting C++ code fusion");
  55. // 调用基类的performCodeFusion来融合除main之外的函数
  56. Fusion::performCodeFusion(M);
  57. // 处理main函数:将projectB_main重命名为main,并删除projectA_main
  58. const std::string CPP_TARGET_MAIN = "_Z13projectB_mainv";
  59. const std::string CPP_BUNKER_MAIN = "_Z13projectA_mainv";
  60. if (Function* targetMain = M.getFunction(CPP_TARGET_MAIN)) {
  61. LOG_INFO("performCodeFusion", "Renaming {0} to main", CPP_TARGET_MAIN);
  62. targetMain->setName("main");
  63. } else {
  64. LOG_ERROR("performCodeFusion", "Could not find target main function {0} to rename", CPP_TARGET_MAIN);
  65. }
  66. if (Function* bunkerMain = M.getFunction(CPP_BUNKER_MAIN)) {
  67. LOG_INFO("performCodeFusion", "Removing bunker main function {0}", CPP_BUNKER_MAIN);
  68. bunkerMain->eraseFromParent();
  69. } else {
  70. LOG_WARNING("performCodeFusion", "Could not find bunker main function {0} to remove", CPP_BUNKER_MAIN);
  71. }
  72. // // 最终的模块验证
  73. // LOG_INFO("performCodeFusion", "Performing final C++ module verification");
  74. // std::string errorStr;
  75. // llvm::raw_string_ostream errorStream(errorStr);
  76. // if (llvm::verifyModule(M, &errorStream)) {
  77. // LOG_ERROR("performCodeFusion", "C++ module verification failed: {0}", errorStr);
  78. // // 同时输出到标准错误流
  79. // llvm::errs() << "C++ module verification failed:\n" << errorStr << "\n";
  80. // } else {
  81. // LOG_INFO("performCodeFusion", "C++ module verification passed successfully");
  82. // }
  83. }
  84. } // namespace slicefusion