123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136 |
- import os
- import unittest
- from easyfl.tracking import client
- from easyfl.tracking import metric
- from easyfl.tracking import storage
- class TrackingClientTest(unittest.TestCase):
- def __init__(self, *args, **kwargs):
- super(TrackingClientTest, self).__init__(*args, **kwargs)
- self._database = os.path.join(os.getcwd(), "tracker", "easyfl_test.db")
- self._store = storage.get_store(self._database)
- self._store.truncate_task_metric()
- self._store.truncate_round_metric()
- self._store.truncate_client_metric()
- def test_task(self):
- task_id = "test_task"
- conf = {"task_id": task_id}
- tracker = client.init_tracking(self._database)
- tracker.create_task(task_id, conf)
- m = tracker.get_task_metric(task_id)
- self.assertEqual(m.task_id, task_id)
- self.assertEqual(m.configuration, conf)
- m = tracker.get_task_metric("not_exist_task")
- self.assertEqual(m, None)
- def test_round(self):
- task_id = "test_round"
- round_id = 0
- want_accuracy = 0.9
- want_loss = 0.1
- want_train_upload_size = 10
- want_train_time = 20
- want_extra = {"extra": "information"}
- tracker = client.init_tracking(self._database)
- tracker.create_task(task_id)
- tracker.track_round(metric.TEST_ACCURACY, want_accuracy)
- tracker.track_round(metric.TEST_LOSS, want_loss)
- m = tracker.get_round_metric(round_id, task_id)
- self.assertEqual(m.task_id, task_id)
- self.assertEqual(m.round_id, round_id)
- self.assertEqual(m.test_accuracy, want_accuracy)
- self.assertEqual(m.test_loss, want_loss)
- tracker.save_round(increment=False)
- round_metrics = self._store.get_round_metrics(task_id, [round_id])
- m = metric.RoundMetric.from_sql(next(round_metrics))
- self.assertEqual(m.task_id, task_id)
- self.assertEqual(m.round_id, round_id)
- self.assertEqual(m.test_accuracy, want_accuracy)
- self.assertEqual(m.test_loss, want_loss)
- # round 1
- tracker.set_round(round_id + 1)
- tracker.track_round(metric.TRAIN_UPLOAD_SIZE, want_train_upload_size)
- tracker.track_round(metric.EXTRA, want_extra)
- tracker.save_round()
- round_metrics = self._store.get_round_metrics(task_id, [round_id + 1])
- m = metric.RoundMetric.from_sql(next(round_metrics))
- self.assertEqual(m.task_id, task_id)
- self.assertEqual(m.round_id, round_id + 1)
- self.assertEqual(m.test_accuracy, 0)
- self.assertEqual(m.test_loss, 0)
- self.assertEqual(m.train_upload_size, want_train_upload_size)
- self.assertEqual(m.extra["extra"], want_extra["extra"])
- # round 2
- tracker.track_round(metric.TRAIN_TIME, want_train_time)
- m = tracker.get_round_metric(round_id + 2, task_id)
- self.assertEqual(m.task_id, task_id)
- self.assertEqual(m.round_id, round_id + 2)
- self.assertEqual(m.train_time, want_train_time)
- self.assertEqual(m.test_accuracy, 0)
- self.assertEqual(m.test_loss, 0)
- self.assertEqual(m.extra, {})
- def test_client(self):
- task_id = "test_client"
- round_id = 1
- client_id = "client_id_test"
- client_id_2 = "client_id_test_2"
- want_accuracy = 0.9123456789
- want_mAP = 0.8
- want_rank1 = 0.7
- want_loss = 0.1
- want_extra = {"extra": "information"}
- # test error
- tracker = client.init_tracking(self._database)
- self.assertRaises(LookupError, tracker.create_client, client_id)
- self.assertRaises(LookupError, tracker.track_client, metric.TEST_ACCURACY, [want_accuracy])
- # test track and get client
- tracker.set_client_context(task_id, round_id, client_id)
- tracker.track_client(metric.TRAIN_ACCURACY, [want_accuracy])
- tracker.track_client("mAP", want_mAP)
- m = tracker.get_client_metric(client_id, round_id, task_id)
- self.assertEqual(m.train_accuracy, [round(want_accuracy, 4)])
- self.assertEqual(m.extra["mAP"], want_mAP)
- # test save client
- tracker.save_client()
- client_metrics = self._store.get_client_metrics(task_id, round_id, [client_id])
- m = metric.ClientMetric.from_sql(next(client_metrics))
- self.assertEqual(m.task_id, task_id)
- self.assertEqual(m.round_id, round_id)
- self.assertEqual(m.client_id, client_id)
- self.assertEqual(m.train_accuracy, [round(want_accuracy, 4)])
- self.assertEqual(m.extra["mAP"], want_mAP)
- self._store.truncate_client_metric()
- # test save multiple clients
- tracker2 = client.init_tracking(self._database)
- tracker2.set_client_context(task_id, round_id, client_id_2)
- tracker2.track_client(metric.TRAIN_LOSS, [want_loss])
- tracker2.track_client("rank1", want_rank1)
- tracker2.track_client(metric.EXTRA, want_extra)
- client_metrics = [tracker.get_client_metric(), tracker2.get_client_metric()]
- tracker.save_clients(client_metrics)
- results = self._store.get_client_metrics(task_id, round_id, [client_id, client_id_2])
- metrics = [metric.ClientMetric.from_sql(r) for r in results]
- self.assertEqual(len(metrics), 2)
- self.assertEqual(len(metrics[1].extra), 2)
- self.assertEqual(metrics[1].task_id, task_id)
- self.assertEqual(metrics[1].round_id, round_id)
- self.assertEqual(metrics[1].client_id, client_id_2)
- self.assertEqual(metrics[1].train_loss, [want_loss])
- self.assertEqual(metrics[1].extra["rank1"], want_rank1)
- if __name__ == '__main__':
- unittest.main()
|