homo_nn_param.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. from pipeline.param.base_param import BaseParam
  2. class TrainerParam(BaseParam):
  3. def __init__(self, trainer_name=None, **kwargs):
  4. super(TrainerParam, self).__init__()
  5. self.trainer_name = trainer_name
  6. self.param = kwargs
  7. def check(self):
  8. if self.trainer_name is not None:
  9. self.check_string(self.trainer_name, 'trainer_name')
  10. def to_dict(self):
  11. ret = {'trainer_name': self.trainer_name, 'param': self.param}
  12. return ret
  13. class DatasetParam(BaseParam):
  14. def __init__(self, dataset_name=None, **kwargs):
  15. super(DatasetParam, self).__init__()
  16. self.dataset_name = dataset_name
  17. self.param = kwargs
  18. def check(self):
  19. if self.dataset_name is not None:
  20. self.check_string(self.dataset_name, 'dataset_name')
  21. def to_dict(self):
  22. ret = {'dataset_name': self.dataset_name, 'param': self.param}
  23. return ret
  24. class HomoNNParam(BaseParam):
  25. def __init__(self,
  26. trainer: TrainerParam = TrainerParam(),
  27. dataset: DatasetParam = DatasetParam(),
  28. torch_seed: int = 100,
  29. nn_define: dict = None,
  30. loss: dict = None,
  31. optimizer: dict = None
  32. ):
  33. super(HomoNNParam, self).__init__()
  34. self.trainer = trainer
  35. self.dataset = dataset
  36. self.torch_seed = torch_seed
  37. self.nn_define = nn_define
  38. self.loss = loss
  39. self.optimizer = optimizer
  40. def check(self):
  41. assert isinstance(self.trainer, TrainerParam), 'trainer must be a TrainerParam()'
  42. assert isinstance(self.dataset, DatasetParam), 'dataset must be a DatasetParam()'
  43. self.trainer.check()
  44. self.dataset.check()
  45. self.check_positive_integer(self.torch_seed, 'torch seed')
  46. if self.nn_define is not None:
  47. assert isinstance(self.nn_define, dict), 'nn define should be a dict defining model structures'
  48. if self.loss is not None:
  49. assert isinstance(self.loss, dict), 'loss parameter should be a loss config dict'
  50. if self.optimizer is not None:
  51. assert isinstance(self.optimizer, dict), 'optimizer parameter should be a config dict'