client.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. import json
  2. import torch
  3. import tempfile
  4. import inspect
  5. from fate_arch.computing.non_distributed import LocalData
  6. from fate_arch.computing._util import is_table
  7. from federatedml.model_base import ModelBase
  8. from federatedml.nn.homo.trainer.trainer_base import get_trainer_class, TrainerBase
  9. from federatedml.nn.backend.utils.data import load_dataset
  10. from federatedml.param.homo_nn_param import HomoNNParam
  11. from federatedml.nn.backend.torch import serialization as s
  12. from federatedml.nn.backend.torch.base import FateTorchOptimizer
  13. from federatedml.model_base import MetricMeta
  14. from federatedml.util import LOGGER
  15. from federatedml.util import consts
  16. from federatedml.nn.homo.trainer.trainer_base import StdReturnFormat
  17. from federatedml.nn.backend.utils.common import global_seed, get_homo_model_dict, get_homo_param_meta, recover_model_bytes, get_torch_model_bytes
  18. from federatedml.callbacks.model_checkpoint import ModelCheckpoint
  19. from federatedml.statistic.data_overview import check_with_inst_id
  20. from federatedml.nn.homo.trainer.trainer_base import ExporterBase
  21. from fate_arch.session import computing_session
  22. from federatedml.nn.backend.utils.data import get_ret_predict_table
  23. from federatedml.nn.dataset.table import TableDataset
  24. from federatedml.nn.backend.utils.data import add_match_id
  25. from federatedml.protobuf.generated.homo_nn_model_param_pb2 import HomoNNParam as HomoNNParamPB
  26. from federatedml.protobuf.generated.homo_nn_model_meta_pb2 import HomoNNMeta as HomoNNMetaPB
  27. class NNModelExporter(ExporterBase):
  28. def __init__(self, *args, **kwargs):
  29. super().__init__(*args, **kwargs)
  30. def export_model_dict(self, model=None, optimizer=None, model_define=None, optimizer_define=None, loss_define=None,
  31. epoch_idx=-1, converge_status=False, loss_history=None, best_epoch=-1, extra_data={}):
  32. if issubclass(type(model), torch.nn.Module):
  33. model_statedict = model.state_dict()
  34. else:
  35. model_statedict = None
  36. opt_state_dict = None
  37. if optimizer is not None:
  38. assert isinstance(optimizer, torch.optim.Optimizer), \
  39. 'optimizer must be an instance of torch.optim.Optimizer'
  40. opt_state_dict = optimizer.state_dict()
  41. model_status = {
  42. 'model': model_statedict,
  43. 'optimizer': opt_state_dict,
  44. }
  45. model_saved_bytes = get_torch_model_bytes(model_status)
  46. extra_data_bytes = get_torch_model_bytes(extra_data)
  47. param = HomoNNParamPB()
  48. meta = HomoNNMetaPB()
  49. # save param
  50. param.model_bytes = model_saved_bytes
  51. param.extra_data_bytes = extra_data_bytes
  52. param.epoch_idx = epoch_idx
  53. param.converge_status = converge_status
  54. param.best_epoch = best_epoch
  55. if loss_history is None:
  56. loss_history = []
  57. param.loss_history.extend(loss_history)
  58. # save meta
  59. meta.nn_define.append(json.dumps(model_define))
  60. meta.optimizer_define.append(json.dumps(optimizer_define))
  61. meta.loss_func_define.append(json.dumps(loss_define))
  62. return get_homo_model_dict(param, meta)
  63. class HomoNNClient(ModelBase):
  64. def __init__(self):
  65. super(HomoNNClient, self).__init__()
  66. self.model_param = HomoNNParam()
  67. self.trainer = consts.FEDAVG_TRAINER
  68. self.trainer_param = {}
  69. self.dataset_module = None
  70. self.dataset = None
  71. self.dataset_param = {}
  72. self.torch_seed = None
  73. self.loss = None
  74. self.optimizer = None
  75. self.nn_define = None
  76. # running varialbles
  77. self.trainer_inst = None
  78. # export model
  79. self.exporter = NNModelExporter()
  80. self.model_loaded = False
  81. self.model = None
  82. # cache dataset
  83. self.cache_dataset = {}
  84. # dtable partitions
  85. self.partitions = 4
  86. # warm start display iter
  87. self.warm_start_iter = None
  88. def _init_model(self, param: HomoNNParam):
  89. train_param = param.trainer.to_dict()
  90. dataset_param = param.dataset.to_dict()
  91. self.trainer = train_param['trainer_name']
  92. self.dataset = dataset_param['dataset_name']
  93. self.trainer_param = train_param['param']
  94. self.dataset_param = dataset_param['param']
  95. self.torch_seed = param.torch_seed
  96. self.nn_define = param.nn_define
  97. self.loss = param.loss
  98. self.optimizer = param.optimizer
  99. def init(self):
  100. # set random seed
  101. global_seed(self.torch_seed)
  102. # load trainer class
  103. if self.trainer is None:
  104. raise ValueError(
  105. 'Trainer is not specified, please specify your trainer')
  106. trainer_class = get_trainer_class(self.trainer)
  107. LOGGER.info('trainer class is {}'.format(trainer_class))
  108. # recover model from model config / or recover from saved model param
  109. loaded_model_dict = None
  110. # if has model protobuf, load model config from protobuf
  111. load_opt_state_dict = False
  112. if self.model_loaded:
  113. param, meta = get_homo_param_meta(self.model)
  114. self.warm_start_iter = param.epoch_idx
  115. if param is None or meta is None:
  116. raise ValueError(
  117. 'model protobuf is None, make sure'
  118. 'that your trainer calls export_model() function to save models')
  119. if meta.nn_define[0] is None:
  120. raise ValueError(
  121. 'nn_define is None, model protobuf has no nn-define, make sure'
  122. 'that your trainer calls export_model() function to save models')
  123. self.nn_define = json.loads(meta.nn_define[0])
  124. loss = json.loads(meta.loss_func_define[0])
  125. optimizer = json.loads(meta.optimizer_define[0])
  126. loaded_model_dict = recover_model_bytes(param.model_bytes)
  127. extra_data = recover_model_bytes(param.extra_data_bytes)
  128. if self.optimizer is not None and optimizer != self.optimizer:
  129. LOGGER.info('optimizer updated')
  130. else:
  131. self.optimizer = optimizer
  132. load_opt_state_dict = True
  133. if self.loss is not None and self.loss != loss:
  134. LOGGER.info('loss updated')
  135. else:
  136. self.loss = loss
  137. else:
  138. extra_data = {}
  139. # check key param
  140. if self.nn_define is None:
  141. raise ValueError(
  142. 'Model structure is not defined, nn_define is None, please check your param')
  143. # get model from nn define
  144. model = s.recover_sequential_from_dict(self.nn_define)
  145. if loaded_model_dict:
  146. model.load_state_dict(loaded_model_dict['model'])
  147. LOGGER.info('load model state dict from check point')
  148. LOGGER.info('model structure is {}'.format(model))
  149. # init optimizer
  150. if self.optimizer is not None:
  151. optimizer_: FateTorchOptimizer = s.recover_optimizer_from_dict(
  152. self.optimizer)
  153. # pass model parameters to optimizer
  154. optimizer = optimizer_.to_torch_instance(model.parameters())
  155. if load_opt_state_dict:
  156. LOGGER.info('load optimizer state dict')
  157. optimizer.load_state_dict(loaded_model_dict['optimizer'])
  158. LOGGER.info('optimizer is {}'.format(optimizer))
  159. else:
  160. optimizer = None
  161. LOGGER.info('optimizer is not specified')
  162. # init loss
  163. if self.loss is not None:
  164. loss_fn = s.recover_loss_fn_from_dict(self.loss)
  165. LOGGER.info('loss function is {}'.format(loss_fn))
  166. else:
  167. loss_fn = None
  168. LOGGER.info('loss function is not specified')
  169. # init trainer
  170. trainer_inst: TrainerBase = trainer_class(**self.trainer_param)
  171. trainer_train_args = inspect.getfullargspec(trainer_inst.train).args
  172. args_format = [
  173. 'self',
  174. 'train_set',
  175. 'validate_set',
  176. 'optimizer',
  177. 'loss',
  178. 'extra_data'
  179. ]
  180. if len(trainer_train_args) < 6:
  181. raise ValueError(
  182. 'Train function of trainer should take 6 arguments :{}, but current trainer.train '
  183. 'only takes {} arguments: {}'.format(
  184. args_format, len(trainer_train_args), trainer_train_args))
  185. trainer_inst.set_nn_config(self.nn_define, self.optimizer, self.loss)
  186. trainer_inst.fed_mode = True
  187. return trainer_inst, model, optimizer, loss_fn, extra_data
  188. def fit(self, train_input, validate_input=None):
  189. LOGGER.debug('train input is {}'.format(train_input))
  190. # train input & validate input are DTables or path str
  191. if not is_table(train_input):
  192. if isinstance(train_input, LocalData):
  193. train_input = train_input.path
  194. assert train_input is not None, 'input train path is None!'
  195. if not is_table(validate_input):
  196. if isinstance(validate_input, LocalData):
  197. validate_input = validate_input.path
  198. assert validate_input is not None, 'input validate path is None!'
  199. # fate loss callback setting
  200. self.callback_meta(
  201. "loss",
  202. "train",
  203. MetricMeta(
  204. name="train",
  205. metric_type="LOSS",
  206. extra_metas={
  207. "unit_name": "epochs"}))
  208. # set random seed
  209. global_seed(self.torch_seed)
  210. self.trainer_inst, model, optimizer, loss_fn, extra_data = self.init()
  211. self.trainer_inst.set_model(model)
  212. self.trainer_inst.set_tracker(self.tracker)
  213. self.trainer_inst.set_model_exporter(self.exporter)
  214. # load dataset class
  215. dataset_inst = load_dataset(
  216. dataset_name=self.dataset,
  217. data_path_or_dtable=train_input,
  218. dataset_cache=self.cache_dataset,
  219. param=self.dataset_param
  220. )
  221. # set dataset prefix
  222. dataset_inst.set_type('train')
  223. LOGGER.info('train dataset instance is {}'.format(dataset_inst))
  224. if validate_input:
  225. val_dataset_inst = load_dataset(
  226. dataset_name=self.dataset,
  227. data_path_or_dtable=validate_input,
  228. dataset_cache=self.cache_dataset,
  229. param=self.dataset_param
  230. )
  231. if id(val_dataset_inst) != id(dataset_inst):
  232. dataset_inst.set_type('validate')
  233. LOGGER.info('validate dataset instance is {}'.format(dataset_inst))
  234. else:
  235. val_dataset_inst = None
  236. # display warmstart iter
  237. if self.component_properties.is_warm_start:
  238. self.callback_warm_start_init_iter(self.warm_start_iter)
  239. # set model check point
  240. self.trainer_inst.set_checkpoint(ModelCheckpoint(self, save_freq=1))
  241. # training
  242. self.trainer_inst.train(
  243. dataset_inst,
  244. val_dataset_inst,
  245. optimizer,
  246. loss_fn,
  247. extra_data
  248. )
  249. # training is done, get exported model
  250. self.model = self.trainer_inst.get_cached_model()
  251. self.set_summary(self.trainer_inst.get_summary())
  252. def predict(self, cpn_input):
  253. with_inst_id = False
  254. schema = None
  255. if not is_table(cpn_input):
  256. if isinstance(cpn_input, LocalData):
  257. cpn_input = cpn_input.path
  258. assert cpn_input is not None, 'input path is None!'
  259. elif is_table(cpn_input):
  260. with_inst_id = check_with_inst_id(cpn_input)
  261. schema = cpn_input.schema
  262. LOGGER.info('running predict')
  263. if self.trainer_inst is None:
  264. # init model
  265. self.trainer_inst, model, optimizer, loss_fn, _ = self.init()
  266. self.trainer_inst.set_model(model)
  267. self.trainer_inst.set_tracker(self.tracker)
  268. dataset_inst = load_dataset(
  269. dataset_name=self.dataset,
  270. data_path_or_dtable=cpn_input,
  271. dataset_cache=self.cache_dataset,
  272. param=self.dataset_param)
  273. if not dataset_inst.has_dataset_type():
  274. dataset_inst.set_type('predict')
  275. trainer_ret = self.trainer_inst.predict(dataset_inst)
  276. if trainer_ret is None or not isinstance(trainer_ret, StdReturnFormat):
  277. LOGGER.info(
  278. 'trainer did not return formatted predicted result, skip predict')
  279. return None
  280. id_table, pred_table, classes = trainer_ret()
  281. if with_inst_id: # set match id
  282. add_match_id(id_table=id_table, dataset_inst=dataset_inst)
  283. id_dtable, pred_dtable = get_ret_predict_table(
  284. id_table, pred_table, classes, self.partitions, computing_session)
  285. ret_table = self.predict_score_to_output(
  286. id_dtable, pred_dtable, classes)
  287. if schema is not None:
  288. self.set_predict_data_schema(ret_table, schema)
  289. return ret_table
  290. def export_model(self):
  291. if self.model is None:
  292. LOGGER.debug('export an empty model')
  293. return self.exporter.export_model_dict() # return an empty model
  294. return self.model
  295. def load_model(self, model_dict):
  296. model_dict = list(model_dict["model"].values())[0]
  297. self.model = model_dict
  298. self.model_loaded = True
  299. # override function
  300. @staticmethod
  301. def set_predict_data_schema(predict_datas, schemas):
  302. if predict_datas is None:
  303. return predict_datas
  304. if isinstance(predict_datas, list):
  305. predict_data = predict_datas[0]
  306. schema = schemas[0]
  307. else:
  308. predict_data = predict_datas
  309. schema = schemas
  310. if predict_data is not None:
  311. predict_data.schema = {
  312. "header": [
  313. "label",
  314. "predict_result",
  315. "predict_score",
  316. "predict_detail",
  317. "type",
  318. ],
  319. "sid": 'id',
  320. "content_type": "predict_result"
  321. }
  322. if schema.get("match_id_name") is not None:
  323. predict_data.schema["match_id_name"] = schema.get(
  324. "match_id_name")
  325. return predict_data