sequantial.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from pipeline.component.nn.backend.torch.base import Sequential as Seq
  17. from pipeline.component.nn.backend.torch.cust import CustModel
  18. from pipeline.component.nn.backend.torch.interactive import InteractiveLayer
  19. class Sequential(object):
  20. def __init__(self):
  21. self.__config_type = None
  22. self._model = None
  23. def is_empty(self):
  24. return self._model is None
  25. def get_model(self):
  26. return self._model
  27. def add(self, layer):
  28. _IS_TF_KERAS = False
  29. try:
  30. import tensorflow as tf
  31. _IS_TF_KERAS = isinstance(layer, tf.Module)
  32. except ImportError:
  33. pass
  34. if _IS_TF_KERAS:
  35. # please notice that keras backend now is abandoned, hetero & homo nn support keras backend no more,
  36. # but pipeline keras interface is kept
  37. layer_type = "keras"
  38. else:
  39. layer_type = "torch"
  40. is_layer = hasattr(
  41. layer,
  42. "__module__") and "pipeline.component.nn.backend.torch.nn" == getattr(
  43. layer,
  44. "__module__")
  45. is_seq = isinstance(layer, Seq)
  46. is_cust_model = isinstance(layer, CustModel)
  47. is_interactive_layer = isinstance(layer, InteractiveLayer)
  48. if not (is_layer or is_cust_model or is_interactive_layer or is_seq):
  49. raise ValueError(
  50. "Layer type {} not support yet, added layer must be a FateTorchLayer or a fate_torch "
  51. "Sequential, remember to call fate_torch_hook() before using pipeline "
  52. "".format(
  53. type(layer)))
  54. self._add_layer(layer, layer_type)
  55. def _add_layer(self, layer, layer_type, replace=True):
  56. if layer_type == 'torch':
  57. if self._model is None or replace:
  58. self._model = Seq()
  59. self.__config_type = layer_type
  60. elif layer_type == 'keras':
  61. # please notice that keras backend now is abandoned, hetero & homo nn support keras backend no more,
  62. # but pipeline keras interface is kept
  63. from pipeline.component.nn.models.keras_interface import SequentialModel
  64. self.__config_type = layer_type
  65. self._model = SequentialModel()
  66. self._model.add(layer)
  67. def get_layer_type(self):
  68. return self.__config_type
  69. def get_loss_config(self, loss):
  70. return self._model.get_loss_config(loss)
  71. def get_optimizer_config(self, optimizer):
  72. return self._model.get_optimizer_config(optimizer)
  73. def get_network_config(self):
  74. if not self.__config_type:
  75. raise ValueError("Empty layer find, can't get config")
  76. return self._model.get_network_config()
  77. def __repr__(self):
  78. return self._model.__repr__()