linear_model_weight.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import numpy as np
  18. from federatedml.framework.weights import ListWeights, TransferableWeights
  19. from federatedml.util import LOGGER, paillier_check, ipcl_operator
  20. class LinearModelWeights(ListWeights):
  21. def __init__(self, l, fit_intercept, raise_overflow_error=True):
  22. l = np.array(l)
  23. if l.shape != (0,) and not paillier_check.is_paillier_encrypted_number(l):
  24. if np.max(np.abs(l)) > 1e8:
  25. if raise_overflow_error:
  26. raise RuntimeError(
  27. "The model weights are overflow, please check if the input data has been normalized")
  28. else:
  29. LOGGER.warning(
  30. f"LinearModelWeights contains entry greater than 1e8.")
  31. super().__init__(l)
  32. self.fit_intercept = fit_intercept
  33. self.raise_overflow_error = raise_overflow_error
  34. def for_remote(self):
  35. return TransferableWeights(self._weights, self.__class__, self.fit_intercept)
  36. @property
  37. def coef_(self):
  38. if self.fit_intercept:
  39. if paillier_check.is_single_ipcl_encrypted_number(self._weights):
  40. coeffs = ipcl_operator.get_coeffs(self._weights.item(0))
  41. return np.array(coeffs)
  42. return np.array(self._weights[:-1])
  43. return np.array(self._weights)
  44. @property
  45. def intercept_(self):
  46. if self.fit_intercept:
  47. if paillier_check.is_single_ipcl_encrypted_number(self._weights):
  48. return ipcl_operator.get_intercept(self._weights.item(0))
  49. return 0.0 if len(self._weights) == 0 else self._weights[-1]
  50. return 0.0
  51. def binary_op(self, other: 'LinearModelWeights', func, inplace):
  52. if inplace:
  53. for k, v in enumerate(self._weights):
  54. self._weights[k] = func(self._weights[k], other._weights[k])
  55. return self
  56. else:
  57. _w = []
  58. for k, v in enumerate(self._weights):
  59. _w.append(func(self._weights[k], other._weights[k]))
  60. return LinearModelWeights(_w, self.fit_intercept, self.raise_overflow_error)
  61. def map_values(self, func, inplace):
  62. if paillier_check.is_single_ipcl_encrypted_number(self._weights):
  63. if inplace:
  64. self._weights = np.array(func(self.unboxed.item(0)))
  65. return self
  66. else:
  67. _w = func(self.unboxed.item(0))
  68. return LinearModelWeights(_w, self.fit_intercept)
  69. if inplace:
  70. for k, v in enumerate(self._weights):
  71. self._weights[k] = func(v)
  72. return self
  73. else:
  74. _w = []
  75. for v in self._weights:
  76. _w.append(func(v))
  77. return LinearModelWeights(_w, self.fit_intercept)
  78. def __repr__(self):
  79. return f"weights: {self.coef_}, intercept: {self.intercept_}"