logistic_regression.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. #
  2. # Copyright 2021 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import numpy as np
  17. from sklearn.linear_model import LogisticRegression
  18. from ..component_converter import ComponentConverterBase
  19. class LRComponentConverter(ComponentConverterBase):
  20. @staticmethod
  21. def get_target_modules():
  22. return ['HomoLR']
  23. def convert(self, model_dict):
  24. param_obj = model_dict["HomoLogisticRegressionParam"]
  25. meta_obj = model_dict["HomoLogisticRegressionMeta"]
  26. sk_lr_model = LogisticRegression(penalty=meta_obj.penalty.lower(),
  27. tol=meta_obj.tol,
  28. fit_intercept=meta_obj.fit_intercept,
  29. max_iter=meta_obj.max_iter)
  30. coefficient = np.empty((1, len(param_obj.header)))
  31. for index in range(len(param_obj.header)):
  32. coefficient[0][index] = param_obj.weight[param_obj.header[index]]
  33. sk_lr_model.coef_ = coefficient
  34. sk_lr_model.intercept_ = np.array([param_obj.intercept])
  35. # hard-coded 0-1 classification as HomoLR only supports this for now
  36. sk_lr_model.classes_ = np.array([0., 1.])
  37. sk_lr_model.n_iter_ = [param_obj.iters]
  38. return sk_lr_model