extract_call_relations.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 分析代码函数的 caller 和 callee 关系,将有调用关系的函数合并为组。
  5. """
  6. import json
  7. import re
  8. import os
  9. import argparse
  10. from collections import defaultdict
  11. from typing import Dict, List, Set, Tuple, Optional
  12. # 常见的 C/C++ 库函数和系统调用,这些不应该作为连接不同函数组的依据
  13. COMMON_LIB_FUNCTIONS = {
  14. # 内存管理
  15. 'malloc', 'calloc', 'realloc', 'free', 'memcpy', 'memmove', 'memset',
  16. 'memcmp', 'memchr', 'alloca', 'aligned_alloc',
  17. # 字符串处理
  18. 'strlen', 'strcpy', 'strncpy', 'strcat', 'strncat', 'strcmp', 'strncmp',
  19. 'strchr', 'strrchr', 'strstr', 'strtok', 'strdup', 'strndup', 'strspn',
  20. 'strcspn', 'strpbrk', 'strerror', 'sprintf', 'snprintf', 'vsprintf',
  21. 'vsnprintf', 'sscanf',
  22. # 输入输出
  23. 'printf', 'fprintf', 'vprintf', 'vfprintf', 'puts', 'fputs', 'putc',
  24. 'fputc', 'putchar', 'gets', 'fgets', 'getc', 'fgetc', 'getchar',
  25. 'scanf', 'fscanf', 'fopen', 'fclose', 'fread', 'fwrite', 'fseek',
  26. 'ftell', 'rewind', 'fflush', 'feof', 'ferror', 'clearerr', 'perror',
  27. # 类型转换
  28. 'atoi', 'atol', 'atoll', 'atof', 'strtol', 'strtoll', 'strtoul',
  29. 'strtoull', 'strtof', 'strtod', 'strtold',
  30. # 数学函数
  31. 'abs', 'labs', 'llabs', 'fabs', 'floor', 'ceil', 'round', 'sqrt',
  32. 'pow', 'exp', 'log', 'log10', 'sin', 'cos', 'tan', 'asin', 'acos',
  33. 'atan', 'atan2', 'min', 'max',
  34. # 时间函数
  35. 'time', 'clock', 'difftime', 'mktime', 'strftime', 'localtime',
  36. 'gmtime', 'asctime', 'ctime', 'gettimeofday', 'sleep', 'usleep',
  37. 'nanosleep',
  38. # 进程和信号
  39. 'exit', 'abort', '_exit', 'atexit', 'system', 'getenv', 'setenv',
  40. 'fork', 'exec', 'execl', 'execv', 'execle', 'execve', 'execlp',
  41. 'execvp', 'wait', 'waitpid', 'kill', 'signal', 'raise',
  42. # 断言和错误处理
  43. 'assert', 'errno', 'setjmp', 'longjmp',
  44. # POSIX 和系统调用
  45. 'open', 'close', 'read', 'write', 'lseek', 'stat', 'fstat', 'lstat',
  46. 'access', 'chmod', 'chown', 'link', 'unlink', 'rename', 'mkdir',
  47. 'rmdir', 'opendir', 'closedir', 'readdir', 'getcwd', 'chdir',
  48. 'pipe', 'dup', 'dup2', 'fcntl', 'ioctl', 'select', 'poll', 'mmap',
  49. 'munmap', 'mprotect', 'socket', 'bind', 'listen', 'accept', 'connect',
  50. 'send', 'recv', 'sendto', 'recvfrom', 'shutdown', 'setsockopt',
  51. 'getsockopt', 'pthread_create', 'pthread_join', 'pthread_exit',
  52. 'pthread_mutex_lock', 'pthread_mutex_unlock', 'pthread_cond_wait',
  53. 'pthread_cond_signal',
  54. # C++ 常用
  55. 'std', 'make_shared', 'make_unique', 'move', 'forward', 'swap',
  56. 'begin', 'end', 'size', 'empty', 'push_back', 'pop_back', 'front',
  57. 'back', 'insert', 'erase', 'clear', 'find', 'count', 'sort',
  58. 'unique', 'reverse', 'copy', 'fill', 'transform', 'accumulate',
  59. # 类型检查
  60. 'static_assert', 'ASSERT', 'DCHECK', 'CHECK', 'EXPECT', 'VERIFY',
  61. # 日志
  62. 'LOG', 'DLOG', 'VLOG', 'ERR', 'WARN', 'INFO', 'DEBUG', 'TRACE',
  63. # 其他常见宏/函数
  64. 'DISALLOW_COPY_AND_ASSIGN', 'NOTREACHED', 'UNIMPLEMENTED',
  65. 'offsetof', 'container_of', 'likely', 'unlikely', 'BUG', 'BUG_ON',
  66. 'WARN_ON', 'IS_ERR', 'PTR_ERR', 'ERR_PTR', 'ERR_CAST',
  67. # 测试相关
  68. 'TEST', 'TEST_F', 'TEST_P', 'EXPECT_TRUE', 'EXPECT_FALSE',
  69. 'EXPECT_EQ', 'EXPECT_NE', 'EXPECT_LT', 'EXPECT_LE', 'EXPECT_GT',
  70. 'EXPECT_GE', 'ASSERT_TRUE', 'ASSERT_FALSE', 'ASSERT_EQ', 'ASSERT_NE',
  71. 'MOCK_METHOD', 'INSTANTIATE_TEST_SUITE_P',
  72. }
  73. def extract_function_name(func_code: str) -> Optional[str]:
  74. """
  75. 从函数代码中提取函数名。
  76. 支持 C/C++ 风格的函数定义。
  77. """
  78. # 移除注释
  79. code = re.sub(r'//.*?\n', '\n', func_code)
  80. code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
  81. # 匹配函数定义的模式
  82. # 格式: [返回类型] [类名::]函数名(参数列表)
  83. patterns = [
  84. # C++ 成员函数: ReturnType ClassName::FunctionName(...)
  85. r'(?:[\w\s\*&<>,]+?)\s+(\w+::~?\w+)\s*\([^)]*\)\s*(?:const)?\s*(?:override)?\s*(?:final)?\s*(?:\{|:)',
  86. # 构造函数/析构函数: ClassName::ClassName(...) 或 ClassName::~ClassName(...)
  87. r'^[\s]*(\w+::~?\w+)\s*\([^)]*\)\s*(?:\{|:)',
  88. # 普通 C 函数: ReturnType FunctionName(...)
  89. r'(?:[\w\s\*&<>,]+?)\s+(\w+)\s*\([^)]*\)\s*\{',
  90. # 简单模式
  91. r'^\s*(?:static\s+)?(?:inline\s+)?(?:virtual\s+)?(?:[\w\*&<>,\s]+)\s+(\w+)\s*\(',
  92. ]
  93. for pattern in patterns:
  94. match = re.search(pattern, code, re.MULTILINE)
  95. if match:
  96. func_name = match.group(1)
  97. # 如果是 ClassName::FunctionName 格式,只取函数名
  98. if '::' in func_name:
  99. func_name = func_name.split('::')[-1]
  100. return func_name
  101. return None
  102. def extract_function_calls(
  103. func_code: str,
  104. self_name: Optional[str] = None,
  105. exclude_common_libs: bool = True
  106. ) -> Set[str]:
  107. """
  108. 从函数代码中提取所有被调用的函数名(callees)。
  109. Args:
  110. func_code: 函数代码
  111. self_name: 当前函数名(会被排除)
  112. exclude_common_libs: 是否排除常见库函数
  113. """
  114. # 移除注释和字符串
  115. code = re.sub(r'//.*?\n', '\n', func_code)
  116. code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
  117. code = re.sub(r'"(?:[^"\\]|\\.)*"', '""', code) # 移除字符串
  118. code = re.sub(r"'(?:[^'\\]|\\.)*'", "''", code) # 移除字符
  119. # 提取函数调用: 函数名(
  120. # 排除关键字和常见的非函数调用
  121. keywords = {
  122. 'if', 'else', 'while', 'for', 'switch', 'case', 'return', 'break',
  123. 'continue', 'sizeof', 'typeof', 'alignof', 'decltype', 'static_cast',
  124. 'dynamic_cast', 'reinterpret_cast', 'const_cast', 'new', 'delete',
  125. 'throw', 'catch', 'try', 'namespace', 'class', 'struct', 'enum',
  126. 'union', 'typedef', 'using', 'template', 'typename', 'public',
  127. 'private', 'protected', 'virtual', 'override', 'final', 'explicit',
  128. 'inline', 'static', 'extern', 'const', 'volatile', 'mutable',
  129. 'register', 'auto', 'default', 'goto', 'asm', '__asm', '__asm__',
  130. }
  131. # 匹配函数调用
  132. pattern = r'\b([a-zA-Z_]\w*)\s*\('
  133. matches = re.findall(pattern, code)
  134. # 过滤关键字、自身和常见库函数
  135. callees = set()
  136. for name in matches:
  137. if name in keywords:
  138. continue
  139. if self_name is not None and name == self_name:
  140. continue
  141. if exclude_common_libs and name in COMMON_LIB_FUNCTIONS:
  142. continue
  143. callees.add(name)
  144. return callees
  145. def load_jsonl(file_path: str) -> List[Dict]:
  146. """
  147. 加载 JSONL 文件。
  148. """
  149. data = []
  150. with open(file_path, 'r', encoding='utf-8') as f:
  151. for line in f:
  152. line = line.strip()
  153. if line:
  154. data.append(json.loads(line))
  155. return data
  156. def build_call_graph(
  157. records: List[Dict],
  158. exclude_common_libs: bool = True
  159. ) -> Tuple[Dict[str, Set[str]], Dict[int, str], Dict[str, List[int]]]:
  160. """
  161. 构建函数调用图。
  162. Args:
  163. records: 数据记录列表
  164. exclude_common_libs: 是否排除常见库函数
  165. 返回:
  166. - call_graph: {函数名: {被调用的函数名集合}}
  167. - idx_to_func: {记录索引: 函数名}
  168. - func_to_idxs: {函数名: [记录索引列表]}(一个函数名可能对应多条记录)
  169. """
  170. call_graph = {}
  171. idx_to_func = {}
  172. func_to_idxs = defaultdict(list)
  173. for i, record in enumerate(records):
  174. func_code = record.get('func', '')
  175. func_name = extract_function_name(func_code)
  176. if func_name:
  177. callees = extract_function_calls(func_code, func_name, exclude_common_libs)
  178. call_graph[func_name] = callees
  179. idx_to_func[i] = func_name
  180. func_to_idxs[func_name].append(i)
  181. return call_graph, idx_to_func, func_to_idxs
  182. def find_high_frequency_functions(
  183. call_graph: Dict[str, Set[str]],
  184. all_funcs: Set[str],
  185. threshold_percentile: float = 99.0
  186. ) -> Set[str]:
  187. """
  188. 找出被高频调用的函数(可能是通用工具函数)。
  189. Args:
  190. call_graph: 函数调用图
  191. all_funcs: 数据集中的所有函数名
  192. threshold_percentile: 阈值百分位数(默认 99%)
  193. Returns:
  194. 高频被调用的函数集合
  195. """
  196. # 统计每个函数被调用的次数
  197. callee_count = defaultdict(int)
  198. for callees in call_graph.values():
  199. for callee in callees:
  200. if callee in all_funcs:
  201. callee_count[callee] += 1
  202. if not callee_count:
  203. return set()
  204. # 计算阈值
  205. counts = sorted(callee_count.values())
  206. threshold_idx = int(len(counts) * threshold_percentile / 100)
  207. threshold = counts[min(threshold_idx, len(counts) - 1)]
  208. # 只有当阈值大于某个最小值时才过滤(避免过滤掉正常的调用关系)
  209. if threshold < 10:
  210. return set()
  211. high_freq_funcs = {fn for fn, count in callee_count.items() if count >= threshold}
  212. return high_freq_funcs
  213. def find_related_groups(
  214. records: List[Dict],
  215. call_graph: Dict[str, Set[str]],
  216. func_to_idxs: Dict[str, List[int]],
  217. auto_filter_high_freq: bool = True,
  218. high_freq_threshold: float = 99.0
  219. ) -> List[List[Dict]]:
  220. """
  221. 找出有调用关系的函数组。
  222. 使用 Union-Find 算法将有调用关系的函数合并。
  223. Args:
  224. records: 数据记录列表
  225. call_graph: 函数调用图
  226. func_to_idxs: 函数名到记录索引的映射
  227. auto_filter_high_freq: 是否自动过滤高频调用的函数
  228. high_freq_threshold: 高频函数的阈值百分位数
  229. """
  230. # 获取所有函数名
  231. all_funcs = set(call_graph.keys())
  232. # 找出高频被调用的函数
  233. high_freq_funcs = set()
  234. if auto_filter_high_freq:
  235. high_freq_funcs = find_high_frequency_functions(
  236. call_graph, all_funcs, high_freq_threshold
  237. )
  238. if high_freq_funcs:
  239. print(f" 自动过滤 {len(high_freq_funcs)} 个高频被调用的函数")
  240. # 只保留在数据集中实际存在的调用关系
  241. # 构建双向关系图(caller -> callee, callee -> caller)
  242. related_graph = defaultdict(set)
  243. for caller, callees in call_graph.items():
  244. for callee in callees:
  245. # 只有当 callee 也在我们的数据集中时才建立关系
  246. # 排除高频被调用的函数
  247. if callee in all_funcs and callee not in high_freq_funcs:
  248. related_graph[caller].add(callee)
  249. related_graph[callee].add(caller)
  250. # 使用 BFS/DFS 找连通分量
  251. visited = set()
  252. groups = []
  253. for func_name in all_funcs:
  254. if func_name not in visited:
  255. # BFS 找到所有连通的函数
  256. group_funcs = set()
  257. queue = [func_name]
  258. while queue:
  259. current = queue.pop(0)
  260. if current in visited:
  261. continue
  262. visited.add(current)
  263. group_funcs.add(current)
  264. # 添加相关的函数
  265. for related in related_graph.get(current, []):
  266. if related not in visited:
  267. queue.append(related)
  268. # 将函数名转换为对应的记录
  269. group_records = []
  270. for fn in group_funcs:
  271. for idx in func_to_idxs.get(fn, []):
  272. group_records.append(records[idx])
  273. if group_records:
  274. groups.append(group_records)
  275. return groups
  276. def process_file(
  277. input_path: str,
  278. output_path: str,
  279. min_group_size: int = 1,
  280. max_group_size: int = 0,
  281. exclude_common_libs: bool = True
  282. ):
  283. """
  284. 处理单个 JSONL 文件。
  285. Args:
  286. input_path: 输入文件路径
  287. output_path: 输出文件路径
  288. min_group_size: 最小组大小(默认为1,可设置为2只保留有调用关系的组)
  289. max_group_size: 最大组大小(0表示不限制,超过此大小的组会被拆分为单独的记录)
  290. exclude_common_libs: 是否排除常见库函数
  291. """
  292. print(f"加载数据: {input_path}")
  293. records = load_jsonl(input_path)
  294. print(f"共加载 {len(records)} 条记录")
  295. print("构建函数调用图...")
  296. call_graph, idx_to_func, func_to_idxs = build_call_graph(records, exclude_common_libs)
  297. print(f"识别出 {len(call_graph)} 个函数")
  298. print("分析调用关系,合并相关函数...")
  299. groups = find_related_groups(
  300. records, call_graph, func_to_idxs,
  301. auto_filter_high_freq=True,
  302. high_freq_threshold=99.0
  303. )
  304. # 处理超大组:如果设置了 max_group_size,将超大组拆分为单独的记录
  305. if max_group_size > 0:
  306. new_groups = []
  307. oversized_count = 0
  308. for g in groups:
  309. if len(g) > max_group_size:
  310. oversized_count += 1
  311. # 将超大组中的每个记录拆分为单独的组
  312. for record in g:
  313. new_groups.append([record])
  314. else:
  315. new_groups.append(g)
  316. if oversized_count > 0:
  317. print(f" (已将 {oversized_count} 个超大组拆分为单独记录)")
  318. groups = new_groups
  319. # 按组大小过滤
  320. if min_group_size > 1:
  321. groups = [g for g in groups if len(g) >= min_group_size]
  322. # 统计信息
  323. total_funcs = sum(len(g) for g in groups)
  324. groups_with_relations = [g for g in groups if len(g) > 1]
  325. single_func_groups = len([g for g in groups if len(g) == 1])
  326. # 按组大小分布统计
  327. size_distribution = defaultdict(int)
  328. for g in groups:
  329. size = len(g)
  330. if size == 1:
  331. size_distribution["1 (单独函数)"] += 1
  332. elif size <= 5:
  333. size_distribution["2-5"] += 1
  334. elif size <= 10:
  335. size_distribution["6-10"] += 1
  336. elif size <= 50:
  337. size_distribution["11-50"] += 1
  338. elif size <= 100:
  339. size_distribution["51-100"] += 1
  340. elif size <= 500:
  341. size_distribution["101-500"] += 1
  342. elif size <= 1000:
  343. size_distribution["501-1000"] += 1
  344. else:
  345. size_distribution["1000+"] += 1
  346. print(f"\n==================== 统计信息 ====================")
  347. print(f" 总记录数(原始): {len(records)}")
  348. print(f" 总函数数(分组后): {total_funcs}")
  349. print(f" 总组数: {len(groups)}")
  350. print(f" - 单独函数组(无调用关系): {single_func_groups}")
  351. print(f" - 有调用关系的组(大小>1): {len(groups_with_relations)}")
  352. if groups_with_relations:
  353. actual_max_size = max(len(g) for g in groups_with_relations)
  354. avg_group_size = sum(len(g) for g in groups_with_relations) / len(groups_with_relations)
  355. print(f" 最大组大小: {actual_max_size}")
  356. print(f" 有关系组的平均大小: {avg_group_size:.2f}")
  357. print(f"\n 组大小分布:")
  358. # 按特定顺序输出
  359. order = ["1 (单独函数)", "2-5", "6-10", "11-50", "51-100", "101-500", "501-1000", "1000+"]
  360. for key in order:
  361. if key in size_distribution:
  362. count = size_distribution[key]
  363. percentage = count / len(groups) * 100
  364. print(f" - 大小 {key}: {count} 组 ({percentage:.1f}%)")
  365. print(f"====================================================")
  366. # 输出结果
  367. output_data = {
  368. "metadata": {
  369. "source_file": os.path.basename(input_path),
  370. "total_records": len(records),
  371. "total_functions_grouped": total_funcs,
  372. "total_groups": len(groups),
  373. "single_function_groups": single_func_groups,
  374. "groups_with_relations": len(groups_with_relations),
  375. "max_group_size": max(len(g) for g in groups) if groups else 0,
  376. "avg_related_group_size": round(sum(len(g) for g in groups_with_relations) / len(groups_with_relations), 2) if groups_with_relations else 0,
  377. "size_distribution": dict(size_distribution),
  378. },
  379. "groups": groups
  380. }
  381. os.makedirs(os.path.dirname(output_path), exist_ok=True)
  382. with open(output_path, 'w', encoding='utf-8') as f:
  383. json.dump(output_data, f, ensure_ascii=False, indent=2)
  384. print(f"\n结果已保存到: {output_path}")
  385. def main():
  386. parser = argparse.ArgumentParser(description='分析代码函数的调用关系')
  387. parser.add_argument(
  388. '--input', '-i',
  389. type=str,
  390. required=True,
  391. help='输入的 JSONL 文件路径'
  392. )
  393. parser.add_argument(
  394. '--output', '-o',
  395. type=str,
  396. default=None,
  397. help='输出的 JSON 文件路径(默认为 output/<输入文件名>_grouped.json)'
  398. )
  399. parser.add_argument(
  400. '--min-group-size', '-m',
  401. type=int,
  402. default=1,
  403. help='最小组大小,设为2可只保留有调用关系的组(默认为1)'
  404. )
  405. parser.add_argument(
  406. '--max-group-size', '-M',
  407. type=int,
  408. default=0,
  409. help='最大组大小,超过此大小的组会被拆分(0表示不限制,默认为0)'
  410. )
  411. parser.add_argument(
  412. '--include-common-libs',
  413. action='store_true',
  414. default=False,
  415. help='是否包含常见库函数作为调用关系(默认排除)'
  416. )
  417. args = parser.parse_args()
  418. # 设置默认输出路径
  419. if args.output is None:
  420. base_name = os.path.splitext(os.path.basename(args.input))[0]
  421. # 获取脚本所在目录的上两级(项目根目录)
  422. script_dir = os.path.dirname(os.path.abspath(__file__))
  423. project_root = os.path.dirname(os.path.dirname(script_dir))
  424. args.output = os.path.join(project_root, 'output', f'{base_name}_grouped.json')
  425. process_file(
  426. args.input,
  427. args.output,
  428. args.min_group_size,
  429. args.max_group_size,
  430. exclude_common_libs=not args.include_common_libs
  431. )
  432. if __name__ == '__main__':
  433. main()