base.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import numpy as np
  2. from federatedml.param.hetero_nn_param import HeteroNNParam
  3. from federatedml.transfer_variable.base_transfer_variable import BaseTransferVariables
  4. class InteractiveLayerBase(object):
  5. def __init__(self, params: HeteroNNParam, **kwargs):
  6. self.params = params
  7. self.transfer_variable: BaseTransferVariables = None
  8. def set_flow_id(self, flow_id):
  9. if self.transfer_variable is not None:
  10. self.transfer_variable.set_flowid(flow_id)
  11. def set_batch(self, batch_size):
  12. pass
  13. def forward(self, x, epoch: int, batch: int, train: bool = True, **kwargs) -> np.ndarray:
  14. pass
  15. def backward(self, *args, **kwargs):
  16. pass
  17. def guest_backward(self, error, epoch: int, batch_idx: int, **kwargs):
  18. pass
  19. def host_backward(self, epoch: int, batch_idx: int, **kwargs):
  20. pass
  21. def export_model(self) -> bytes:
  22. pass
  23. def restore_model(self, model_bytes: bytes):
  24. pass
  25. def set_backward_select_strategy(self):
  26. pass
  27. class InteractiveLayerGuest(InteractiveLayerBase):
  28. def __init__(self, params: HeteroNNParam, **kwargs):
  29. super(InteractiveLayerGuest, self).__init__(params, **kwargs)
  30. def backward(self, error, epoch: int, batch: int, **kwargs):
  31. pass
  32. class InteractiveLayerHost(InteractiveLayerBase):
  33. def __init__(self, params: HeteroNNParam, **kwargs):
  34. super(InteractiveLayerHost, self).__init__(params, **kwargs)
  35. def backward(self, epoch: int, batch: int, **kwargs):
  36. pass