rsa_intersect_host.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. #
  2. # Copyright 2021 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import gmpy2
  17. import uuid
  18. from federatedml.statistic.intersect.rsa_intersect.rsa_intersect_base import RsaIntersect
  19. from federatedml.util import consts, LOGGER
  20. class RsaIntersectionHost(RsaIntersect):
  21. def __init__(self):
  22. super().__init__()
  23. self.role = consts.HOST
  24. def split_calculation_process(self, data_instances):
  25. LOGGER.info("RSA intersect using split calculation.")
  26. # split data
  27. sid_hash_odd = data_instances.filter(lambda k, v: k & 1)
  28. sid_hash_even = data_instances.filter(lambda k, v: not k & 1)
  29. # LOGGER.debug(f"sid_hash_odd count: {sid_hash_odd.count()},"
  30. # f"odd fraction: {sid_hash_odd.count()/data_instances.count()}")
  31. # generate rsa keys
  32. # self.e, self.d, self.n = self.generate_protocol_key()
  33. self.generate_protocol_key()
  34. LOGGER.info("Generate host protocol key!")
  35. public_key = {"e": self.e, "n": self.n}
  36. # sends public key e & n to guest
  37. self.transfer_variable.host_pubkey.remote(public_key,
  38. role=consts.GUEST,
  39. idx=0)
  40. LOGGER.info("Remote public key to Guest.")
  41. # generate ri for even ids
  42. # count = sid_hash_even.count()
  43. # self.r = self.generate_r_base(self.random_bit, count, self.random_base_fraction)
  44. # LOGGER.info(f"Generate {len(self.r)} r values.")
  45. # receive guest key for even ids
  46. guest_public_key = self.transfer_variable.guest_pubkey.get(idx=0)
  47. # LOGGER.debug("Get guest_public_key:{} from Guest".format(guest_public_key))
  48. LOGGER.info(f"Get guest_public_key from Guest")
  49. self.rcv_e = int(guest_public_key["e"])
  50. self.rcv_n = int(guest_public_key["n"])
  51. # encrypt & send guest pubkey-encrypted odd ids
  52. pubkey_ids_process = self.pubkey_id_process(sid_hash_even,
  53. fraction=self.random_base_fraction,
  54. random_bit=self.random_bit,
  55. rsa_e=self.rcv_e,
  56. rsa_n=self.rcv_n)
  57. LOGGER.info(f"Finish pubkey_ids_process")
  58. mask_host_id = pubkey_ids_process.mapValues(lambda v: None)
  59. self.transfer_variable.host_pubkey_ids.remote(mask_host_id,
  60. role=consts.GUEST,
  61. idx=0)
  62. LOGGER.info("Remote host_pubkey_ids to Guest")
  63. # encrypt & send prvkey-encrypted host odd ids to guest
  64. prvkey_ids_process_pair = self.cal_prvkey_ids_process_pair(sid_hash_odd,
  65. self.d,
  66. self.n,
  67. self.p,
  68. self.q,
  69. self.cp,
  70. self.cq)
  71. prvkey_ids_process = prvkey_ids_process_pair.mapValues(lambda v: None)
  72. self.transfer_variable.host_prvkey_ids.remote(prvkey_ids_process,
  73. role=consts.GUEST,
  74. idx=0)
  75. LOGGER.info("Remote host_prvkey_ids to Guest.")
  76. # get & sign guest pubkey-encrypted odd ids
  77. guest_pubkey_ids = self.transfer_variable.guest_pubkey_ids.get(idx=0)
  78. LOGGER.info(f"Get guest_pubkey_ids from guest")
  79. host_sign_guest_ids = guest_pubkey_ids.map(lambda k, v: (k, self.sign_id(k,
  80. self.d,
  81. self.n,
  82. self.p,
  83. self.q,
  84. self.cp,
  85. self.cq)))
  86. LOGGER.debug(f"host sign guest_pubkey_ids")
  87. # send signed guest odd ids
  88. self.transfer_variable.host_sign_guest_ids.remote(host_sign_guest_ids,
  89. role=consts.GUEST,
  90. idx=0)
  91. LOGGER.info("Remote host_sign_guest_ids_process to Guest.")
  92. # recv guest privkey-encrypted even ids
  93. guest_prvkey_ids = self.transfer_variable.guest_prvkey_ids.get(idx=0)
  94. LOGGER.info("Get guest_prvkey_ids")
  95. # receive guest-signed host even ids
  96. recv_guest_sign_host_ids = self.transfer_variable.guest_sign_host_ids.get(idx=0)
  97. LOGGER.info(f"Get guest_sign_host_ids from Guest.")
  98. guest_sign_host_ids = pubkey_ids_process.join(recv_guest_sign_host_ids,
  99. lambda g, r: (g[0],
  100. RsaIntersectionHost.hash(gmpy2.divm(int(r),
  101. int(g[1]),
  102. self.rcv_n),
  103. self.final_hash_operator,
  104. self.rsa_params.salt)))
  105. sid_guest_sign_host_ids = guest_sign_host_ids.map(lambda k, v: (v[1], v[0]))
  106. encrypt_intersect_even_ids = sid_guest_sign_host_ids.join(guest_prvkey_ids, lambda sid, h: sid)
  107. # filter & send intersect even ids
  108. intersect_even_ids = self.filter_intersect_ids([encrypt_intersect_even_ids])
  109. remote_intersect_even_ids = encrypt_intersect_even_ids.mapValues(lambda v: None)
  110. self.transfer_variable.host_intersect_ids.remote(remote_intersect_even_ids, role=consts.GUEST, idx=0)
  111. LOGGER.info(f"Remote host intersect ids to Guest")
  112. # recv intersect ids
  113. intersect_ids = None
  114. if self.sync_intersect_ids:
  115. encrypt_intersect_odd_ids = self.transfer_variable.intersect_ids.get(idx=0)
  116. intersect_odd_ids_pair = encrypt_intersect_odd_ids.join(prvkey_ids_process_pair, lambda e, h: h)
  117. intersect_odd_ids = intersect_odd_ids_pair.map(lambda k, v: (v, None))
  118. intersect_ids = intersect_odd_ids.union(intersect_even_ids)
  119. LOGGER.info("Get intersect ids from Guest")
  120. return intersect_ids
  121. def unified_calculation_process(self, data_instances):
  122. LOGGER.info("RSA intersect using unified calculation.")
  123. # generate rsa keys
  124. # self.e, self.d, self.n = self.generate_protocol_key()
  125. self.generate_protocol_key()
  126. LOGGER.info("Generate protocol key!")
  127. public_key = {"e": self.e, "n": self.n}
  128. # sends public key e & n to guest
  129. self.transfer_variable.host_pubkey.remote(public_key,
  130. role=consts.GUEST,
  131. idx=0)
  132. LOGGER.info("Remote public key to Guest.")
  133. # hash host ids
  134. prvkey_ids_process_pair = self.cal_prvkey_ids_process_pair(data_instances,
  135. self.d,
  136. self.n,
  137. self.p,
  138. self.q,
  139. self.cp,
  140. self.cq,
  141. self.first_hash_operator)
  142. prvkey_ids_process = prvkey_ids_process_pair.mapValues(lambda v: None)
  143. self.transfer_variable.host_prvkey_ids.remote(prvkey_ids_process,
  144. role=consts.GUEST,
  145. idx=0)
  146. LOGGER.info("Remote host_ids_process to Guest.")
  147. # Recv guest ids
  148. guest_pubkey_ids = self.transfer_variable.guest_pubkey_ids.get(idx=0)
  149. LOGGER.info("Get guest_pubkey_ids from guest")
  150. # Process(signs) guest ids and return to guest
  151. host_sign_guest_ids = guest_pubkey_ids.map(lambda k, v: (k, self.sign_id(k,
  152. self.d,
  153. self.n,
  154. self.p,
  155. self.q,
  156. self.cp,
  157. self.cq)))
  158. self.transfer_variable.host_sign_guest_ids.remote(host_sign_guest_ids,
  159. role=consts.GUEST,
  160. idx=0)
  161. LOGGER.info("Remote host_sign_guest_ids_process to Guest.")
  162. # recv intersect ids
  163. intersect_ids = None
  164. if self.sync_intersect_ids:
  165. encrypt_intersect_ids = self.transfer_variable.intersect_ids.get(idx=0)
  166. intersect_ids_pair = encrypt_intersect_ids.join(prvkey_ids_process_pair, lambda e, h: h)
  167. intersect_ids = intersect_ids_pair.map(lambda k, v: (v, None))
  168. LOGGER.info("Get intersect ids from Guest")
  169. return intersect_ids
  170. def get_intersect_key(self, party_id=None):
  171. intersect_key = {"e": str(self.e),
  172. "d": str(self.d),
  173. "n": str(self.n),
  174. "p": str(self.p),
  175. "q": str(self.q),
  176. "cp": str(self.cp),
  177. "cq": str(self.cq)}
  178. return intersect_key
  179. def load_intersect_key(self, cache_meta):
  180. intersect_key = cache_meta[str(self.guest_party_id)]["intersect_key"]
  181. self.e = int(intersect_key["e"])
  182. self.d = int(intersect_key["d"])
  183. self.n = int(intersect_key["n"])
  184. self.p = int(intersect_key["p"])
  185. self.q = int(intersect_key["q"])
  186. self.cp = int(intersect_key["cp"])
  187. self.cq = int(intersect_key["cq"])
  188. def run_cardinality(self, data_instances):
  189. LOGGER.info(f"run cardinality_only with RSA")
  190. # generate rsa keys
  191. self.generate_protocol_key()
  192. LOGGER.info("Generate protocol key!")
  193. public_key = {"e": self.e, "n": self.n}
  194. # sends public key e & n to guest
  195. self.transfer_variable.host_pubkey.remote(public_key,
  196. role=consts.GUEST,
  197. idx=0)
  198. LOGGER.info("Remote public key to Guest.")
  199. # hash host ids
  200. prvkey_ids_process_pair = self.cal_prvkey_ids_process_pair(data_instances,
  201. self.d,
  202. self.n,
  203. self.p,
  204. self.q,
  205. self.cp,
  206. self.cq,
  207. self.first_hash_operator)
  208. filter = self.construct_filter(prvkey_ids_process_pair,
  209. false_positive_rate=self.intersect_preprocess_params.false_positive_rate,
  210. hash_method=self.intersect_preprocess_params.hash_method,
  211. random_state=self.intersect_preprocess_params.random_state)
  212. self.filter = filter
  213. self.transfer_variable.host_filter.remote(filter,
  214. role=consts.GUEST,
  215. idx=0)
  216. LOGGER.info("Remote host_filter to Guest.")
  217. # Recv guest ids
  218. guest_pubkey_ids = self.transfer_variable.guest_pubkey_ids.get(idx=0)
  219. LOGGER.info("Get guest_pubkey_ids from guest")
  220. # Process(signs) guest ids and return to guest
  221. host_sign_guest_ids = guest_pubkey_ids.map(lambda k, v: (k, self.sign_id(k,
  222. self.d,
  223. self.n,
  224. self.p,
  225. self.q,
  226. self.cp,
  227. self.cq)))
  228. self.transfer_variable.host_sign_guest_ids.remote(host_sign_guest_ids,
  229. role=consts.GUEST,
  230. idx=0)
  231. LOGGER.info("Remote host_sign_guest_ids_process to Guest.")
  232. if self.sync_cardinality:
  233. self.intersect_num = self.transfer_variable.cardinality.get(idx=0)
  234. LOGGER.info("Got intersect cardinality from guest.")
  235. def generate_cache(self, data_instances):
  236. LOGGER.info("Run RSA intersect cache.")
  237. # generate rsa keys
  238. # self.e, self.d, self.n = self.generate_protocol_key()
  239. self.generate_protocol_key()
  240. LOGGER.info("Generate protocol key!")
  241. public_key = {"e": self.e, "n": self.n}
  242. # sends public key e & n to guest
  243. self.transfer_variable.host_pubkey.remote(public_key,
  244. role=consts.GUEST,
  245. idx=0)
  246. LOGGER.info("Remote public key to Guest.")
  247. # hash host ids
  248. prvkey_ids_process_pair = self.cal_prvkey_ids_process_pair(data_instances,
  249. self.d,
  250. self.n,
  251. self.p,
  252. self.q,
  253. self.cp,
  254. self.cq,
  255. self.first_hash_operator)
  256. prvkey_ids_process = prvkey_ids_process_pair.mapValues(lambda v: None)
  257. cache_id = str(uuid.uuid4())
  258. # self.cache_id = {self.guest_party_id: cache_id}
  259. # cache_schema = {"cache_id": cache_id}
  260. # self.cache = prvkey_ids_process_pair
  261. # prvkey_ids_process.schema = cache_schema
  262. self.cache_transfer_variable.remote(cache_id, role=consts.GUEST, idx=0)
  263. LOGGER.info(f"remote cache_id to guest")
  264. self.transfer_variable.host_prvkey_ids.remote(prvkey_ids_process,
  265. role=consts.GUEST,
  266. idx=0)
  267. LOGGER.info("Remote host_ids_process to Guest.")
  268. # prvkey_ids_process_pair.schema = cache_schema
  269. cache_data = {self.guest_party_id: prvkey_ids_process_pair}
  270. cache_meta = {self.guest_party_id: {"cache_id": cache_id,
  271. "intersect_meta": self.get_intersect_method_meta(),
  272. "intersect_key": self.get_intersect_key()}}
  273. return cache_data, cache_meta
  274. def cache_unified_calculation_process(self, data_instances, cache_data):
  275. LOGGER.info("RSA intersect using cache.")
  276. cache = self.extract_cache_list(cache_data, self.guest_party_id)[0]
  277. # Recv guest ids
  278. guest_pubkey_ids = self.transfer_variable.guest_pubkey_ids.get(idx=0)
  279. LOGGER.info("Get guest_pubkey_ids from guest")
  280. # Process(signs) guest ids and return to guest
  281. host_sign_guest_ids = guest_pubkey_ids.map(lambda k, v: (k, self.sign_id(k,
  282. self.d,
  283. self.n,
  284. self.p,
  285. self.q,
  286. self.cp,
  287. self.cq)))
  288. self.transfer_variable.host_sign_guest_ids.remote(host_sign_guest_ids,
  289. role=consts.GUEST,
  290. idx=0)
  291. LOGGER.info("Remote host_sign_guest_ids_process to Guest.")
  292. # recv intersect ids
  293. intersect_ids = None
  294. if self.sync_intersect_ids:
  295. encrypt_intersect_ids = self.transfer_variable.intersect_ids.get(idx=0)
  296. intersect_ids_pair = encrypt_intersect_ids.join(cache, lambda e, h: h)
  297. intersect_ids = intersect_ids_pair.map(lambda k, v: (v, None))
  298. LOGGER.info("Get intersect ids from Guest")
  299. return intersect_ids