trainer_base.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. import abc
  2. import importlib
  3. import torch as t
  4. import numpy as np
  5. from torch.nn import Module
  6. from typing import List
  7. from federatedml.util import consts
  8. from federatedml.util import LOGGER
  9. from federatedml.model_base import serialize_models
  10. from federatedml.nn.backend.utils.common import ML_PATH
  11. from federatedml.feature.instance import Instance
  12. from federatedml.evaluation.evaluation import Evaluation
  13. from federatedml.model_base import Metric, MetricMeta
  14. from federatedml.param import EvaluateParam
  15. class StdReturnFormat(object):
  16. def __init__(self, id_table_list, pred_table, classes):
  17. self.id = id_table_list
  18. self.pred_table = pred_table
  19. self.classes = classes
  20. def __call__(self,):
  21. return self.id, self.pred_table, self.classes
  22. class ExporterBase(object):
  23. def __init__(self, *args, **kwargs):
  24. pass
  25. def export_model_dict(self, model=None, optimizer=None, model_define=None, optimizer_define=None, loss_define=None,
  26. epoch_idx=-1, converge_status=False, loss_history=None, best_epoch=-1, extra_data={}):
  27. pass
  28. class TrainerBase(object):
  29. def __init__(self, **kwargs):
  30. self._fed_mode = True
  31. self.role = None
  32. self.party_id = None
  33. self.party_id_list = None
  34. self._flowid = None
  35. self._cache_model = None
  36. self._model = None
  37. self._tracker = None
  38. self._model_checkpoint = None
  39. self._exporter = None
  40. self._evaluation_summary = {}
  41. # running status
  42. self._set_model_checkpoint_epoch = set()
  43. # nn config
  44. self.nn_define, self.opt_define, self.loss_define = {}, {}, {}
  45. # ret summary
  46. self._summary = {}
  47. @staticmethod
  48. def is_pos_int(val):
  49. return val > 0 and isinstance(val, int)
  50. @staticmethod
  51. def is_float(val):
  52. return isinstance(val, float)
  53. @staticmethod
  54. def is_bool(val):
  55. return isinstance(val, bool)
  56. @staticmethod
  57. def check_trainer_param(
  58. var_list,
  59. name_list,
  60. judge_func,
  61. warning_str,
  62. allow_none=True):
  63. for var, name in zip(var_list, name_list):
  64. if allow_none and var is None:
  65. continue
  66. assert judge_func(var), warning_str.format(name)
  67. @property
  68. def model(self):
  69. if not hasattr(self, '_model'):
  70. raise AttributeError(
  71. 'model variable is not initialized, remember to call'
  72. ' super(your_class, self).__init__()')
  73. if self._model is None:
  74. raise AttributeError(
  75. 'model is not set, use set_model() function to set training model')
  76. return self._model
  77. @model.setter
  78. def model(self, val):
  79. self._model = val
  80. @property
  81. def fed_mode(self):
  82. if not hasattr(self, '_fed_mode'):
  83. raise AttributeError(
  84. 'run_local_mode variable is not initialized, remember to call'
  85. ' super(your_class, self).__init__()')
  86. return self._fed_mode
  87. @fed_mode.setter
  88. def fed_mode(self, val):
  89. assert isinstance(val, bool), 'fed mode must be a bool'
  90. self._fed_mode = val
  91. def local_mode(self):
  92. self.fed_mode = False
  93. def set_nn_config(self, nn_define, optimizer_define, loss_define):
  94. self.nn_define = nn_define
  95. self.opt_define = optimizer_define
  96. self.loss_define = loss_define
  97. def set_tracker(self, tracker):
  98. self._tracker = tracker
  99. def set_checkpoint(self, chkp):
  100. self._model_checkpoint = chkp
  101. def set_party_id_list(self, party_id_list):
  102. self.party_id_list = party_id_list
  103. def set_model_exporter(self, exporter):
  104. assert isinstance(
  105. exporter, ExporterBase), 'exporter is not an instance of ExporterBase'
  106. self._exporter = exporter
  107. def get_cached_model(self):
  108. return self._cache_model
  109. @staticmethod
  110. def task_type_infer(predict_result: t.Tensor, true_label):
  111. # infer task type and classes(of classification task)
  112. predict_result = predict_result.cpu()
  113. true_label = true_label.cpu()
  114. pred_shape = predict_result.shape
  115. with t.no_grad():
  116. if true_label.max() == 1.0 and true_label.min() == 0.0:
  117. return consts.BINARY
  118. if (len(pred_shape) > 1) and (pred_shape[1] > 1):
  119. if t.isclose(
  120. predict_result.sum(
  121. axis=1).cpu(), t.Tensor(
  122. [1.0])).all():
  123. return consts.MULTY
  124. else:
  125. return None
  126. elif (len(pred_shape) == 1) or (pred_shape[1] == 1):
  127. return consts.REGRESSION
  128. return None
  129. def _update_metric_summary(self, metric_dict):
  130. if len(metric_dict) == 0:
  131. return
  132. iter_name = list(metric_dict.keys())[0]
  133. metric_dict = metric_dict[iter_name]
  134. if len(self._evaluation_summary) == 0:
  135. self._evaluation_summary = {namespace: {}
  136. for namespace in metric_dict}
  137. for namespace in metric_dict:
  138. for metric_name in metric_dict[namespace]:
  139. epoch_metric = metric_dict[namespace][metric_name]
  140. if namespace not in self._evaluation_summary:
  141. self._evaluation_summary[namespace] = {}
  142. if metric_name not in self._evaluation_summary[namespace]:
  143. self._evaluation_summary[namespace][metric_name] = []
  144. self._evaluation_summary[namespace][metric_name].append(
  145. epoch_metric)
  146. def get_evaluation_summary(self):
  147. return self._evaluation_summary
  148. def get_summary(self):
  149. return self._summary
  150. """
  151. User Interfaces
  152. """
  153. def set_model(self, model: Module):
  154. if not issubclass(type(model), Module):
  155. raise ValueError('model must be a subclass of pytorch nn.Module')
  156. self.model = model
  157. def save(
  158. self,
  159. model=None,
  160. epoch_idx=-1,
  161. optimizer=None,
  162. converge_status=False,
  163. loss_history=None,
  164. best_epoch=-1,
  165. extra_data={}):
  166. assert isinstance(
  167. epoch_idx, int) and epoch_idx >= 0, 'epoch idx must be an int >= 0'
  168. if self._exporter:
  169. model_dict = self._exporter.export_model_dict(model=model,
  170. optimizer=optimizer,
  171. model_define=self.nn_define,
  172. optimizer_define=self.opt_define,
  173. loss_define=self.loss_define,
  174. epoch_idx=epoch_idx,
  175. converge_status=converge_status,
  176. loss_history=loss_history,
  177. best_epoch=best_epoch,
  178. extra_data=extra_data
  179. )
  180. self._cache_model = model_dict
  181. def checkpoint(
  182. self,
  183. epoch_idx,
  184. model=None,
  185. optimizer=None,
  186. converge_status=False,
  187. loss_history=None,
  188. best_epoch=-1,
  189. extra_data={}):
  190. assert isinstance(
  191. epoch_idx, int) and epoch_idx >= 0, 'epoch idx must be an int >= 0'
  192. if self._model_checkpoint:
  193. if self._exporter is None:
  194. raise RuntimeError('exporter is None, cannot save checkpoint')
  195. if epoch_idx in self._set_model_checkpoint_epoch:
  196. LOGGER.info(
  197. 'checkpoint at epoch {} set, skip setting checkpoint'.format(epoch_idx))
  198. return
  199. self.save(model=model, epoch_idx=epoch_idx, optimizer=optimizer, converge_status=converge_status,
  200. loss_history=loss_history, best_epoch=best_epoch, extra_data=extra_data)
  201. self._model_checkpoint.add_checkpoint(len(self._set_model_checkpoint_epoch),
  202. to_save_model=serialize_models(self._cache_model)) # step_index, to_save_model
  203. self._set_model_checkpoint_epoch.add(epoch_idx)
  204. LOGGER.info('checkpoint at epoch {} saved'.format(epoch_idx))
  205. def format_predict_result(self, sample_ids: List, predict_result: t.Tensor,
  206. true_label: t.Tensor, task_type: str = None):
  207. predict_result = predict_result.cpu().detach()
  208. if task_type == 'auto':
  209. task_type = self.task_type_infer(predict_result, true_label)
  210. if task_type is None:
  211. LOGGER.warning(
  212. 'unable to infer predict result type, predict process will be skipped')
  213. return None
  214. classes = None
  215. if task_type == consts.BINARY:
  216. classes = [0, 1]
  217. elif task_type == consts.MULTY:
  218. classes = [i for i in range(predict_result.shape[1])]
  219. true_label = true_label.cpu().detach().flatten().tolist()
  220. if task_type == consts.MULTY:
  221. predict_result = predict_result.tolist()
  222. else:
  223. predict_result = predict_result.flatten().tolist()
  224. id_table = [(id_, Instance(label=l))
  225. for id_, l in zip(sample_ids, true_label)]
  226. score_table = [(id_, pred)
  227. for id_, pred in zip(sample_ids, predict_result)]
  228. return StdReturnFormat(id_table, score_table, classes)
  229. def callback_metric(self, metric_name: str, value: float, metric_type='train', epoch_idx=0):
  230. assert metric_type in [
  231. 'train', 'validate'], 'metric_type should be train or validate'
  232. iter_name = 'iteration_{}'.format(epoch_idx)
  233. if self._tracker is not None:
  234. self._tracker.log_metric_data(
  235. metric_type, iter_name, [
  236. Metric(
  237. metric_name, np.round(
  238. value, 6))])
  239. self._tracker.set_metric_meta(
  240. metric_type, iter_name, MetricMeta(
  241. name=metric_name, metric_type='EVALUATION_SUMMARY'))
  242. def callback_loss(self, loss: float, epoch_idx: int):
  243. if self._tracker is not None:
  244. self._tracker.log_metric_data(
  245. metric_name="loss",
  246. metric_namespace="train",
  247. metrics=[Metric(epoch_idx, loss)],
  248. )
  249. def summary(self, summary_dict: dict):
  250. assert isinstance(summary_dict, dict), 'summary must be a dict'
  251. self._summary = summary_dict
  252. def evaluation(self, sample_ids: list, pred_scores: t.Tensor, label: t.Tensor, dataset_type='train',
  253. metric_list=None, epoch_idx=0, task_type=None):
  254. eval_obj = Evaluation()
  255. if task_type == 'auto':
  256. task_type = self.task_type_infer(pred_scores, label)
  257. if task_type is None:
  258. LOGGER.debug('cannot infer task type, return')
  259. return
  260. assert dataset_type in [
  261. 'train', 'validate'], 'dataset_type must in ["train", "validate"]'
  262. eval_param = EvaluateParam(eval_type=task_type)
  263. if task_type == consts.BINARY:
  264. eval_param.metrics = ['auc', 'ks']
  265. elif task_type == consts.MULTY:
  266. eval_param.metrics = ['accuracy', 'precision', 'recall']
  267. eval_param.check_single_value_default_metric()
  268. eval_obj._init_model(eval_param)
  269. pred_scores = pred_scores.cpu().detach().numpy()
  270. label = label.cpu().detach().numpy().flatten()
  271. if task_type == consts.REGRESSION or task_type == consts.BINARY:
  272. pred_scores = pred_scores.flatten()
  273. label = label.flatten()
  274. pred_scores = pred_scores.tolist()
  275. label = label.tolist()
  276. assert len(pred_scores) == len(
  277. label), 'the length of predict score != the length of label, pred {} and label {}'.format(len(pred_scores), len(label))
  278. eval_data = []
  279. for id_, s, l in zip(sample_ids, pred_scores, label):
  280. if task_type == consts.REGRESSION:
  281. eval_data.append([id_, (l, s, s)])
  282. if task_type == consts.MULTY:
  283. pred_label = np.argmax(s)
  284. eval_data.append([id_, (l, pred_label, s)])
  285. elif task_type == consts.BINARY:
  286. pred_label = (s > 0.5) + 1
  287. eval_data.append([id_, (l, pred_label, s)])
  288. eval_result = eval_obj.evaluate_metrics(dataset_type, eval_data)
  289. if self._tracker is not None:
  290. eval_obj.set_tracker(self._tracker)
  291. # send result to fate-board
  292. eval_obj.callback_metric_data(
  293. {'iteration_{}'.format(epoch_idx): [eval_result]})
  294. self._update_metric_summary(eval_obj.metric_summaries)
  295. return self._evaluation_summary
  296. def to_cuda(self, var):
  297. if hasattr(var, 'cuda'):
  298. return var.cuda()
  299. elif isinstance(var, tuple) or isinstance(var, list):
  300. ret = tuple(self.to_cuda(i) for i in var)
  301. return ret
  302. elif isinstance(var, dict):
  303. for k in var:
  304. if hasattr(var[k], 'cuda'):
  305. var[k] = var[k].cuda()
  306. return var
  307. else:
  308. return var
  309. @abc.abstractmethod
  310. def train(self, train_set, validate_set=None, optimizer=None, loss=None, extra_data={}):
  311. """
  312. train_set : A Dataset Instance, must be a instance of subclass of Dataset (federatedml.nn.dataset.base),
  313. for example, TableDataset() (from federatedml.nn.dataset.table)
  314. validate_set : A Dataset Instance, but optional must be a instance of subclass of Dataset
  315. (federatedml.nn.dataset.base), for example, TableDataset() (from federatedml.nn.dataset.table)
  316. optimizer : A pytorch optimizer class instance, for example, t.optim.Adam(), t.optim.SGD()
  317. loss : A pytorch Loss class, for example, nn.BECLoss(), nn.CrossEntropyLoss()
  318. """
  319. pass
  320. @abc.abstractmethod
  321. def predict(self, dataset):
  322. pass
  323. @abc.abstractmethod
  324. def server_aggregate_procedure(self, extra_data={}):
  325. pass
  326. """
  327. Load Trainer
  328. """
  329. def get_trainer_class(trainer_module_name: str):
  330. if trainer_module_name.endswith('.py'):
  331. trainer_module_name = trainer_module_name.replace('.py', '')
  332. ds_modules = importlib.import_module(
  333. '{}.homo.trainer.{}'.format(
  334. ML_PATH, trainer_module_name))
  335. try:
  336. for k, v in ds_modules.__dict__.items():
  337. if isinstance(v, type):
  338. if issubclass(v, TrainerBase) and v is not TrainerBase:
  339. return v
  340. raise ValueError('Did not find any class in {}.py that is the subclass of Trainer class'.
  341. format(trainer_module_name))
  342. except ValueError as e:
  343. raise e