# # Copyright 2019 The FATE Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import hashlib import time import typing import beautifultable from fate_arch.common.log import getLogger import inspect from functools import wraps from fate_arch.abc import CTableABC profile_logger = getLogger("PROFILING") _PROFILE_LOG_ENABLED = False _START_TIME = None _END_TIME = None class _TimerItem(object): def __init__(self): self.count = 0 self.total_time = 0.0 self.max_time = 0.0 def union(self, other: '_TimerItem'): self.count += other.count self.total_time += other.total_time if self.max_time < other.max_time: self.max_time = other.max_time def add(self, elapse_time): self.count += 1 self.total_time += elapse_time if elapse_time > self.max_time: self.max_time = elapse_time @property def mean(self): if self.count == 0: return 0.0 return self.total_time / self.count def as_list(self): return [self.count, self.total_time, self.mean, self.max_time] def __str__(self): return f"n={self.count}, sum={self.total_time:.4f}, mean={self.mean:.4f}, max={self.max_time:.4f}" def __repr__(self): return self.__str__() class _ComputingTimerItem(object): def __init__(self, function_name: str, function_stack): self.function_name = function_name self.function_stack = function_stack self.item = _TimerItem() class _ComputingTimer(object): _STATS: typing.MutableMapping[str, _ComputingTimerItem] = {} def __init__(self, function_name: str, function_stack_list): self._start = time.time() function_stack = "\n".join(function_stack_list) self._hash = hashlib.blake2b(function_stack.encode('utf-8'), digest_size=5).hexdigest() if self._hash not in self._STATS: self._STATS[self._hash] = _ComputingTimerItem(function_name, function_stack) if _PROFILE_LOG_ENABLED: profile_logger.debug(f"[computing#{self._hash}]function_stack: {' <-'.join(function_stack_list)}") if _PROFILE_LOG_ENABLED: profile_logger.debug(f"[computing#{self._hash}]start") def done(self, function_string): elapse = time.time() - self._start self._STATS[self._hash].item.add(elapse) if _PROFILE_LOG_ENABLED: profile_logger.debug(f"[computing#{self._hash}]done, elapse: {elapse}, function: {function_string}") @classmethod def computing_statistics_table(cls, timer_aggregator: _TimerItem = None): stack_table = beautifultable.BeautifulTable(110, precision=4, detect_numerics=False) stack_table.columns.header = ["function", "n", "sum(s)", "mean(s)", "max(s)", "stack_hash", "stack"] stack_table.columns.alignment["stack"] = beautifultable.ALIGN_LEFT stack_table.columns.header.alignment = beautifultable.ALIGN_CENTER stack_table.border.left = '' stack_table.border.right = '' stack_table.border.bottom = '' stack_table.border.top = '' function_table = beautifultable.BeautifulTable(110) function_table.set_style(beautifultable.STYLE_COMPACT) function_table.columns.header = ["function", "n", "sum(s)", "mean(s)", "max(s)"] aggregate = {} total = _TimerItem() for hash_id, timer in cls._STATS.items(): stack_table.rows.append([timer.function_name, *timer.item.as_list(), hash_id, timer.function_stack]) aggregate.setdefault(timer.function_name, _TimerItem()).union(timer.item) total.union(timer.item) for function_name, item in aggregate.items(): function_table.rows.append([function_name, *item.as_list()]) detailed_base_table = beautifultable.BeautifulTable(120) stack_table.rows.sort("sum(s)", reverse=True) detailed_base_table.rows.append(["stack", stack_table]) detailed_base_table.rows.append(["total", total]) base_table = beautifultable.BeautifulTable(120) function_table.rows.sort("sum(s)", reverse=True) base_table.rows.append(["function", function_table]) base_table.rows.append(["total", total]) if timer_aggregator: timer_aggregator.union(total) return base_table.get_string(), detailed_base_table.get_string() class _FederationTimer(object): _GET_STATS: typing.MutableMapping[str, _TimerItem] = {} _REMOTE_STATS: typing.MutableMapping[str, _TimerItem] = {} @classmethod def federation_statistics_table(cls, timer_aggregator: _TimerItem = None): total = _TimerItem() get_table = beautifultable.BeautifulTable(110) get_table.columns.header = ["name", "n", "sum(s)", "mean(s)", "max(s)"] for name, item in cls._GET_STATS.items(): get_table.rows.append([name, *item.as_list()]) total.union(item) get_table.rows.sort("sum(s)", reverse=True) get_table.border.left = '' get_table.border.right = '' get_table.border.bottom = '' get_table.border.top = '' remote_table = beautifultable.BeautifulTable(110) remote_table.columns.header = ["name", "n", "sum(s)", "mean(s)", "max(s)"] for name, item in cls._REMOTE_STATS.items(): remote_table.rows.append([name, *item.as_list()]) total.union(item) remote_table.rows.sort("sum(s)", reverse=True) remote_table.border.left = '' remote_table.border.right = '' remote_table.border.bottom = '' remote_table.border.top = '' base_table = beautifultable.BeautifulTable(120) base_table.rows.append(["get", get_table]) base_table.rows.append(["remote", remote_table]) base_table.rows.append(["total", total]) if timer_aggregator: timer_aggregator.union(total) return base_table.get_string() class _FederationRemoteTimer(_FederationTimer): def __init__(self, name, full_name, tag, local, parties): self._name = name self._full_name = full_name self._tag = tag self._local_party = local self._parties = parties self._start_time = time.time() self._end_time = None if self._full_name not in self._REMOTE_STATS: self._REMOTE_STATS[self._full_name] = _TimerItem() def done(self, federation): self._end_time = time.time() self._REMOTE_STATS[self._full_name].add(self.elapse) profile_logger.debug(f"[federation.remote.{self._full_name}.{self._tag}]" f"{self._local_party}->{self._parties} done") if is_profile_remote_enable(): federation.remote(v={"start_time": self._start_time, "end_time": self._end_time}, name=self._name, tag=profile_remote_tag(self._tag), parties=self._parties, gc=None) @property def elapse(self): return self._end_time - self._start_time class _FederationGetTimer(_FederationTimer): def __init__(self, name, full_name, tag, local, parties): self._name = name self._full_name = full_name self._tag = tag self._local_party = local self._parties = parties self._start_time = time.time() self._end_time = None if self._full_name not in self._GET_STATS: self._GET_STATS[self._full_name] = _TimerItem() def done(self, federation): self._end_time = time.time() self._GET_STATS[self._full_name].add(self.elapse) profile_logger.debug(f"[federation.get.{self._full_name}.{self._tag}]" f"{self._local_party}<-{self._parties} done") if is_profile_remote_enable(): remote_meta = federation.get(name=self._name, tag=profile_remote_tag(self._tag), parties=self._parties, gc=None) for party, meta in zip(self._parties, remote_meta): profile_logger.debug(f"[federation.meta.{self._full_name}.{self._tag}]{self._local_party}<-{party}]" f"meta={meta}") @property def elapse(self): return self._end_time - self._start_time def federation_remote_timer(name, full_name, tag, local, parties): profile_logger.debug(f"[federation.remote.{full_name}.{tag}]{local}->{parties} start") return _FederationRemoteTimer(name, full_name, tag, local, parties) def federation_get_timer(name, full_name, tag, local, parties): profile_logger.debug(f"[federation.get.{full_name}.{tag}]{local}<-{parties} start") return _FederationGetTimer(name, full_name, tag, local, parties) def profile_start(): global _PROFILE_LOG_ENABLED _PROFILE_LOG_ENABLED = True global _START_TIME _START_TIME = time.time() def profile_ends(): global _END_TIME _END_TIME = time.time() profile_total_time = _END_TIME - _START_TIME # gather computing and federation profile statistics timer_aggregator = _TimerItem() computing_timer_aggregator = _TimerItem() federation_timer_aggregator = _TimerItem() computing_base_table, computing_detailed_table = _ComputingTimer.computing_statistics_table( timer_aggregator=computing_timer_aggregator) federation_base_table = _FederationTimer.federation_statistics_table(timer_aggregator=federation_timer_aggregator) timer_aggregator.union(computing_timer_aggregator) timer_aggregator.union(federation_timer_aggregator) # logging profile_driver_time = profile_total_time - timer_aggregator.total_time profile_logger.info( "Total: {:.4f}s, Driver: {:.4f}s({:.2%}), Federation: {:.4f}s({:.2%}), Computing: {:.4f}s({:.2%})".format( profile_total_time, profile_driver_time, profile_driver_time / profile_total_time, federation_timer_aggregator.total_time, federation_timer_aggregator.total_time / profile_total_time, computing_timer_aggregator.total_time, computing_timer_aggregator.total_time / profile_total_time ) ) profile_logger.info(f"\nComputing:\n{computing_base_table}\n\nFederation:\n{federation_base_table}\n") profile_logger.debug(f"\nDetailed Computing:\n{computing_detailed_table}\n") global _PROFILE_LOG_ENABLED _PROFILE_LOG_ENABLED = False def _pretty_table_str(v): if isinstance(v, CTableABC): return f"Table(partition={v.partitions})" else: return f"{type(v).__name__}" def _func_annotated_string(func, *args, **kwargs): pretty_args = [] for k, v in inspect.signature(func).bind(*args, **kwargs).arguments.items(): pretty_args.append(f"{k}: {_pretty_table_str(v)}") return f"{func.__name__}({', '.join(pretty_args)})" def _call_stack_strings(): call_stack_strings = [] frames = inspect.getouterframes(inspect.currentframe(), 10)[2:-2] for frame in frames: call_stack_strings.append(f"[{frame.filename.split('/')[-1]}:{frame.lineno}]{frame.function}") return call_stack_strings def computing_profile(func): @wraps(func) def _fn(*args, **kwargs): function_call_stack = _call_stack_strings() timer = _ComputingTimer(func.__name__, function_call_stack) rtn = func(*args, **kwargs) function_string = f"{_func_annotated_string(func, *args, **kwargs)} -> {_pretty_table_str(rtn)}" timer.done(function_string) return rtn return _fn __META_REMOTE_ENABLE = False def enable_profile_remote(): global __META_REMOTE_ENABLE __META_REMOTE_ENABLE = True def is_profile_remote_enable(): return __META_REMOTE_ENABLE def profile_remote_tag(tag): return f"_{tag}"