123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242 |
- from easyfl.tracking.metric import TaskMetric, RoundMetric, ClientMetric
- from easyfl.tracking.storage import get_store
- class TrackingClient(object):
- """Client for tracking task metrics, round metrics, and client metrics.
- Task Tracking:
- client.create_task(task_id, conf)
- Round Tracking:
- client.track_round(name, value)
- client.save_round() # auto increment to next round
- client.track_round(name, value)
- Client Tracking:
- client.set_client_context(task_id, round_id, client_id)
- client.track_client(name, value)
- """
- def __init__(self, db_path=None, db_address=None, init_store=True):
- """If storage is not initialized, the tracking client can only collect metrics but not save them.
- """
- self._task_id = None
- self._round_id = None
- self._client_id = None
- self._current_task = None
- self._current_round = None
- self._current_client = None
- self._cached_task_metrics = {}
- self._cached_round_metrics = {}
- if init_store:
- self._storage = get_store(db_path, db_address)
- def get_task_metric(self, task_id):
- """Get task from storage
- """
- task_metric = self._storage.get_task_metric(task_id)
- if task_metric is None:
- return
- return TaskMetric.from_sql(task_metric)
- def get_round_metric(self, round_id, task_id):
- if task_id == self._task_id and round_id == self._round_id:
- return self._current_round
- return self._storage.get_round_metrics(task_id, [round_id])
- def get_client_metric(self, client_id=None, round_id=None, task_id=None):
- if (task_id == self._task_id and round_id == self._round_id and client_id == self._client_id) or \
- (client_id is None and round_id is None and task_id is None):
- return self._current_client
- return self._storage.get_client_metrics(task_id, round_id, [client_id])
- def get_client_metrics(self, client_ids, round_id, task_id):
- """Get list of client metrics.
- :param client_ids: list of client ids.
- :param round_id: round id.
- :param task_id: task id.
- """
- return self._storage.get_client_metrics(task_id, round_id, client_ids)
- def create_task(self, task_id, conf=None, save=True):
- if task_id is None:
- raise ValueError("task_id cannot be None to create task")
- self._task_id = task_id
- self._current_task = TaskMetric(task_id, conf)
- if save:
- self._storage.store_task_metric(self._current_task)
- def create_round(self, round_id, task_id=None):
- if task_id is None:
- task_id = self._task_id
- if round_id is None:
- raise ValueError("round_id cannot be None to create round")
- if round_id != self._round_id:
- self._round_id = round_id
- self._current_round = RoundMetric(task_id, self._round_id)
- def create_client(self, client_id, reset=True):
- """Create client under current round of task.
- Current implementation requires round and task exist to create client.
- """
- self._check_context()
- if client_id is None:
- raise ValueError("client_id cannot be None to create client.")
- if reset or not self._current_client or client_id != self._client_id:
- self._current_client = ClientMetric(self._task_id, self._round_id, client_id)
- self._client_id = client_id
- return
- self._current_client.task_id = self._task_id
- self._current_client.round_id = self._round_id
- def track_task(self, name, value, task_id=None):
- if self._diff_task_id(task_id):
- self._cached_task_metrics[self._task_id] = self._current_task
- self._task_id = task_id
- self._current_task = TaskMetric(task_id)
- self._current_task.add(name, value)
- self._storage.store_task_metric(self._current_task)
- def track_round(self, name, value, round_id=None, task_id=None):
- if self._diff_task_id(task_id):
- create_task(task_id)
- if self._diff_round_id(round_id):
- # self._cached_round_metrics[self.unique_round_id] = self._current_round
- self.create_round(round_id)
- if self._current_round is None:
- self.create_round(0)
- self._current_round.add(name, value)
- def track_client(self, name, value, client_id=None):
- """Track client under current round and task.
- Current implementation requires round and task exist to track client.
- """
- self._check_context()
- if self._diff_client_id(client_id) or self._current_client is None:
- self.create_client(client_id)
- self._current_client.add(name, value)
- def save_round(self, increment=True, cache=False):
- if self._current_round is None:
- raise ValueError("Round metric is not initialized")
- self._storage.store_round_metric(self._current_round)
- if cache:
- self._cached_round_metrics[self.unique_round_id] = self._current_round
- if increment:
- self.create_round(self._round_id + 1)
- def save_client(self):
- if self._current_client is None:
- raise ValueError("Client metric is not initialized")
- self._storage.store_client_metrics([self._current_client])
- def save_clients(self, client_metrics):
- self._storage.store_client_metrics(client_metrics)
- def set_task(self, task_id):
- if self._current_task is None:
- self.create_task(task_id, save=False)
- def set_round(self, round_id):
- self.create_round(round_id)
- def set_client_context(self, task_id, round_id, client_id, reset_client=True):
- """Set the client context for tracking.
- :param task_id: task id, indicating current the training task
- :param round_id: round id, indicating current round of training/testing
- :param client_id: client id
- :param reset_client: resets and creates a new client.
- """
- self.set_task(task_id)
- self.set_round(round_id)
- self.create_client(client_id, reset=reset_client)
- @property
- def unique_round_id(self):
- return f"{self._task_id}_{self._round_id}"
- def _diff_task_id(self, task_id):
- return task_id is not None and task_id != self._task_id
- def _diff_round_id(self, round_id):
- return round_id is not None and round_id != self._round_id
- def _diff_client_id(self, client_id):
- return client_id is not None and client_id != self._client_id
- def _check_context(self):
- if self._task_id is None or self._round_id is None:
- raise LookupError("task_id or round_id of the client is not set")
- _client = TrackingClient(init_store=False)
- """easyfl.tracking.TrackingClient: The global tracking client object"""
- def init_tracking(path=None, address=None, init_store=True):
- """Initialize tracking client. This tracking client is isolated from the global tracking client.
- This is useful when an application need to run multiple tasks.
- :param path: database path
- :param address: remote address of tracking service to connect to
- :param init_store: whether initialize storage
- """
- return TrackingClient(path, address, init_store)
- # ------ following methods are not well tested yet ------
- def setup_tracking(path=None, address=None):
- """Setup tracking with global tracking client.
- """
- global _client
- _client = init_tracking(path, address)
- def get_task(task_id):
- return _client.get_task_metric(task_id)
- def get_round(round_id, task_id):
- return _client.get_round_metric(round_id, task_id)
- def create_task(task_id, conf=None):
- _client.create_task(task_id, conf)
- def track_task(name, value, task_id=None):
- _client.track_task(name, value, task_id)
- def track_round(name, value, round_id=None, task_id=None):
- _client.track_round(name, value, round_id, task_id)
- def track_client(name, value, client_id=None):
- _client.track_client(name, value, client_id)
- def set_task(task_id):
- _client.set_task(task_id)
- def set_round(round_id):
- _client.set_round(round_id)
- def save_round():
- _client.save_round()
|