storage.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. import logging
  2. import os
  3. import random
  4. import sqlite3
  5. import time
  6. from easyfl.communication import grpc_wrapper
  7. from easyfl.pb import common_pb2 as common_pb
  8. from easyfl.pb import tracking_service_pb2 as tracking_pb
  9. from easyfl.protocol.codec import marshal
  10. logger = logging.getLogger(__name__)
  11. DEFAULT_SQLITE_DB = "easyfl.db"
  12. STORAGE_SQLITE = "sqlite"
  13. STORAGE_REMOTE = "remote"
  14. TYPE_ROUND = "round"
  15. TYPE_CLIENT = "client"
  16. DEFAULT_TIMEOUT = 10
  17. CREATE_TASK_METRIC_SQL = '''
  18. CREATE TABLE IF NOT EXISTS task_metric
  19. (task_id CHAR(50) NOT NULL PRIMARY KEY,
  20. config TEXT);'''
  21. CREATE_ROUND_METRIC_SQL = '''
  22. CREATE TABLE IF NOT EXISTS round_metric
  23. (task_id CHAR(50) NOT NULL,
  24. round_id INT NOT NULL,
  25. accuracy REAL NOT NULL,
  26. loss REAL NOT NULL,
  27. round_time REAL NOT NULL,
  28. train_time REAL NOT NULL,
  29. test_time REAL NOT NULL,
  30. train_distribute_time REAL,
  31. test_distribute_time REAL,
  32. train_upload_size REAL,
  33. train_download_size REAL,
  34. test_upload_size REAL,
  35. test_download_size REAL,
  36. extra TEXT,
  37. PRIMARY KEY (task_id, round_id));'''
  38. CREATE_CLIENT_METRIC_SQL = '''
  39. CREATE TABLE IF NOT EXISTS client_metric
  40. (task_id CHAR(50) NOT NULL,
  41. round_id INT NOT NULL,
  42. client_id CHAR(20) NOT NULL,
  43. train_accuracy TEXT ,
  44. train_loss TEXT ,
  45. test_accuracy REAL ,
  46. test_loss REAL ,
  47. train_time REAL ,
  48. test_time REAL ,
  49. train_upload_time REAL ,
  50. test_upload_time REAL ,
  51. train_upload_size REAL ,
  52. train_download_size REAL ,
  53. test_upload_size REAL ,
  54. test_download_size REAL ,
  55. extra TEXT ,
  56. PRIMARY KEY (task_id, round_id, client_id));'''
  57. def get_store(path=None, address=None):
  58. if address:
  59. return RemoteStorage(address)
  60. else:
  61. return SqliteStorage(path)
  62. def get_storage_type(is_remote=True):
  63. if is_remote:
  64. return STORAGE_REMOTE
  65. else:
  66. return STORAGE_SQLITE
  67. class SqliteStorage(object):
  68. """SqliteStorage uses sqlite to save tracking metrics
  69. """
  70. def __init__(self, database=None):
  71. if database is None:
  72. database = os.path.join(os.getcwd(), "tracker", DEFAULT_SQLITE_DB)
  73. self._conn = sqlite3.connect(database, check_same_thread=False)
  74. self.setup()
  75. def __del__(self):
  76. self._conn.close()
  77. def setup(self):
  78. with self._conn:
  79. try:
  80. self._retry_execute(CREATE_TASK_METRIC_SQL)
  81. logger.info("Setup task metric table")
  82. self._retry_execute(CREATE_ROUND_METRIC_SQL)
  83. logger.info("Setup round metric table")
  84. self._retry_execute(CREATE_CLIENT_METRIC_SQL)
  85. logger.info("Setup client metric table")
  86. except sqlite3.OperationalError as e:
  87. logger.error(f"Failed to setup table, error: {e}")
  88. # ------------------ store metrics ------------------
  89. def store_task_metric(self, metric):
  90. sql = "INSERT INTO task_metric(task_id, config) VALUES (?, ?)"
  91. try:
  92. self._retry_execute(sql, metric.to_sql_param())
  93. logger.debug("Task metric saved successfully")
  94. except (sqlite3.OperationalError, sqlite3.DatabaseError) as e:
  95. logger.error(f"Failed to store round metric, error: {e}")
  96. def store_round_metric(self, metric):
  97. sql = '''
  98. INSERT INTO round_metric (
  99. task_id,
  100. round_id,
  101. accuracy,
  102. loss,
  103. round_time,
  104. train_time,
  105. test_time,
  106. train_distribute_time,
  107. test_distribute_time,
  108. train_upload_size,
  109. train_download_size,
  110. test_upload_size,
  111. test_download_size,
  112. extra) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);'''
  113. try:
  114. self._retry_execute(sql, metric.to_sql_param())
  115. logger.debug("Round metric saved successfully")
  116. except (sqlite3.OperationalError, sqlite3.DatabaseError) as e:
  117. logger.error(f"Failed to store round metric {metric.task_id} {metric.round_id}, error: {e}")
  118. def store_client_metrics(self, metrics):
  119. """Store a list of client metrics. If the client exists, replace the values.
  120. :param metrics, list of client metrics to store, [].
  121. """
  122. sql = '''
  123. INSERT INTO client_metric (
  124. task_id,
  125. round_id,
  126. client_id,
  127. train_accuracy,
  128. train_loss,
  129. test_accuracy,
  130. test_loss,
  131. train_time,
  132. test_time,
  133. train_upload_time,
  134. test_upload_time,
  135. train_upload_size,
  136. train_download_size,
  137. test_upload_size,
  138. test_download_size,
  139. extra) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);'''
  140. params = [metric.to_sql_param() for metric in metrics]
  141. try:
  142. with self._conn:
  143. self._conn.executemany(sql, params)
  144. logger.debug("Client metrics saved successfully")
  145. except (sqlite3.OperationalError, sqlite3.DatabaseError) as e:
  146. logger.error(f"Failed to store client metrics, error: {e}")
  147. def store_client_train_metric(self, tid, rid, cid, train_loss, train_time, train_upload_time,
  148. train_download_size, train_upload_size):
  149. sql = "INSERT INTO client_metric (task_id, round_id, client_id, train_loss, train_time, " \
  150. "train_upload_size, train_download_size, train_upload_size) VALUES (?, ? ,? ,?, ?, ?, ?, ?);"
  151. param = (tid, rid, cid, train_loss, train_time, train_upload_time, train_download_size, train_upload_size)
  152. try:
  153. self._retry_execute(sql, param)
  154. except sqlite3.OperationalError as e:
  155. logger.error("Failed to store client train metric, error: {}".format(e))
  156. def store_client_test_metric(self, tid, rid, cid, test_acc, test_loss, test_time,
  157. test_upload_time, test_download_size):
  158. sql = "UPDATE client_metric SET test_accuracy=?, test_loss=?, test_time=? ,test_upload_size=?, " \
  159. "test_download_size=? WHERE task_id=? AND round_id=? AND client_id=?;"
  160. param = (test_acc, test_loss, test_time, test_upload_time, test_download_size, tid, rid, cid)
  161. try:
  162. self._retry_execute(sql, param)
  163. except sqlite3.OperationalError as e:
  164. logger.error("Failed to store client test metric, error: {}".format(e))
  165. # ------------------ get metrics ------------------
  166. def get_task_metric(self, task_id):
  167. sql = "SELECT * FROM task_metric WHERE task_id=?"
  168. with self._conn:
  169. result = self._conn.execute(sql, (task_id,))
  170. for r in result:
  171. return r
  172. def get_round_metrics(self, task_id, rounds):
  173. if rounds:
  174. sql = "SELECT * FROM round_metric WHERE task_id=? AND round_id IN (%s)" % ("?," * len(rounds))[:-1]
  175. param = [task_id] + rounds
  176. else:
  177. sql = "SELECT * FROM round_metric WHERE task_id=?"
  178. param = (task_id,)
  179. with self._conn:
  180. result = self._conn.execute(sql, param)
  181. return result
  182. def get_client_metrics(self, task_id, round_id, client_ids=None):
  183. if client_ids:
  184. sql = "SELECT * FROM client_metric WHERE task_id=? AND round_id=? \
  185. AND client_id IN (%s)" % ("?," * len(client_ids))[:-1]
  186. param = [task_id, round_id] + client_ids
  187. else:
  188. sql = "SELECT * FROM client_metric WHERE task_id=? AND round_id=?"
  189. param = (task_id, round_id)
  190. with self._conn:
  191. result = self._conn.execute(sql, param)
  192. return result
  193. def get_round_train_test_time(self, tid, rounds, interval=1):
  194. sql = "SELECT SUM(train_time+test_time) FROM round_metric WHERE task_id=? AND round_id<?"
  195. result = []
  196. for r in range(interval, rounds + interval, interval):
  197. param = (tid, r)
  198. with self._conn:
  199. res = self._conn.execute(sql, param)
  200. for i in res:
  201. result.append((r, i[0]))
  202. return result
  203. # ------------------ delete metrics ------------------
  204. def truncate_task_metric(self):
  205. sql = "DELETE FROM task_metric"
  206. try:
  207. self._retry_execute(sql)
  208. except (sqlite3.OperationalError, sqlite3.DatabaseError) as e:
  209. logger.error(f"Failed to truncate task metric, error: {e}")
  210. def truncate_round_metric(self):
  211. sql = "DELETE FROM round_metric"
  212. try:
  213. self._retry_execute(sql)
  214. except (sqlite3.OperationalError, sqlite3.DatabaseError) as e:
  215. logger.error(f"Failed to truncate round metric, error: {e}")
  216. def truncate_client_metric(self):
  217. sql = "DELETE FROM client_metric"
  218. try:
  219. self._retry_execute(sql)
  220. except (sqlite3.OperationalError, sqlite3.DatabaseError) as e:
  221. logger.error(f"Failed to truncate round metric, error: {e}")
  222. def delete_round_metric(self, task_id, round_id):
  223. sql = "DELETE FROM round_metric WHERE task_id=? AND round_id=?"
  224. try:
  225. self._retry_execute(sql, (task_id, round_id))
  226. except (sqlite3.OperationalError, sqlite3.DatabaseError) as e:
  227. logger.error(f"Failed to delete round metric {task_id} {round_id}, error: {e}")
  228. def _retry_execute(self, sql, param=(), timeout=DEFAULT_TIMEOUT):
  229. for t in range(0, timeout + 1):
  230. try:
  231. with self._conn:
  232. self._conn.execute(sql, param)
  233. break
  234. except (sqlite3.OperationalError, sqlite3.DatabaseError) as e:
  235. logger.info("retry tracking, error: {}".format(e))
  236. if t == timeout:
  237. raise e
  238. sleep_time = random.uniform(0, 0.2)
  239. time.sleep(sleep_time)
  240. continue
  241. class RemoteStorage(object):
  242. """RemoteStorage sends request to remote service to store tracking metrics
  243. """
  244. def __init__(self, address="localhost:12666"):
  245. # TODO: put the remote address in config
  246. self.tracking_stub = grpc_wrapper.init_stub(grpc_wrapper.TYPE_TRACKING, address)
  247. def store_task_metric(self, metric):
  248. response = self.tracking_stub.TrackTaskMetric(tracking_pb.TrackTaskMetricRequest(task_metric=metric.to_proto()))
  249. if response.status == common_pb.SC_UNKNOWN:
  250. logger.error("Failed to store task metric.")
  251. return response.status
  252. def store_round_metric(self, metric):
  253. req = tracking_pb.TrackRoundMetricRequest(round_metric=metric.to_proto())
  254. response = self.tracking_stub.TrackRoundMetric(req)
  255. if response.status == common_pb.SC_UNKNOWN:
  256. logger.error(f"Failed to store round metric, task_id: {metric.task_id} round_id: {metric.round_id}.")
  257. return response.status
  258. def store_client_metrics(self, metrics):
  259. client_metrics = [m.to_proto() for m in metrics]
  260. req = tracking_pb.TrackClientMetricRequest(client_metrics=client_metrics)
  261. response = self.tracking_stub.TrackClientMetric(req)
  262. if response.status == common_pb.SC_UNKNOWN:
  263. logger.error(f"Failed to store client metrics.")
  264. return response.status
  265. def store_client_train_metric(self, tid, rid, cid, train_loss, train_time, train_upload_time,
  266. train_download_size, train_upload_size):
  267. req = tracking_pb.TrackClientTrainMetricRequest(task_id=tid,
  268. round_id=rid,
  269. client_id=cid,
  270. train_loss=train_loss,
  271. train_time=train_time,
  272. train_upload_time=train_upload_time,
  273. train_download_size=train_download_size,
  274. train_upload_size=train_upload_size)
  275. response = self.tracking_stub.TrackClientTrainMetric(req)
  276. if response.status == common_pb.SC_UNKNOWN:
  277. logger.error("Failed to store client metric, task id: {} round id: {} client id: {}.".format(tid, rid, cid))
  278. return response.status
  279. def store_client_test_metric(self, tid, rid, cid, test_acc, test_loss, test_time,
  280. test_upload_time, test_download_size):
  281. req = tracking_pb.TrackClientTestMetricRequest(task_id=tid,
  282. round_id=rid,
  283. client_id=cid,
  284. test_accuracy=test_acc,
  285. test_loss=test_loss,
  286. test_time=test_time,
  287. test_upload_time=test_upload_time,
  288. test_download_size=test_download_size)
  289. response = self.tracking_stub.TrackClientTestMetric(req)
  290. if response.status == common_pb.SC_UNKNOWN:
  291. logger.error("Failed to store client metric, task id: {} round id: {} client id: {}.".format(tid, rid, cid))
  292. return response.status