keras_interface.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. _TF_KERAS_VALID = False
  18. try:
  19. from tensorflow.keras.models import Sequential
  20. _TF_KERAS_VALID = True
  21. except ImportError:
  22. pass
  23. def build_model(model_type="sequential"):
  24. if model_type != "sequential":
  25. raise ValueError("Only support sequential model now")
  26. return SequentialModel()
  27. class SequentialModel(object):
  28. def __init__(self):
  29. if _TF_KERAS_VALID:
  30. self._model = Sequential()
  31. else:
  32. self._model = None
  33. def add(self, layer):
  34. if not _TF_KERAS_VALID:
  35. raise ImportError(
  36. "Please install tensorflow first, "
  37. "can not import sequential model from tensorflow.keras.model !!!")
  38. self._model.add(layer)
  39. @staticmethod
  40. def get_loss_config(loss):
  41. if isinstance(loss, str):
  42. return loss
  43. if loss.__module__ == "tensorflow.python.keras.losses":
  44. return loss.__name__
  45. raise ValueError(
  46. "keras sequential model' loss should be string of losses function of tf_keras")
  47. @staticmethod
  48. def get_optimizer_config(optimizer):
  49. if isinstance(optimizer, str):
  50. return optimizer
  51. opt_config = optimizer.get_config()
  52. if "name" in opt_config:
  53. opt_config["optimizer"] = opt_config["name"]
  54. del opt_config["name"]
  55. return opt_config
  56. def get_network_config(self):
  57. if not _TF_KERAS_VALID:
  58. raise ImportError(
  59. "Please install tensorflow first, "
  60. "can not import sequential model from tensorflow.keras.model !!!")
  61. return json.loads(self._model.to_json())