test_fix_point.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. #
  17. # Copyright 2019 The FATE Authors. All Rights Reserved.
  18. #
  19. # Licensed under the Apache License, Version 2.0 (the "License");
  20. # you may not use this file except in compliance with the License.
  21. # You may obtain a copy of the License at
  22. #
  23. # http://www.apache.org/licenses/LICENSE-2.0
  24. #
  25. # Unless required by applicable law or agreed to in writing, software
  26. # distributed under the License is distributed on an "AS IS" BASIS,
  27. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  28. # See the License for the specific language governing permissions and
  29. # limitations under the License.
  30. #
  31. import random
  32. import unittest
  33. import uuid
  34. from concurrent.futures import ProcessPoolExecutor, as_completed
  35. import numpy as np
  36. from federatedml.secureprotol.spdz import SPDZ
  37. from federatedml.secureprotol.spdz.tensor.fixedpoint_numpy import FixedPointTensor
  38. from federatedml.transfer_variable.transfer_class.secret_share_transfer_variable import SecretShareTransferVariable
  39. NUM_HOSTS = 1
  40. EPS = 0.001
  41. def session_init(job_id, idx):
  42. from fate_arch import session
  43. role = "guest" if idx < 1 else "host"
  44. party_id = 9999 + idx if idx < 1 else 10000 + (idx - 1)
  45. role_parties = {
  46. "host": [
  47. 10000 + i for i in range(NUM_HOSTS)
  48. ],
  49. "guest": [
  50. 9999 + i for i in range(1)
  51. ]
  52. }
  53. sess = session.init(job_id)
  54. sess.init_federation(job_id, dict(local=dict(role=role, party_id=party_id), role=role_parties))
  55. return sess.parties.local_party(), sess.parties.all_parties()
  56. def submit(func, *args, **kwargs):
  57. with ProcessPoolExecutor() as pool:
  58. num = NUM_HOSTS + 1
  59. result = [None] * num
  60. futures = {}
  61. for _idx in range(num):
  62. kv = kwargs.copy()
  63. kv["idx"] = _idx
  64. futures[pool.submit(func, *args, **kv)] = _idx
  65. for future in as_completed(futures):
  66. result[futures[future]] = future.result()
  67. return result
  68. def create_and_get(job_id, idx, data):
  69. _, all_parties = session_init(job_id, idx)
  70. with SPDZ():
  71. if idx == 0:
  72. x = FixedPointTensor.from_source("x", data)
  73. else:
  74. x = FixedPointTensor.from_source("x", all_parties[0])
  75. return x.get()
  76. def add_and_sub(job_id, idx, data_list):
  77. _, all_parties = session_init(job_id, idx)
  78. with SPDZ():
  79. if idx == 0:
  80. x = FixedPointTensor.from_source("x", data_list[0])
  81. y = FixedPointTensor.from_source("y", all_parties[1])
  82. else:
  83. x = FixedPointTensor.from_source("x", all_parties[0])
  84. y = FixedPointTensor.from_source("y", data_list[1])
  85. a = (x + y).get()
  86. b = (x - y).get()
  87. return a, b
  88. def add_and_sub_plaintext(job_id, idx, data_list):
  89. _, all_parties = session_init(job_id, idx)
  90. with SPDZ():
  91. if idx == 0:
  92. x = FixedPointTensor.from_source("x", data_list[0])
  93. else:
  94. x = FixedPointTensor.from_source("x", all_parties[0])
  95. y = data_list[1]
  96. a = (x + y).get()
  97. a1 = (y + x).get()
  98. b = (x - y).get()
  99. b1 = (y - x).get()
  100. return a, a1, b, b1
  101. def mul_plaintext(job_id, idx, data_list):
  102. _, all_parties = session_init(job_id, idx)
  103. with SPDZ():
  104. if idx == 0:
  105. x = FixedPointTensor.from_source("x", data_list[0])
  106. else:
  107. x = FixedPointTensor.from_source("x", all_parties[0])
  108. y = data_list[1]
  109. return (x * y).get(), (y * x).get()
  110. def mat_mul(job_id, idx, data_list):
  111. _, all_parties = session_init(job_id, idx)
  112. with SPDZ():
  113. if idx == 0:
  114. x = FixedPointTensor.from_source("x", data_list[0])
  115. y = FixedPointTensor.from_source("y", all_parties[1])
  116. else:
  117. x = FixedPointTensor.from_source("x", all_parties[0])
  118. y = FixedPointTensor.from_source("y", data_list[1])
  119. return (x @ y).get()
  120. def einsum(job_id, idx, einsum_expr, data_list):
  121. _, all_parties = session_init(job_id, idx)
  122. with SPDZ():
  123. if idx == 0:
  124. x = FixedPointTensor.from_source("x", data_list[0])
  125. y = FixedPointTensor.from_source("y", all_parties[1])
  126. else:
  127. x = FixedPointTensor.from_source("x", all_parties[0])
  128. y = FixedPointTensor.from_source("y", data_list[1])
  129. return x.einsum(y, einsum_expr).get()
  130. class TestSyncBase(unittest.TestCase):
  131. def setUp(self) -> None:
  132. self.transfer_variable = SecretShareTransferVariable()
  133. self.job_id = str(uuid.uuid1())
  134. self.transfer_variable.set_flowid(self.job_id)
  135. def test_create_and_get(self):
  136. data = np.random.rand(10, 15)
  137. rec = submit(create_and_get, self.job_id, data=data)
  138. for x in rec:
  139. self.assertAlmostEqual(np.linalg.norm(x - data), 0, delta=EPS)
  140. def test_add_and_sub(self):
  141. x = np.random.rand(10, 15)
  142. y = np.random.rand(10, 15)
  143. data_list = [x, y]
  144. rec = submit(add_and_sub, self.job_id, data_list=data_list)
  145. for a, b in rec:
  146. self.assertAlmostEqual(np.linalg.norm((x + y) - a), 0, delta=2 * EPS)
  147. self.assertAlmostEqual(np.linalg.norm((x - y) - b), 0, delta=2 * EPS)
  148. def test_add_and_sub_plaintext(self):
  149. # x = np.random.rand(10, 15)
  150. # y = np.random.rand(10, 15)
  151. x = np.array([1, 2, 3, 4])
  152. y = np.array([5, 6, 7, 8])
  153. data_list = [x, y]
  154. rec = submit(add_and_sub_plaintext, self.job_id, data_list=data_list)
  155. for a, a1, b, b1 in rec:
  156. self.assertAlmostEqual(np.linalg.norm((x + y) - a), 0, delta=2 * EPS)
  157. self.assertAlmostEqual(np.linalg.norm((x + y) - a1), 0, delta=2 * EPS)
  158. self.assertAlmostEquals(np.linalg.norm((x - y) - b), 0, delta=2 * EPS)
  159. self.assertAlmostEquals(np.linalg.norm((y - x) - b1), 0, delta=2 * EPS)
  160. def test_mul_plaintext(self):
  161. x = np.random.rand(10, 15)
  162. y = random.randint(1, 10000)
  163. data_list = [x, y]
  164. rec = submit(mul_plaintext, self.job_id, data_list=data_list)
  165. for a, b in rec:
  166. self.assertAlmostEqual(np.linalg.norm((x * y) - a), 0, delta=y * EPS)
  167. self.assertAlmostEqual(np.linalg.norm((x * y) - b), 0, delta=y * EPS)
  168. def test_matmul(self):
  169. j_dim = 15
  170. x = np.random.rand(10, j_dim)
  171. y = np.random.rand(j_dim, 20)
  172. data_list = [x, y]
  173. rec = submit(mat_mul, self.job_id, data_list=data_list)
  174. for a in rec:
  175. self.assertAlmostEqual(np.linalg.norm((x @ y) - a), 0, delta=j_dim * EPS)
  176. def test_einsum(self):
  177. j_dim = 5
  178. k_dim = 4
  179. x = np.random.rand(10, j_dim, k_dim)
  180. y = np.random.rand(k_dim, j_dim, 20)
  181. einsum_expr = "ijk,kjl->il"
  182. data_list = [x, y]
  183. rec = submit(einsum, self.job_id, einsum_expr=einsum_expr, data_list=data_list)
  184. for a in rec:
  185. self.assertAlmostEqual(np.linalg.norm(np.einsum(einsum_expr, x, y) - a), 0, delta=j_dim * k_dim * EPS)