spdz_test.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. import numpy as np
  19. import time
  20. from prettytable import PrettyTable, ORGMODE
  21. from fate_arch.session import computing_session as session, get_parties
  22. from federatedml.secureprotol.spdz import SPDZ
  23. from federatedml.model_base import ModelBase, ComponentOutput
  24. from federatedml.test.spdz_test.spdz_test_param import SPDZTestParam
  25. from federatedml.util import LOGGER
  26. from federatedml.secureprotol.spdz.tensor.fixedpoint_table import FixedPointTensor as TableTensor
  27. from federatedml.secureprotol.spdz.tensor.fixedpoint_numpy import FixedPointTensor as NumpyTensor
  28. class SPDZTest(ModelBase):
  29. def __init__(self):
  30. super(SPDZTest, self).__init__()
  31. self.data_num = None
  32. self.data_partition = None
  33. self.seed = None
  34. self.test_round = None
  35. self.tracker = None
  36. """plaintest data"""
  37. self.int_data_x = None
  38. self.float_data_x = None
  39. self.int_data_y = None
  40. self.float_data_y = None
  41. self.model_param = SPDZTestParam()
  42. self.parties = None
  43. self.local_party = None
  44. self.other_party = None
  45. self._set_parties()
  46. self.metric = None
  47. self.operation = None
  48. self.test_count = None
  49. self.op_test_list = ["float_add", "int_add", "float_sub", "int_sub", "float_dot", "int_dot"]
  50. self._summary = {"op_test_list": self.op_test_list,
  51. "tensor_type": ["numpy", "table"],
  52. "numpy": {},
  53. "table": {}}
  54. def _init_runtime_parameters(self, cpn_input):
  55. self.model_param.update(cpn_input.parameters)
  56. self.tracker = cpn_input.tracker
  57. self._init_model()
  58. def _init_model(self):
  59. self.data_num = self.model_param.data_num
  60. self.data_partition = self.model_param.data_partition
  61. self.seed = self.model_param.seed
  62. self.test_round = self.model_param.test_round
  63. self.data_lower_bound = self.model_param.data_lower_bound
  64. self.data_upper_bound = self.model_param.data_upper_bound
  65. self.data_lower_bound = 0
  66. self.data_upper_bound = 100
  67. def _set_parties(self):
  68. parties = []
  69. guest_parties = get_parties().roles_to_parties(["guest"])
  70. host_parties = get_parties().roles_to_parties(["host"])
  71. parties.extend(guest_parties)
  72. parties.extend(host_parties)
  73. local_party = get_parties().local_party
  74. other_party = parties[0] if parties[0] != local_party else parties[1]
  75. self.parties = parties
  76. self.local_party = local_party
  77. self.other_party = other_party
  78. def _init_data(self):
  79. np.random.seed(self.seed)
  80. self.int_data_x = np.random.randint(int(self.data_lower_bound), int(self.data_upper_bound), size=self.data_num)
  81. self.float_data_x = np.random.uniform(self.data_lower_bound, self.data_upper_bound, size=self.data_num)
  82. self.int_data_y = np.random.randint(int(self.data_lower_bound), int(self.data_upper_bound), size=self.data_num)
  83. self.float_data_y = np.random.uniform(self.data_lower_bound, self.data_upper_bound, size=self.data_num)
  84. def _test_spdz(self):
  85. table_list = []
  86. table_int_data_x, table_float_data_x = None, None
  87. table_int_data_y, table_float_data_y = None, None
  88. if self.local_party.role == "guest":
  89. table_int_data_x = session.parallelize(self.int_data_x,
  90. include_key=False,
  91. partition=self.data_partition)
  92. table_int_data_x = table_int_data_x.mapValues(lambda x: np.array([x]))
  93. table_float_data_x = session.parallelize(self.float_data_x,
  94. include_key=False,
  95. partition=self.data_partition)
  96. table_float_data_x = table_float_data_x.mapValues(lambda x: np.array([x]))
  97. else:
  98. table_int_data_y = session.parallelize(self.int_data_y,
  99. include_key=False,
  100. partition=self.data_partition)
  101. table_int_data_y = table_int_data_y.mapValues(lambda y: np.array([y]))
  102. table_float_data_y = session.parallelize(self.float_data_y,
  103. include_key=False,
  104. partition=self.data_partition)
  105. table_float_data_y = table_float_data_y.mapValues(lambda y: np.array([y]))
  106. for tensor_type in ["numpy", "table"]:
  107. table = PrettyTable()
  108. table.set_style(ORGMODE)
  109. field_name = ["DataType", "One time consumption", f"{self.data_num} times consumption",
  110. "relative acc", "log2 acc", "operations per second"]
  111. self._summary["field_name"] = field_name
  112. table.field_names = field_name
  113. with SPDZ(local_party=self.local_party, all_parties=self.parties) as spdz:
  114. for op_type in self.op_test_list:
  115. start_time = time.time()
  116. for epoch in range(self.test_round):
  117. LOGGER.info(f"test spdz, tensor_type: {tensor_type}, op_type: {op_type}, epoch: {epoch}")
  118. tag = "_".join([tensor_type, op_type, str(epoch)])
  119. spdz.set_flowid(tag)
  120. if self.local_party.role == "guest":
  121. if tensor_type == "table":
  122. if op_type.startswith("int"):
  123. fixed_point_x = TableTensor.from_source("int_x_" + tag, table_int_data_x)
  124. fixed_point_y = TableTensor.from_source("int_y_" + tag, self.other_party)
  125. else:
  126. fixed_point_x = TableTensor.from_source("float_x_" + tag, table_float_data_x)
  127. fixed_point_y = TableTensor.from_source("float_y_" + tag, self.other_party)
  128. else:
  129. if op_type.startswith("int"):
  130. fixed_point_x = NumpyTensor.from_source("int_x_" + tag, self.int_data_x)
  131. fixed_point_y = NumpyTensor.from_source("int_y_" + tag, self.other_party)
  132. else:
  133. fixed_point_x = NumpyTensor.from_source("float_x_" + tag, self.float_data_x)
  134. fixed_point_y = NumpyTensor.from_source("float_y_" + tag, self.other_party)
  135. else:
  136. if tensor_type == "table":
  137. if op_type.startswith("int"):
  138. fixed_point_y = TableTensor.from_source("int_y_" + tag, table_int_data_y)
  139. fixed_point_x = TableTensor.from_source("int_x_" + tag, self.other_party)
  140. else:
  141. fixed_point_y = TableTensor.from_source("float_y_" + tag, table_float_data_y)
  142. fixed_point_x = TableTensor.from_source("float_x_" + tag, self.other_party)
  143. else:
  144. if op_type.startswith("int"):
  145. fixed_point_y = NumpyTensor.from_source("int_y_" + tag, self.int_data_y)
  146. fixed_point_x = NumpyTensor.from_source("int_x_" + tag, self.other_party)
  147. else:
  148. fixed_point_y = NumpyTensor.from_source("float_y_" + tag, self.float_data_y)
  149. fixed_point_x = NumpyTensor.from_source("float_x_" + tag, self.other_party)
  150. ret = self.calculate_ret(op_type, tensor_type, fixed_point_x, fixed_point_y)
  151. total_time = time.time() - start_time
  152. self.output_table(op_type, table, tensor_type, total_time, ret)
  153. table_list.append(table)
  154. self.tracker.log_component_summary(self._summary)
  155. for table in table_list:
  156. LOGGER.info(table)
  157. def calculate_ret(self, op_type, tensor_type,
  158. fixed_point_x, fixed_point_y,
  159. ):
  160. if op_type.endswith("add"):
  161. ret = (fixed_point_x + fixed_point_y).get()
  162. elif op_type.endswith("sub"):
  163. ret = (fixed_point_x - fixed_point_y).get()
  164. else:
  165. ret = (fixed_point_x.dot(fixed_point_y)).get()[0]
  166. if tensor_type == "table":
  167. ret = ret[0]
  168. if tensor_type == "table" and not op_type.endswith("dot"):
  169. arr = [None] * self.data_num
  170. for k, v in ret.collect():
  171. arr[k] = v[0]
  172. ret = np.array(arr)
  173. return ret
  174. def output_table(self, op_type, table, tensor_type, total_time, spdz_ret):
  175. if op_type.startswith("int"):
  176. data_x = self.int_data_x
  177. data_y = self.int_data_y
  178. else:
  179. data_x = self.float_data_x
  180. data_y = self.float_data_y
  181. numpy_ret = None
  182. if op_type.endswith("add") or op_type.endswith("sub"):
  183. start = time.time()
  184. for i in range(self.test_round):
  185. if op_type.endswith("add"):
  186. numpy_ret = data_x + data_y
  187. else:
  188. numpy_ret = data_x - data_y
  189. plain_text_time = time.time() - start
  190. relative_acc = 0
  191. for np_x, spdz_x in zip(numpy_ret, spdz_ret):
  192. relative_acc += abs(np_x - spdz_x) / max(abs(np_x), abs(spdz_x) + 1e-15)
  193. else:
  194. start = time.time()
  195. for i in range(self.test_round):
  196. numpy_ret = np.dot(data_x, data_y)
  197. plain_text_time = time.time() - start
  198. relative_acc = abs(numpy_ret - spdz_ret) / max(abs(numpy_ret), abs(spdz_ret))
  199. relative_acc /= self.data_num
  200. log2_acc = -np.log2(relative_acc) if relative_acc != 0 else 0
  201. row_info = [op_type, total_time / self.data_num / self.test_round, total_time / self.test_round,
  202. relative_acc, log2_acc, int(self.data_num * self.test_round / total_time)]
  203. table.add_row(row_info)
  204. self._summary[tensor_type][op_type] = row_info
  205. return table.get_string(title=f"SPDZ {tensor_type} Computational performance")
  206. def run(self, cpn_input):
  207. LOGGER.info("begin to init parameters of secure add example")
  208. self._init_runtime_parameters(cpn_input)
  209. LOGGER.info("begin to make data")
  210. self._init_data()
  211. self._test_spdz()
  212. return ComponentOutput(self.save_data(), self.export_model(), self.save_cache())