weights.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import abc
  17. import numpy as np
  18. import operator
  19. from federatedml.secureprotol.encrypt import Encrypt
  20. from federatedml.util import LOGGER
  21. class TransferableWeights:
  22. def __init__(self, weights, cls, *args, **kwargs):
  23. self._weights = weights
  24. self._cls = cls
  25. if args:
  26. self._args = args
  27. if kwargs:
  28. self._kwargs = kwargs
  29. def with_degree(self, degree):
  30. setattr(self, "_degree", degree)
  31. return self
  32. def get_degree(self, default=None):
  33. return getattr(self, "_degree", default)
  34. @property
  35. def unboxed(self):
  36. return self._weights
  37. @property
  38. def weights(self):
  39. if not hasattr(self, "_args") and not hasattr(self, "_kwargs"):
  40. return self._cls(self._weights)
  41. else:
  42. args = self._args if hasattr(self, "_args") else ()
  43. kwargs = self._kwargs if hasattr(self, "_kwargs") else {}
  44. return self._cls(self._weights, *args, **kwargs)
  45. class Weights:
  46. def __init__(self, l):
  47. self._weights = l
  48. def for_remote(self):
  49. return TransferableWeights(self._weights, self.__class__)
  50. @property
  51. def unboxed(self):
  52. return self._weights
  53. @abc.abstractmethod
  54. def map_values(self, func, inplace):
  55. pass
  56. @abc.abstractmethod
  57. def binary_op(self, other, func, inplace):
  58. pass
  59. @abc.abstractmethod
  60. def axpy(self, a, y):
  61. pass
  62. def decrypted(self, cipher: Encrypt, inplace=True):
  63. return self.map_values(cipher.decrypt, inplace=inplace)
  64. def encrypted(self, cipher: Encrypt, inplace=True):
  65. return self.map_values(cipher.encrypt, inplace=inplace)
  66. def __imul__(self, other):
  67. return self.map_values(lambda x: x * other, inplace=True)
  68. def __mul__(self, other):
  69. return self.map_values(lambda x: x * other, inplace=False)
  70. def __rmul__(self, other):
  71. return self * other
  72. def __iadd__(self, other):
  73. return self.binary_op(other, operator.add, inplace=True)
  74. def __add__(self, other):
  75. return self.binary_op(other, operator.add, inplace=False)
  76. def __radd__(self, other):
  77. return self + other
  78. def __isub__(self, other):
  79. return self.binary_op(other, operator.sub, inplace=True)
  80. def __sub__(self, other):
  81. return self.binary_op(other, operator.sub, inplace=False)
  82. def __truediv__(self, other):
  83. return self.map_values(lambda x: x / other, inplace=False)
  84. def __itruediv__(self, other):
  85. return self.map_values(lambda x: x / other, inplace=True)
  86. class NumericWeights(Weights):
  87. def __init__(self, v):
  88. super().__init__(v)
  89. def map_values(self, func, inplace):
  90. v = func(self._weights)
  91. if inplace:
  92. self._weights = v
  93. return self
  94. else:
  95. return NumericWeights(v)
  96. def binary_op(self, other: 'NumpyWeights', func, inplace):
  97. v = func(self._weights, other._weights)
  98. if inplace:
  99. self._weights = v
  100. return self
  101. else:
  102. return NumericWeights(v)
  103. def axpy(self, a, y: 'NumpyWeights'):
  104. self._weights = self._weights + a * y._weights
  105. return self
  106. class ListWeights(Weights):
  107. def __init__(self, l):
  108. super().__init__(l)
  109. def map_values(self, func, inplace):
  110. if inplace:
  111. for k, v in enumerate(self._weights):
  112. self._weights[k] = func(v)
  113. return self
  114. else:
  115. _w = []
  116. for v in self._weights:
  117. _w.append(func(v))
  118. return ListWeights(_w)
  119. def binary_op(self, other: 'ListWeights', func, inplace):
  120. if inplace:
  121. for k, v in enumerate(self._weights):
  122. self._weights[k] = func(self._weights[k], other._weights[k])
  123. return self
  124. else:
  125. _w = []
  126. for k, v in enumerate(self._weights):
  127. _w.append(func(self._weights[k], other._weights[k]))
  128. return ListWeights(_w)
  129. def axpy(self, a, y: 'ListWeights'):
  130. for k, v in enumerate(self._weights):
  131. self._weights[k] += a * y._weights[k]
  132. return self
  133. class DictWeights(Weights):
  134. def __init__(self, d):
  135. super().__init__(d)
  136. def map_values(self, func, inplace):
  137. if inplace:
  138. for k, v in self._weights.items():
  139. self._weights[k] = func(v)
  140. return self
  141. else:
  142. _w = dict()
  143. for k, v in self._weights.items():
  144. _w[k] = func(v)
  145. return DictWeights(_w)
  146. def binary_op(self, other: 'DictWeights', func, inplace):
  147. if inplace:
  148. for k, v in self._weights.items():
  149. self._weights[k] = func(other._weights[k], v)
  150. return self
  151. else:
  152. _w = dict()
  153. for k, v in self._weights.items():
  154. _w[k] = func(other._weights[k], v)
  155. return DictWeights(_w)
  156. def axpy(self, a, y: 'DictWeights'):
  157. for k, v in self._weights.items():
  158. self._weights[k] += a * y._weights[k]
  159. return self
  160. class OrderDictWeights(Weights):
  161. """
  162. This class provide a dict container same as `DictWeights` but with fixed key order.
  163. This feature is useful in secure aggregation random padding generation, which is order sensitive.
  164. """
  165. def __init__(self, d):
  166. super().__init__(d)
  167. self.walking_order = sorted(d.keys(), key=str)
  168. def map_values(self, func, inplace):
  169. if inplace:
  170. for k in self.walking_order:
  171. self._weights[k] = func(self._weights[k])
  172. return self
  173. else:
  174. _w = dict()
  175. for k in self.walking_order:
  176. _w[k] = func(self._weights[k])
  177. return OrderDictWeights(_w)
  178. def binary_op(self, other: 'OrderDictWeights', func, inplace):
  179. if inplace:
  180. for k in self.walking_order:
  181. self._weights[k] = func(other._weights[k], self._weights[k])
  182. return self
  183. else:
  184. _w = dict()
  185. for k in self.walking_order:
  186. _w[k] = func(other._weights[k], self._weights[k])
  187. return OrderDictWeights(_w)
  188. def axpy(self, a, y: 'OrderDictWeights'):
  189. for k in self.walking_order:
  190. self._weights[k] += a * y._weights[k]
  191. return self
  192. class NumpyWeights(Weights):
  193. def __init__(self, arr):
  194. super().__init__(arr)
  195. def map_values(self, func, inplace):
  196. if inplace:
  197. size = self._weights.size
  198. view = self._weights.view().reshape(size)
  199. for i in range(size):
  200. view[i] = func(view[i])
  201. return self
  202. else:
  203. vec_func = np.vectorize(func)
  204. weights = vec_func(self._weights)
  205. return NumpyWeights(weights)
  206. def binary_op(self, other: 'NumpyWeights', func, inplace):
  207. if inplace:
  208. size = self._weights.size
  209. view = self._weights.view().reshape(size)
  210. view_other = other._weights.view().reshape(size)
  211. for i in range(size):
  212. view[i] = func(view[i], view_other[i])
  213. return self
  214. else:
  215. vec_func = np.vectorize(func)
  216. weights = vec_func(self._weights, other._weights)
  217. return NumpyWeights(weights)
  218. def axpy(self, a, y: 'NumpyWeights'):
  219. size = self._weights.size
  220. view = self._weights.view().reshape(size)
  221. view_other = y._weights.view().reshpae(size)
  222. for i in range(size):
  223. view[i] += a * view_other[i]
  224. return self
  225. def __repr__(self):
  226. return self._weights.__repr__()