profile.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import hashlib
  17. import time
  18. import typing
  19. import beautifultable
  20. from fate_arch.common.log import getLogger
  21. import inspect
  22. from functools import wraps
  23. from fate_arch.abc import CTableABC
  24. profile_logger = getLogger("PROFILING")
  25. _PROFILE_LOG_ENABLED = False
  26. _START_TIME = None
  27. _END_TIME = None
  28. class _TimerItem(object):
  29. def __init__(self):
  30. self.count = 0
  31. self.total_time = 0.0
  32. self.max_time = 0.0
  33. def union(self, other: '_TimerItem'):
  34. self.count += other.count
  35. self.total_time += other.total_time
  36. if self.max_time < other.max_time:
  37. self.max_time = other.max_time
  38. def add(self, elapse_time):
  39. self.count += 1
  40. self.total_time += elapse_time
  41. if elapse_time > self.max_time:
  42. self.max_time = elapse_time
  43. @property
  44. def mean(self):
  45. if self.count == 0:
  46. return 0.0
  47. return self.total_time / self.count
  48. def as_list(self):
  49. return [self.count, self.total_time, self.mean, self.max_time]
  50. def __str__(self):
  51. return f"n={self.count}, sum={self.total_time:.4f}, mean={self.mean:.4f}, max={self.max_time:.4f}"
  52. def __repr__(self):
  53. return self.__str__()
  54. class _ComputingTimerItem(object):
  55. def __init__(self, function_name: str, function_stack):
  56. self.function_name = function_name
  57. self.function_stack = function_stack
  58. self.item = _TimerItem()
  59. class _ComputingTimer(object):
  60. _STATS: typing.MutableMapping[str, _ComputingTimerItem] = {}
  61. def __init__(self, function_name: str, function_stack_list):
  62. self._start = time.time()
  63. function_stack = "\n".join(function_stack_list)
  64. self._hash = hashlib.blake2b(function_stack.encode('utf-8'), digest_size=5).hexdigest()
  65. if self._hash not in self._STATS:
  66. self._STATS[self._hash] = _ComputingTimerItem(function_name, function_stack)
  67. if _PROFILE_LOG_ENABLED:
  68. profile_logger.debug(f"[computing#{self._hash}]function_stack: {' <-'.join(function_stack_list)}")
  69. if _PROFILE_LOG_ENABLED:
  70. profile_logger.debug(f"[computing#{self._hash}]start")
  71. def done(self, function_string):
  72. elapse = time.time() - self._start
  73. self._STATS[self._hash].item.add(elapse)
  74. if _PROFILE_LOG_ENABLED:
  75. profile_logger.debug(f"[computing#{self._hash}]done, elapse: {elapse}, function: {function_string}")
  76. @classmethod
  77. def computing_statistics_table(cls, timer_aggregator: _TimerItem = None):
  78. stack_table = beautifultable.BeautifulTable(110, precision=4, detect_numerics=False)
  79. stack_table.columns.header = ["function", "n", "sum(s)", "mean(s)", "max(s)", "stack_hash", "stack"]
  80. stack_table.columns.alignment["stack"] = beautifultable.ALIGN_LEFT
  81. stack_table.columns.header.alignment = beautifultable.ALIGN_CENTER
  82. stack_table.border.left = ''
  83. stack_table.border.right = ''
  84. stack_table.border.bottom = ''
  85. stack_table.border.top = ''
  86. function_table = beautifultable.BeautifulTable(110)
  87. function_table.set_style(beautifultable.STYLE_COMPACT)
  88. function_table.columns.header = ["function", "n", "sum(s)", "mean(s)", "max(s)"]
  89. aggregate = {}
  90. total = _TimerItem()
  91. for hash_id, timer in cls._STATS.items():
  92. stack_table.rows.append([timer.function_name, *timer.item.as_list(), hash_id, timer.function_stack])
  93. aggregate.setdefault(timer.function_name, _TimerItem()).union(timer.item)
  94. total.union(timer.item)
  95. for function_name, item in aggregate.items():
  96. function_table.rows.append([function_name, *item.as_list()])
  97. detailed_base_table = beautifultable.BeautifulTable(120)
  98. stack_table.rows.sort("sum(s)", reverse=True)
  99. detailed_base_table.rows.append(["stack", stack_table])
  100. detailed_base_table.rows.append(["total", total])
  101. base_table = beautifultable.BeautifulTable(120)
  102. function_table.rows.sort("sum(s)", reverse=True)
  103. base_table.rows.append(["function", function_table])
  104. base_table.rows.append(["total", total])
  105. if timer_aggregator:
  106. timer_aggregator.union(total)
  107. return base_table.get_string(), detailed_base_table.get_string()
  108. class _FederationTimer(object):
  109. _GET_STATS: typing.MutableMapping[str, _TimerItem] = {}
  110. _REMOTE_STATS: typing.MutableMapping[str, _TimerItem] = {}
  111. @classmethod
  112. def federation_statistics_table(cls, timer_aggregator: _TimerItem = None):
  113. total = _TimerItem()
  114. get_table = beautifultable.BeautifulTable(110)
  115. get_table.columns.header = ["name", "n", "sum(s)", "mean(s)", "max(s)"]
  116. for name, item in cls._GET_STATS.items():
  117. get_table.rows.append([name, *item.as_list()])
  118. total.union(item)
  119. get_table.rows.sort("sum(s)", reverse=True)
  120. get_table.border.left = ''
  121. get_table.border.right = ''
  122. get_table.border.bottom = ''
  123. get_table.border.top = ''
  124. remote_table = beautifultable.BeautifulTable(110)
  125. remote_table.columns.header = ["name", "n", "sum(s)", "mean(s)", "max(s)"]
  126. for name, item in cls._REMOTE_STATS.items():
  127. remote_table.rows.append([name, *item.as_list()])
  128. total.union(item)
  129. remote_table.rows.sort("sum(s)", reverse=True)
  130. remote_table.border.left = ''
  131. remote_table.border.right = ''
  132. remote_table.border.bottom = ''
  133. remote_table.border.top = ''
  134. base_table = beautifultable.BeautifulTable(120)
  135. base_table.rows.append(["get", get_table])
  136. base_table.rows.append(["remote", remote_table])
  137. base_table.rows.append(["total", total])
  138. if timer_aggregator:
  139. timer_aggregator.union(total)
  140. return base_table.get_string()
  141. class _FederationRemoteTimer(_FederationTimer):
  142. def __init__(self, name, full_name, tag, local, parties):
  143. self._name = name
  144. self._full_name = full_name
  145. self._tag = tag
  146. self._local_party = local
  147. self._parties = parties
  148. self._start_time = time.time()
  149. self._end_time = None
  150. if self._full_name not in self._REMOTE_STATS:
  151. self._REMOTE_STATS[self._full_name] = _TimerItem()
  152. def done(self, federation):
  153. self._end_time = time.time()
  154. self._REMOTE_STATS[self._full_name].add(self.elapse)
  155. profile_logger.debug(f"[federation.remote.{self._full_name}.{self._tag}]"
  156. f"{self._local_party}->{self._parties} done")
  157. if is_profile_remote_enable():
  158. federation.remote(v={"start_time": self._start_time, "end_time": self._end_time},
  159. name=self._name,
  160. tag=profile_remote_tag(self._tag),
  161. parties=self._parties,
  162. gc=None)
  163. @property
  164. def elapse(self):
  165. return self._end_time - self._start_time
  166. class _FederationGetTimer(_FederationTimer):
  167. def __init__(self, name, full_name, tag, local, parties):
  168. self._name = name
  169. self._full_name = full_name
  170. self._tag = tag
  171. self._local_party = local
  172. self._parties = parties
  173. self._start_time = time.time()
  174. self._end_time = None
  175. if self._full_name not in self._GET_STATS:
  176. self._GET_STATS[self._full_name] = _TimerItem()
  177. def done(self, federation):
  178. self._end_time = time.time()
  179. self._GET_STATS[self._full_name].add(self.elapse)
  180. profile_logger.debug(f"[federation.get.{self._full_name}.{self._tag}]"
  181. f"{self._local_party}<-{self._parties} done")
  182. if is_profile_remote_enable():
  183. remote_meta = federation.get(name=self._name, tag=profile_remote_tag(self._tag), parties=self._parties,
  184. gc=None)
  185. for party, meta in zip(self._parties, remote_meta):
  186. profile_logger.debug(f"[federation.meta.{self._full_name}.{self._tag}]{self._local_party}<-{party}]"
  187. f"meta={meta}")
  188. @property
  189. def elapse(self):
  190. return self._end_time - self._start_time
  191. def federation_remote_timer(name, full_name, tag, local, parties):
  192. profile_logger.debug(f"[federation.remote.{full_name}.{tag}]{local}->{parties} start")
  193. return _FederationRemoteTimer(name, full_name, tag, local, parties)
  194. def federation_get_timer(name, full_name, tag, local, parties):
  195. profile_logger.debug(f"[federation.get.{full_name}.{tag}]{local}<-{parties} start")
  196. return _FederationGetTimer(name, full_name, tag, local, parties)
  197. def profile_start():
  198. global _PROFILE_LOG_ENABLED
  199. _PROFILE_LOG_ENABLED = True
  200. global _START_TIME
  201. _START_TIME = time.time()
  202. def profile_ends():
  203. global _END_TIME
  204. _END_TIME = time.time()
  205. profile_total_time = _END_TIME - _START_TIME
  206. # gather computing and federation profile statistics
  207. timer_aggregator = _TimerItem()
  208. computing_timer_aggregator = _TimerItem()
  209. federation_timer_aggregator = _TimerItem()
  210. computing_base_table, computing_detailed_table = _ComputingTimer.computing_statistics_table(
  211. timer_aggregator=computing_timer_aggregator)
  212. federation_base_table = _FederationTimer.federation_statistics_table(timer_aggregator=federation_timer_aggregator)
  213. timer_aggregator.union(computing_timer_aggregator)
  214. timer_aggregator.union(federation_timer_aggregator)
  215. # logging
  216. profile_driver_time = profile_total_time - timer_aggregator.total_time
  217. profile_logger.info(
  218. "Total: {:.4f}s, Driver: {:.4f}s({:.2%}), Federation: {:.4f}s({:.2%}), Computing: {:.4f}s({:.2%})".format(
  219. profile_total_time,
  220. profile_driver_time,
  221. profile_driver_time / profile_total_time,
  222. federation_timer_aggregator.total_time,
  223. federation_timer_aggregator.total_time / profile_total_time,
  224. computing_timer_aggregator.total_time,
  225. computing_timer_aggregator.total_time / profile_total_time
  226. )
  227. )
  228. profile_logger.info(f"\nComputing:\n{computing_base_table}\n\nFederation:\n{federation_base_table}\n")
  229. profile_logger.debug(f"\nDetailed Computing:\n{computing_detailed_table}\n")
  230. global _PROFILE_LOG_ENABLED
  231. _PROFILE_LOG_ENABLED = False
  232. def _pretty_table_str(v):
  233. if isinstance(v, CTableABC):
  234. return f"Table(partition={v.partitions})"
  235. else:
  236. return f"{type(v).__name__}"
  237. def _func_annotated_string(func, *args, **kwargs):
  238. pretty_args = []
  239. for k, v in inspect.signature(func).bind(*args, **kwargs).arguments.items():
  240. pretty_args.append(f"{k}: {_pretty_table_str(v)}")
  241. return f"{func.__name__}({', '.join(pretty_args)})"
  242. def _call_stack_strings():
  243. call_stack_strings = []
  244. frames = inspect.getouterframes(inspect.currentframe(), 10)[2:-2]
  245. for frame in frames:
  246. call_stack_strings.append(f"[{frame.filename.split('/')[-1]}:{frame.lineno}]{frame.function}")
  247. return call_stack_strings
  248. def computing_profile(func):
  249. @wraps(func)
  250. def _fn(*args, **kwargs):
  251. function_call_stack = _call_stack_strings()
  252. timer = _ComputingTimer(func.__name__, function_call_stack)
  253. rtn = func(*args, **kwargs)
  254. function_string = f"{_func_annotated_string(func, *args, **kwargs)} -> {_pretty_table_str(rtn)}"
  255. timer.done(function_string)
  256. return rtn
  257. return _fn
  258. __META_REMOTE_ENABLE = False
  259. def enable_profile_remote():
  260. global __META_REMOTE_ENABLE
  261. __META_REMOTE_ENABLE = True
  262. def is_profile_remote_enable():
  263. return __META_REMOTE_ENABLE
  264. def profile_remote_tag(tag):
  265. return f"<remote_profile>_{tag}"