1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import json
- _TF_KERAS_VALID = False
- try:
- from tensorflow.keras.models import Sequential
- _TF_KERAS_VALID = True
- except ImportError:
- pass
- def build_model(model_type="sequential"):
- if model_type != "sequential":
- raise ValueError("Only support sequential model now")
- return SequentialModel()
- class SequentialModel(object):
- def __init__(self):
- if _TF_KERAS_VALID:
- self._model = Sequential()
- else:
- self._model = None
- def add(self, layer):
- if not _TF_KERAS_VALID:
- raise ImportError(
- "Please install tensorflow first, "
- "can not import sequential model from tensorflow.keras.model !!!")
- self._model.add(layer)
- @staticmethod
- def get_loss_config(loss):
- if isinstance(loss, str):
- return loss
- if loss.__module__ == "tensorflow.python.keras.losses":
- return loss.__name__
- raise ValueError(
- "keras sequential model' loss should be string of losses function of tf_keras")
- @staticmethod
- def get_optimizer_config(optimizer):
- if isinstance(optimizer, str):
- return optimizer
- opt_config = optimizer.get_config()
- if "name" in opt_config:
- opt_config["optimizer"] = opt_config["name"]
- del opt_config["name"]
- return opt_config
- def get_network_config(self):
- if not _TF_KERAS_VALID:
- raise ImportError(
- "Please install tensorflow first, "
- "can not import sequential model from tensorflow.keras.model !!!")
- return json.loads(self._model.to_json())
|