interface.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. from pipeline.param.base_param import BaseParam
  2. import sys
  3. def not_working_save_to_fate(*args, **kwargs):
  4. raise ValueError(
  5. 'save to fate not working, please check if your ipython is installed, '
  6. 'and if ipython.get_ipython() is working')
  7. try:
  8. import IPython as ipy
  9. from IPython.core.magic import register_cell_magic
  10. except ImportError as e:
  11. ipy = None
  12. register_cell_magic = None
  13. # check
  14. if register_cell_magic is not None:
  15. if ipy.get_ipython():
  16. @register_cell_magic
  17. def save_to_fate(line, cell):
  18. # search for federatedml path
  19. base_path = None
  20. for p in sys.path:
  21. if p.endswith('/fate/python'):
  22. base_path = p
  23. break
  24. if base_path is None:
  25. raise ValueError(
  26. 'cannot find fate/python in system path, please check your configuration')
  27. base_path = base_path + '/federatedml/'
  28. model_pth = 'nn/model_zoo/'
  29. dataset_pth = 'nn/dataset/'
  30. trainer_pth = 'nn/homo/trainer/'
  31. aggregator_pth = 'framework/homo/aggregator/'
  32. loss_path = 'nn/loss/'
  33. mode_map = {
  34. 'model': model_pth,
  35. 'trainer': trainer_pth,
  36. 'aggregator': aggregator_pth,
  37. 'dataset': dataset_pth,
  38. 'loss': loss_path
  39. }
  40. args = line.split()
  41. assert len(
  42. args) == 2, "input args len is not 2, got {} \n expect format: %%save_to_fate SAVE_MODE FILENAME \n SAVE_MODE in ['model', 'dataset', 'trainer', 'loss', 'aggregator'] FILE_NAME xxx.py".format(args)
  43. modes_avail = ['model', 'dataset', 'trainer', 'aggregator', 'loss']
  44. save_mode = args[0]
  45. file_name = args[1]
  46. assert save_mode in modes_avail, 'avail modes are {}, got {}'.format(
  47. modes_avail, save_mode)
  48. assert file_name.endswith('.py'), 'save file should be a .py'
  49. with open(base_path + mode_map[save_mode] + file_name, 'w') as f:
  50. f.write(cell)
  51. ipy.get_ipython().run_cell(cell)
  52. else:
  53. save_to_fate = not_working_save_to_fate
  54. else:
  55. save_to_fate = not_working_save_to_fate
  56. class TrainerParam(BaseParam):
  57. def __init__(self, trainer_name=None, **kwargs):
  58. super(TrainerParam, self).__init__()
  59. self.trainer_name = trainer_name
  60. self.param = kwargs
  61. def check(self):
  62. if self.trainer_name is None:
  63. raise ValueError(
  64. 'You did not specify the trainer name, please set the trainer name')
  65. self.check_string(self.trainer_name, 'trainer_name')
  66. def to_dict(self):
  67. ret = {'trainer_name': self.trainer_name, 'param': self.param}
  68. return ret
  69. class DatasetParam(BaseParam):
  70. def __init__(self, dataset_name=None, **kwargs):
  71. super(DatasetParam, self).__init__()
  72. self.dataset_name = dataset_name
  73. self.param = kwargs
  74. def check(self):
  75. self.check_string(self.dataset_name, 'dataset_name')
  76. def to_dict(self):
  77. ret = {'dataset_name': self.dataset_name, 'param': self.param}
  78. return ret