run_param_fusion.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 参数传递法融合 - 支持多参数传递和多组测试
  5. """
  6. import os
  7. import sys
  8. import json
  9. import re
  10. import argparse
  11. sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'src'))
  12. from openai import OpenAI
  13. def get_llm_client():
  14. api_key = os.getenv("DASHSCOPE_API_KEY")
  15. if not api_key:
  16. raise ValueError("请设置 DASHSCOPE_API_KEY 环境变量")
  17. return OpenAI(api_key=api_key, base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
  18. def get_original_functions(functions: list, call_chain: list) -> dict:
  19. result = {}
  20. for func_name in call_chain:
  21. for func in functions:
  22. code = func.get('func', '')
  23. if func_name in code:
  24. result[func_name] = code
  25. break
  26. return result
  27. def create_prompt(target_code: str, original_funcs: dict, call_chain: list) -> str:
  28. funcs_text = ""
  29. for name in call_chain:
  30. if name in original_funcs:
  31. funcs_text += f"=== {name} ===\n{original_funcs[name]}\n\n"
  32. n = len(call_chain)
  33. return f"""将目标代码通过参数传递方式融合到调用链函数中。
  34. 目标代码:
  35. {target_code}
  36. 调用链 ({n} 层): {' -> '.join(call_chain)}
  37. 原始函数:
  38. {funcs_text}
  39. 融合规则(参数传递法):
  40. 1. 分析目标代码中的所有变量和操作
  41. 2. 将变量初始化、计算、使用分散到调用链的不同层级
  42. 3. 通过添加函数参数(指针)在层级间传递变量
  43. 4. 每个函数可以传递多个参数
  44. 具体要求:
  45. - 第1层({call_chain[0]}):定义初始变量,通过指针传递给下一层
  46. - 中间层:接收上层参数,执行计算,传递结果给下一层
  47. - 最后层({call_chain[-1]}):接收参数,执行最终操作(如printf)
  48. 输出要求:
  49. - 每个函数输出完整代码
  50. - 不要添加任何注释
  51. - 保持原函数逻辑完整
  52. 返回格式:
  53. {{
  54. {', '.join([f'"{name}": "完整函数代码"' for name in call_chain])}
  55. }}"""
  56. def remove_comments(code: str) -> str:
  57. code = re.sub(r'//.*?$', '', code, flags=re.MULTILINE)
  58. code = re.sub(r'/\*[\s\S]*?\*/', '', code)
  59. code = re.sub(r'\n{3,}', '\n\n', code)
  60. return code.strip()
  61. def parse_response(response: str) -> dict:
  62. def try_parse(text):
  63. try:
  64. return json.loads(text)
  65. except json.JSONDecodeError:
  66. pass
  67. return None
  68. match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', response)
  69. if match:
  70. result = try_parse(match.group(1))
  71. if result:
  72. return result
  73. result = try_parse(response)
  74. if result:
  75. return result
  76. match = re.search(r'\{[\s\S]*\}', response)
  77. if match:
  78. result = try_parse(match.group(0))
  79. if result:
  80. return result
  81. try:
  82. result = {}
  83. func_pattern = r'"(\w+)":\s*"((?:[^"\\]|\\.)*)(?:"|$)'
  84. for match in re.finditer(func_pattern, response, re.DOTALL):
  85. name = match.group(1)
  86. code = match.group(2)
  87. code = code.replace('\\n', '\n').replace('\\t', '\t').replace('\\"', '"')
  88. result[name] = code
  89. if result:
  90. return result
  91. except:
  92. pass
  93. return None
  94. def process_group(client, group: dict, target_code: str, group_idx: int) -> dict:
  95. """处理单个调用链组"""
  96. functions = group['functions']
  97. call_chain = group['longest_call_chain']
  98. original_funcs = get_original_functions(functions, call_chain)
  99. if len(original_funcs) < len(call_chain):
  100. return {"success": False, "error": "无法提取所有函数", "call_chain": call_chain}
  101. prompt = create_prompt(target_code, original_funcs, call_chain)
  102. try:
  103. completion = client.chat.completions.create(
  104. model="qwen-plus",
  105. messages=[
  106. {"role": "system", "content": "你是代码融合专家。只返回JSON,不要添加任何注释到代码中。"},
  107. {"role": "user", "content": prompt}
  108. ],
  109. temperature=0.2,
  110. )
  111. response = completion.choices[0].message.content
  112. result = parse_response(response)
  113. if not result:
  114. return {"success": False, "error": "JSON解析失败", "call_chain": call_chain}
  115. for name in result:
  116. result[name] = remove_comments(result[name])
  117. return {
  118. "success": True,
  119. "group_idx": group_idx,
  120. "call_chain": call_chain,
  121. "fused_functions": result
  122. }
  123. except Exception as e:
  124. return {"success": False, "error": str(e), "call_chain": call_chain}
  125. def generate_code_file(result: dict) -> str:
  126. """生成代码文件内容"""
  127. call_chain = result['call_chain']
  128. fused_functions = result['fused_functions']
  129. lines = ["#include <stdio.h>", "#include <stdlib.h>", "#include <string.h>", ""]
  130. for name in reversed(call_chain):
  131. if name in fused_functions:
  132. lines.append(fused_functions[name])
  133. lines.append("")
  134. return '\n'.join(lines)
  135. def main():
  136. parser = argparse.ArgumentParser(description='参数传递法融合')
  137. parser.add_argument('--target', '-t', type=str, default=None, help='目标代码')
  138. parser.add_argument('--groups', '-g', type=int, default=1, help='测试组数(默认1)')
  139. parser.add_argument('--multi', '-m', action='store_true', help='使用多参数测试用例')
  140. args = parser.parse_args()
  141. project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  142. input_path = os.path.join(project_root, "output/primevul_valid_grouped_depth_4.json")
  143. output_dir = os.path.join(project_root, "output")
  144. code_dir = os.path.join(output_dir, "fused_code")
  145. if args.target:
  146. target_code = args.target
  147. elif args.multi:
  148. target_code = 'int a = 10; int b = 20; int c = a + b; printf("sum=%d, a=%d, b=%d", c, a, b);'
  149. else:
  150. target_code = 'int secret = 42; int key = secret ^ 0xABCD; printf("key=%d", key);'
  151. print("=" * 60)
  152. print(f"参数传递法融合 - 测试 {args.groups} 组")
  153. print("=" * 60)
  154. print(f"目标代码: {target_code}\n")
  155. with open(input_path, 'r', encoding='utf-8') as f:
  156. data = json.load(f)
  157. groups = data['groups']
  158. num_groups = min(args.groups, len(groups))
  159. print(f"可用调用链组: {len(groups)}")
  160. print(f"将测试: {num_groups} 组\n")
  161. client = get_llm_client()
  162. results = []
  163. success_count = 0
  164. for i in range(num_groups):
  165. group = groups[i]
  166. call_chain = group['longest_call_chain']
  167. print(f"[{i+1}/{num_groups}] 处理: {' -> '.join(call_chain[:2])}...")
  168. result = process_group(client, group, target_code, i)
  169. results.append(result)
  170. if result['success']:
  171. success_count += 1
  172. print(f" ✓ 成功")
  173. # 保存单独的代码文件
  174. chain_name = "_".join(call_chain[:2])
  175. code_file = os.path.join(code_dir, f"param_group_{i}_{chain_name}.c")
  176. code_content = generate_code_file(result)
  177. os.makedirs(code_dir, exist_ok=True)
  178. with open(code_file, 'w', encoding='utf-8') as f:
  179. f.write(code_content)
  180. else:
  181. print(f" ✗ 失败: {result['error']}")
  182. # 保存汇总 JSON
  183. output_json = os.path.join(output_dir, "fusion_param_results.json")
  184. output_data = {
  185. "metadata": {
  186. "target_code": target_code,
  187. "passing_method": "parameter",
  188. "total_groups": num_groups,
  189. "success_count": success_count,
  190. "failed_count": num_groups - success_count
  191. },
  192. "results": results
  193. }
  194. with open(output_json, 'w', encoding='utf-8') as f:
  195. json.dump(output_data, f, ensure_ascii=False, indent=2)
  196. print("\n" + "=" * 60)
  197. print("测试结果汇总")
  198. print("=" * 60)
  199. print(f"成功: {success_count}/{num_groups}")
  200. print(f"失败: {num_groups - success_count}/{num_groups}")
  201. print(f"JSON: {output_json}")
  202. print(f"代码目录: {code_dir}")
  203. # 显示成功的结果
  204. if success_count > 0:
  205. print("\n成功的调用链:")
  206. for r in results:
  207. if r['success']:
  208. print(f" - Group {r['group_idx']}: {' -> '.join(r['call_chain'])}")
  209. if __name__ == '__main__':
  210. main()