fate_operator.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from collections import Iterable
  17. import numpy as np
  18. from scipy.sparse import csr_matrix
  19. from federatedml.feature.instance import Instance
  20. from federatedml.feature.sparse_vector import SparseVector
  21. from federatedml.util import paillier_check
  22. def _one_dimension_dot(X, w):
  23. res = 0
  24. # LOGGER.debug("_one_dimension_dot, len of w: {}, len of X: {}".format(len(w), len(X)))
  25. # If all weights are in one single IPCL encrypted number
  26. if paillier_check.is_single_ipcl_encrypted_number(w):
  27. if isinstance(X, csr_matrix):
  28. res = w.item(0).dot(X.data)
  29. else:
  30. res = w.item(0).dot(X)
  31. return res
  32. if isinstance(X, csr_matrix):
  33. for idx, value in zip(X.indices, X.data):
  34. res += value * w[idx]
  35. else:
  36. for i in range(len(X)):
  37. if np.fabs(X[i]) < 1e-5:
  38. continue
  39. res += w[i] * X[i]
  40. if res == 0:
  41. if paillier_check.is_paillier_encrypted_number(w[0]):
  42. res = 0 * w[0]
  43. return res
  44. def dot(value, w):
  45. w_ndim = np.ndim(w)
  46. if paillier_check.is_single_ipcl_encrypted_number(w):
  47. w_ndim += 1
  48. if isinstance(value, Instance):
  49. X = value.features
  50. else:
  51. X = value
  52. # # dot(a, b)[i, j, k, m] = sum(a[i, j, :] * b[k, :, m])
  53. # # One-dimension dot, which is the inner product of these two arrays
  54. if np.ndim(X) == w_ndim == 1:
  55. return _one_dimension_dot(X, w)
  56. elif np.ndim(X) == 2 and w_ndim == 1:
  57. res = []
  58. for x in X:
  59. res.append(_one_dimension_dot(x, w))
  60. res = np.array(res)
  61. else:
  62. res = np.dot(X, w)
  63. return res
  64. def vec_dot(x, w):
  65. new_data = 0
  66. if isinstance(x, SparseVector):
  67. for idx, v in x.get_all_data():
  68. # if idx < len(w):
  69. new_data += v * w[idx]
  70. else:
  71. new_data = np.dot(x, w)
  72. return new_data
  73. def reduce_add(x, y):
  74. if x is None and y is None:
  75. return None
  76. if x is None:
  77. return y
  78. if y is None:
  79. return x
  80. if not isinstance(x, Iterable):
  81. result = x + y
  82. elif isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
  83. result = x + y
  84. else:
  85. result = []
  86. for idx, acc in enumerate(x):
  87. if acc is None:
  88. result.append(acc)
  89. continue
  90. result.append(acc + y[idx])
  91. return result
  92. def norm(vector, p=2):
  93. """
  94. Get p-norm of this vector
  95. Parameters
  96. ----------
  97. vector : numpy array, Input vector
  98. p: int, p-norm
  99. """
  100. if p < 1:
  101. raise ValueError('p should larger or equal to 1 in p-norm')
  102. if type(vector).__name__ != 'ndarray':
  103. vector = np.array(vector)
  104. return np.linalg.norm(vector, p)
  105. # def generate_anonymous(fid, party_id=None, role=None, model=None):
  106. # if model is None:
  107. # if party_id is None or role is None:
  108. # raise ValueError("party_id or role should be provided when generating"
  109. # "anonymous.")
  110. # if party_id is None:
  111. # party_id = model.component_properties.local_partyid
  112. # if role is None:
  113. # role = model.role
  114. #
  115. # party_id = str(party_id)
  116. # fid = str(fid)
  117. # return "_".join([role, party_id, fid])
  118. #
  119. #
  120. # def reconstruct_fid(encoded_name):
  121. # try:
  122. # col_index = int(encoded_name.split('_')[-1])
  123. # except IndexError or ValueError:
  124. # raise RuntimeError(f"Decode name: {encoded_name} is not a valid value")
  125. # return col_index