run_param_fusion.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 运行参数传递法融合,并生成不带注释标记的代码文件
  5. """
  6. import os
  7. import sys
  8. import json
  9. import re
  10. # 添加 src 目录到路径
  11. sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'src'))
  12. from main import CodeFusionProcessor
  13. def remove_fusion_markers(code: str) -> str:
  14. """移除融合标记注释"""
  15. # 移除 /* === Fused Code Start === */ 和 /* === Fused Code End === */ 及其包裹的内容保持
  16. patterns = [
  17. r'/\*\s*===\s*Fused Code Start\s*===\s*\*/\s*\n?',
  18. r'/\*\s*===\s*Fused Code End\s*===\s*\*/\s*\n?',
  19. r'/\*\s*中间层函数.*?\*/\s*\n?',
  20. ]
  21. result = code
  22. for pattern in patterns:
  23. result = re.sub(pattern, '', result)
  24. # 清理多余的空行
  25. result = re.sub(r'\n{3,}', '\n\n', result)
  26. return result
  27. def generate_clean_code_file(result, target_code: str) -> str:
  28. """生成干净的代码文件(不带标记注释)"""
  29. lines = []
  30. # 文件头
  31. lines.append("/*")
  32. lines.append(" * 参数传递法融合代码")
  33. lines.append(f" * 调用链: {' -> '.join(result['call_chain'])}")
  34. lines.append(f" * 调用深度: {result['call_depth']}")
  35. lines.append(" *")
  36. lines.append(" * 原始目标代码:")
  37. for line in target_code.strip().split('\n'):
  38. lines.append(f" * {line.strip()}")
  39. lines.append(" */")
  40. lines.append("")
  41. # 头文件
  42. lines.append("#include <stdio.h>")
  43. lines.append("#include <stdlib.h>")
  44. lines.append("#include <string.h>")
  45. lines.append("")
  46. # 结构体定义(全局状态)
  47. lines.append("/* 共享状态结构体 */")
  48. lines.append("typedef struct {")
  49. lines.append(" int secret;")
  50. lines.append(" int key;")
  51. lines.append("} FusionState;")
  52. lines.append("")
  53. lines.append("/* 全局状态指针 */")
  54. lines.append("static FusionState* fusion_state = NULL;")
  55. lines.append("")
  56. # 函数定义(从最内层到最外层)
  57. lines.append("/* ========== 函数定义 ========== */")
  58. lines.append("")
  59. fused_code = result.get('fused_code', {})
  60. call_chain = result.get('call_chain', [])
  61. for func_name in reversed(call_chain):
  62. if func_name in fused_code:
  63. lines.append(f"/* {func_name} */")
  64. clean_code = remove_fusion_markers(fused_code[func_name])
  65. lines.append(clean_code)
  66. lines.append("")
  67. return '\n'.join(lines)
  68. def main():
  69. # 配置
  70. input_path = "output/primevul_valid_grouped_depth_4.json"
  71. output_json = "output/fusion_param_clean.json"
  72. output_code = "output/fused_code/param_fusion_clean.c"
  73. target_code = "int secret = 42; int key = secret ^ 0xABCD; printf(\"key=%d\", key);"
  74. # 检查输入文件
  75. project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  76. input_full_path = os.path.join(project_root, input_path)
  77. if not os.path.exists(input_full_path):
  78. print(f"Error: Input file not found: {input_full_path}")
  79. sys.exit(1)
  80. print("=" * 60)
  81. print("参数传递法融合(无标记注释)")
  82. print("=" * 60)
  83. print(f"\n目标代码: {target_code}")
  84. print(f"输入文件: {input_path}")
  85. print(f"输出JSON: {output_json}")
  86. print(f"输出代码: {output_code}")
  87. print("")
  88. # 创建处理器
  89. processor = CodeFusionProcessor(
  90. enable_verification=False, # 禁用验证以加快速度
  91. enable_syntax_check=False,
  92. enable_semantic_check=False
  93. )
  94. # 加载数据
  95. data = processor.load_data(input_full_path)
  96. groups = data.get('groups', [])
  97. print(f"共有 {len(groups)} 个调用链组")
  98. print(f"选择第一个组进行融合...")
  99. print("")
  100. # 处理第一个组
  101. group = groups[0]
  102. result = processor.process_group(
  103. group,
  104. target_code,
  105. group_index=0,
  106. passing_method="parameter"
  107. )
  108. if not result.success:
  109. print(f"融合失败: {result.error_message}")
  110. sys.exit(1)
  111. print(f"融合成功!")
  112. print(f"调用链: {' -> '.join(result.call_chain)}")
  113. print(f"融合点数: {result.total_fusion_points}")
  114. print("")
  115. # 保存 JSON 结果
  116. output_data = {
  117. "metadata": {
  118. "target_code": target_code,
  119. "passing_method": "parameter",
  120. "total_processed": 1,
  121. "successful": 1
  122. },
  123. "results": [{
  124. "group_index": result.group_index,
  125. "call_chain": result.call_chain,
  126. "call_depth": result.call_depth,
  127. "functions_count": result.functions_count,
  128. "total_fusion_points": result.total_fusion_points,
  129. "success": result.success,
  130. "fused_code": result.fused_code
  131. }]
  132. }
  133. output_json_path = os.path.join(project_root, output_json)
  134. os.makedirs(os.path.dirname(output_json_path), exist_ok=True)
  135. with open(output_json_path, 'w', encoding='utf-8') as f:
  136. json.dump(output_data, f, ensure_ascii=False, indent=2)
  137. print(f"JSON 结果已保存到: {output_json}")
  138. # 生成干净的代码文件
  139. clean_code = generate_clean_code_file(output_data['results'][0], target_code)
  140. output_code_path = os.path.join(project_root, output_code)
  141. os.makedirs(os.path.dirname(output_code_path), exist_ok=True)
  142. with open(output_code_path, 'w', encoding='utf-8') as f:
  143. f.write(clean_code)
  144. print(f"代码文件已保存到: {output_code}")
  145. print("")
  146. print("=" * 60)
  147. print("融合后的代码预览:")
  148. print("=" * 60)
  149. print(clean_code)
  150. if __name__ == '__main__':
  151. main()