rsa_intersect_guest.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. #
  2. # Copyright 2021 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import gmpy2
  17. from federatedml.statistic.intersect.rsa_intersect.rsa_intersect_base import RsaIntersect
  18. from federatedml.util import consts, LOGGER
  19. class RsaIntersectionGuest(RsaIntersect):
  20. def __init__(self):
  21. super().__init__()
  22. self.role = consts.GUEST
  23. def get_host_prvkey_ids(self):
  24. host_prvkey_ids_list = self.transfer_variable.host_prvkey_ids.get(idx=-1)
  25. LOGGER.info("Get host_prvkey_ids from all host")
  26. return host_prvkey_ids_list
  27. def get_host_filter(self):
  28. host_filter_list = self.transfer_variable.host_filter.get(idx=-1)
  29. LOGGER.info("Get host_filter from all host")
  30. return host_filter_list
  31. def get_host_pubkey_ids(self):
  32. host_pubkey_ids_list = self.transfer_variable.host_pubkey_ids.get(idx=-1)
  33. LOGGER.info("Get host_pubkey_ids from all host")
  34. return host_pubkey_ids_list
  35. def sign_host_ids(self, host_pubkey_ids_list):
  36. # Process(signs) hosts' ids
  37. guest_sign_host_ids_list = [host_pubkey_ids.map(lambda k, v:
  38. (k, self.sign_id(k,
  39. self.d[i],
  40. self.n[i],
  41. self.p[i],
  42. self.q[i],
  43. self.cp[i],
  44. self.cq[i])))
  45. for i, host_pubkey_ids in enumerate(host_pubkey_ids_list)]
  46. LOGGER.info("Sign host_pubkey_ids with guest prv_keys")
  47. return guest_sign_host_ids_list
  48. def send_intersect_ids(self, encrypt_intersect_ids_list, intersect_ids):
  49. if len(self.host_party_id_list) > 1:
  50. for i, host_party_id in enumerate(self.host_party_id_list):
  51. remote_intersect_id = intersect_ids.map(lambda k, v: (v[i], None))
  52. self.transfer_variable.intersect_ids.remote(remote_intersect_id,
  53. role=consts.HOST,
  54. idx=i)
  55. LOGGER.info(f"Remote intersect ids to Host {host_party_id}!")
  56. else:
  57. remote_intersect_id = encrypt_intersect_ids_list[0].mapValues(lambda v: None)
  58. self.transfer_variable.intersect_ids.remote(remote_intersect_id,
  59. role=consts.HOST,
  60. idx=0)
  61. LOGGER.info(f"Remote intersect ids to Host!")
  62. def get_host_intersect_ids(self, guest_prvkey_ids_list):
  63. encrypt_intersect_ids_list = self.transfer_variable.host_intersect_ids.get(idx=-1)
  64. LOGGER.info("Get intersect ids from Host")
  65. intersect_ids_pair_list = [self.extract_intersect_ids(ids,
  66. guest_prvkey_ids_list[i]) for i, ids in
  67. enumerate(encrypt_intersect_ids_list)]
  68. intersect_ids = self.filter_intersect_ids(intersect_ids_pair_list, keep_encrypt_ids=True)
  69. return intersect_ids
  70. def split_calculation_process(self, data_instances):
  71. LOGGER.info("RSA intersect using split calculation.")
  72. # split data
  73. sid_hash_odd = data_instances.filter(lambda k, v: k & 1)
  74. sid_hash_even = data_instances.filter(lambda k, v: not k & 1)
  75. # LOGGER.debug(f"sid_hash_odd count: {sid_hash_odd.count()},"
  76. # f"odd fraction: {sid_hash_odd.count()/data_instances.count()}")
  77. # generate pub keys for even ids
  78. self.generate_protocol_key()
  79. LOGGER.info("Generate guest protocol key!")
  80. # send public key e & n to all host
  81. for i, host_party_id in enumerate(self.host_party_id_list):
  82. guest_public_key = {"e": self.e[i], "n": self.n[i]}
  83. self.transfer_variable.guest_pubkey.remote(guest_public_key,
  84. role=consts.HOST,
  85. idx=i)
  86. LOGGER.info(f"Remote public key to Host {host_party_id}.")
  87. # receive host pub keys for odd ids
  88. host_public_keys = self.transfer_variable.host_pubkey.get(idx=-1)
  89. # LOGGER.debug("Get host_public_key:{} from Host".format(host_public_keys))
  90. LOGGER.info(f"Get host_public_key from Host")
  91. self.rcv_e = [int(public_key["e"]) for public_key in host_public_keys]
  92. self.rcv_n = [int(public_key["n"]) for public_key in host_public_keys]
  93. # encrypt own odd ids with pub keys from host
  94. pubkey_ids_process_list = [self.pubkey_id_process(sid_hash_odd,
  95. fraction=self.random_base_fraction,
  96. random_bit=self.random_bit,
  97. rsa_e=self.rcv_e[i],
  98. rsa_n=self.rcv_n[i]) for i in range(len(self.rcv_e))]
  99. LOGGER.info(f"Perform pubkey_ids_process")
  100. for i, guest_id in enumerate(pubkey_ids_process_list):
  101. mask_guest_id = guest_id.mapValues(lambda v: None)
  102. self.transfer_variable.guest_pubkey_ids.remote(mask_guest_id,
  103. role=consts.HOST,
  104. idx=i)
  105. LOGGER.info(f"Remote guest_pubkey_ids to Host {i}")
  106. # encrypt & send prvkey encrypted guest even ids to host
  107. prvkey_ids_process_pair_list = []
  108. for i, host_party_id in enumerate(self.host_party_id_list):
  109. prvkey_ids_process_pair = self.cal_prvkey_ids_process_pair(sid_hash_even,
  110. self.d[i],
  111. self.n[i],
  112. self.p[i],
  113. self.q[i],
  114. self.cp[i],
  115. self.cq[i])
  116. prvkey_ids_process = prvkey_ids_process_pair.mapValues(lambda v: None)
  117. self.transfer_variable.guest_prvkey_ids.remote(prvkey_ids_process,
  118. role=consts.HOST,
  119. idx=i)
  120. prvkey_ids_process_pair_list.append(prvkey_ids_process_pair)
  121. LOGGER.info(f"Remote guest_prvkey_ids to host {host_party_id}")
  122. # get & sign host pub key encrypted even ids
  123. host_pubkey_ids_list = self.get_host_pubkey_ids()
  124. guest_sign_host_ids_list = self.sign_host_ids(host_pubkey_ids_list)
  125. # send signed host even ids
  126. for i, host_party_id in enumerate(self.host_party_id_list):
  127. self.transfer_variable.guest_sign_host_ids.remote(guest_sign_host_ids_list[i],
  128. role=consts.HOST,
  129. idx=i)
  130. LOGGER.info(f"Remote guest_sign_host_ids to Host {host_party_id}.")
  131. # get prvkey encrypted odd ids from host
  132. host_prvkey_ids_list = self.get_host_prvkey_ids()
  133. # Recv host signed odd ids
  134. # table(guest_pubkey_id, host signed odd ids)
  135. recv_host_sign_guest_ids_list = self.transfer_variable.host_sign_guest_ids.get(idx=-1)
  136. LOGGER.info("Get host_sign_guest_ids from Host")
  137. # table(r^e % n *hash(sid), sid, hash(guest_ids_process/r))
  138. # g[0]=(r^e % n *hash(sid), sid), g[1]=random bits r
  139. host_sign_guest_ids_list = [v.join(recv_host_sign_guest_ids_list[i],
  140. lambda g, r: (g[0], RsaIntersectionGuest.hash(gmpy2.divm(int(r),
  141. int(g[1]),
  142. self.rcv_n[i]),
  143. self.final_hash_operator,
  144. self.rsa_params.salt)))
  145. for i, v in enumerate(pubkey_ids_process_list)]
  146. # table(hash(guest_ids_process/r), sid))
  147. sid_host_sign_guest_ids_list = [g.map(lambda k, v: (v[1], v[0])) for g in host_sign_guest_ids_list]
  148. # get intersect odd ids
  149. # intersect table(hash(guest_ids_process/r), sid)
  150. encrypt_intersect_odd_ids_list = [v.join(host_prvkey_ids_list[i], lambda sid, h: sid) for i, v in
  151. enumerate(sid_host_sign_guest_ids_list)]
  152. intersect_odd_ids = self.filter_intersect_ids(encrypt_intersect_odd_ids_list, keep_encrypt_ids=True)
  153. intersect_even_ids = self.get_host_intersect_ids(prvkey_ids_process_pair_list)
  154. intersect_ids = intersect_odd_ids.union(intersect_even_ids)
  155. if self.sync_intersect_ids:
  156. self.send_intersect_ids(encrypt_intersect_odd_ids_list, intersect_odd_ids)
  157. else:
  158. LOGGER.info("Skip sync intersect ids with Host(s).")
  159. return intersect_ids
  160. def unified_calculation_process(self, data_instances):
  161. LOGGER.info("RSA intersect using unified calculation.")
  162. # receives public key e & n
  163. public_keys = self.transfer_variable.host_pubkey.get(idx=-1)
  164. # LOGGER.debug(f"Get RSA host_public_key:{public_keys} from Host")
  165. LOGGER.info(f"Get RSA host_public_key from Host")
  166. self.rcv_e = [int(public_key["e"]) for public_key in public_keys]
  167. self.rcv_n = [int(public_key["n"]) for public_key in public_keys]
  168. pubkey_ids_process_list = [self.pubkey_id_process(data_instances,
  169. fraction=self.random_base_fraction,
  170. random_bit=self.random_bit,
  171. rsa_e=self.rcv_e[i],
  172. rsa_n=self.rcv_n[i],
  173. hash_operator=self.first_hash_operator,
  174. salt=self.salt) for i in range(len(self.rcv_e))]
  175. LOGGER.info(f"Finish pubkey_ids_process")
  176. for i, guest_id in enumerate(pubkey_ids_process_list):
  177. mask_guest_id = guest_id.mapValues(lambda v: None)
  178. self.transfer_variable.guest_pubkey_ids.remote(mask_guest_id,
  179. role=consts.HOST,
  180. idx=i)
  181. LOGGER.info("Remote guest_pubkey_ids to Host {}".format(i))
  182. host_prvkey_ids_list = self.get_host_prvkey_ids()
  183. LOGGER.info("Get host_prvkey_ids")
  184. # Recv signed guest ids
  185. # table(r^e % n *hash(sid), guest_id_process)
  186. recv_host_sign_guest_ids_list = self.transfer_variable.host_sign_guest_ids.get(idx=-1)
  187. LOGGER.info("Get host_sign_guest_ids from Host")
  188. # table(r^e % n *hash(sid), sid, hash(guest_ids_process/r))
  189. # g[0]=(r^e % n *hash(sid), sid), g[1]=random bits r
  190. host_sign_guest_ids_list = [v.join(recv_host_sign_guest_ids_list[i],
  191. lambda g, r: (g[0], RsaIntersectionGuest.hash(gmpy2.divm(int(r),
  192. int(g[1]),
  193. self.rcv_n[i]),
  194. self.final_hash_operator,
  195. self.rsa_params.salt)))
  196. for i, v in enumerate(pubkey_ids_process_list)]
  197. # table(hash(guest_ids_process/r), sid))
  198. sid_host_sign_guest_ids_list = [g.map(lambda k, v: (v[1], v[0])) for g in host_sign_guest_ids_list]
  199. # intersect table(hash(guest_ids_process/r), sid)
  200. encrypt_intersect_ids_list = [v.join(host_prvkey_ids_list[i], lambda sid, h: sid) for i, v in
  201. enumerate(sid_host_sign_guest_ids_list)]
  202. intersect_ids = self.filter_intersect_ids(encrypt_intersect_ids_list, keep_encrypt_ids=True)
  203. if self.sync_intersect_ids:
  204. self.send_intersect_ids(encrypt_intersect_ids_list, intersect_ids)
  205. else:
  206. LOGGER.info("Skip sync intersect ids with Host(s).")
  207. return intersect_ids
  208. def get_intersect_key(self, party_id):
  209. idx = self.host_party_id_list.index(party_id)
  210. intersect_key = {"rcv_n": str(self.rcv_n[idx]),
  211. "rcv_e": str(self.rcv_e[idx])}
  212. return intersect_key
  213. def load_intersect_key(self, cache_meta):
  214. self.rcv_e, self.rcv_n = [], []
  215. for host_party in self.host_party_id_list:
  216. intersect_key = cache_meta[str(host_party)]["intersect_key"]
  217. self.rcv_e.append(int(intersect_key["rcv_e"]))
  218. self.rcv_n.append(int(intersect_key["rcv_n"]))
  219. def run_cardinality(self, data_instances):
  220. LOGGER.info(f"run cardinality_only with RSA")
  221. # receives public key e & n
  222. public_keys = self.transfer_variable.host_pubkey.get(idx=-1)
  223. LOGGER.info(f"Get RSA host_public_key from Host")
  224. self.rcv_e = [int(public_key["e"]) for public_key in public_keys]
  225. self.rcv_n = [int(public_key["n"]) for public_key in public_keys]
  226. pubkey_ids_process_list = [self.pubkey_id_process(data_instances,
  227. fraction=self.random_base_fraction,
  228. random_bit=self.random_bit,
  229. rsa_e=self.rcv_e[i],
  230. rsa_n=self.rcv_n[i],
  231. hash_operator=self.first_hash_operator,
  232. salt=self.salt) for i in range(len(self.rcv_e))]
  233. LOGGER.info(f"Finish pubkey_ids_process")
  234. for i, guest_id in enumerate(pubkey_ids_process_list):
  235. mask_guest_id = guest_id.mapValues(lambda v: None)
  236. self.transfer_variable.guest_pubkey_ids.remote(mask_guest_id,
  237. role=consts.HOST,
  238. idx=i)
  239. LOGGER.info("Remote guest_pubkey_ids to Host {}".format(i))
  240. host_filter_list = self.get_host_filter()
  241. LOGGER.info("Get host_filter_list")
  242. # Recv signed guest ids
  243. # table(r^e % n *hash(sid), guest_id_process)
  244. recv_host_sign_guest_ids_list = self.transfer_variable.host_sign_guest_ids.get(idx=-1)
  245. LOGGER.info("Get host_sign_guest_ids from Host")
  246. # table(r^e % n *hash(sid), sid, hash(guest_ids_process/r))
  247. # g[0]=(r^e % n *hash(sid), sid), g[1]=random bits r
  248. host_sign_guest_ids_list = [v.join(recv_host_sign_guest_ids_list[i],
  249. lambda g, r: (g[0], RsaIntersectionGuest.hash(gmpy2.divm(int(r),
  250. int(g[1]),
  251. self.rcv_n[i]),
  252. self.final_hash_operator,
  253. self.rsa_params.salt)))
  254. for i, v in enumerate(pubkey_ids_process_list)]
  255. # table(hash(guest_ids_process/r), sid))
  256. # sid_host_sign_guest_ids_list = [g.map(lambda k, v: (v[1], v[0])) for g in host_sign_guest_ids_list]
  257. # filter ids
  258. intersect_ids_list = [host_sign_guest_ids_list[i].filter(lambda k, v: host_filter_list[i].check(v[1]))
  259. for i in range(len(self.host_party_id_list))]
  260. intersect_ids_list = [ids.map(lambda k, v: (v[0], None)) for ids in intersect_ids_list]
  261. intersect_ids = self.get_common_intersection(intersect_ids_list)
  262. self.intersect_num = intersect_ids.count()
  263. if self.sync_cardinality:
  264. self.transfer_variable.cardinality.remote(self.intersect_num, role=consts.HOST, idx=-1)
  265. LOGGER.info("Sent intersect cardinality to host.")
  266. else:
  267. LOGGER.info("Skip sync intersect cardinality with host(s)")
  268. def generate_cache(self, data_instances):
  269. LOGGER.info("Run RSA intersect cache")
  270. # receives public key e & n
  271. public_keys = self.transfer_variable.host_pubkey.get(idx=-1)
  272. # LOGGER.debug(f"Get RSA host_public_key:{public_keys} from Host")
  273. LOGGER.info(f"Get RSA host_public_key from Host")
  274. self.rcv_e = [int(public_key["e"]) for public_key in public_keys]
  275. self.rcv_n = [int(public_key["n"]) for public_key in public_keys]
  276. cache_id_list = self.cache_transfer_variable.get(idx=-1)
  277. LOGGER.info(f"Get cache_id from all host")
  278. host_prvkey_ids_list = self.get_host_prvkey_ids()
  279. LOGGER.info("Get host_prvkey_ids")
  280. cache_data, cache_meta = {}, {}
  281. intersect_meta = self.get_intersect_method_meta()
  282. for i, party_id in enumerate(self.host_party_id_list):
  283. meta = {"cache_id": cache_id_list[i],
  284. "intersect_meta": intersect_meta,
  285. "intersect_key": self.get_intersect_key(party_id)
  286. }
  287. cache_meta[party_id] = meta
  288. cache_data[party_id] = host_prvkey_ids_list[i]
  289. return cache_data, cache_meta
  290. def cache_unified_calculation_process(self, data_instances, cache_data):
  291. LOGGER.info("RSA intersect using cache.")
  292. pubkey_ids_process_list = [self.pubkey_id_process(data_instances,
  293. fraction=self.random_base_fraction,
  294. random_bit=self.random_bit,
  295. rsa_e=self.rcv_e[i],
  296. rsa_n=self.rcv_n[i],
  297. hash_operator=self.first_hash_operator,
  298. salt=self.salt) for i in range(len(self.rcv_e))]
  299. LOGGER.info(f"Finish pubkey_ids_process")
  300. for i, guest_id in enumerate(pubkey_ids_process_list):
  301. mask_guest_id = guest_id.mapValues(lambda v: None)
  302. self.transfer_variable.guest_pubkey_ids.remote(mask_guest_id,
  303. role=consts.HOST,
  304. idx=i)
  305. LOGGER.info("Remote guest_pubkey_ids to Host {}".format(i))
  306. # Recv signed guest ids
  307. # table(r^e % n *hash(sid), guest_id_process)
  308. recv_host_sign_guest_ids_list = self.transfer_variable.host_sign_guest_ids.get(idx=-1)
  309. LOGGER.info("Get host_sign_guest_ids from Host")
  310. # table(r^e % n *hash(sid), sid, hash(guest_ids_process/r))
  311. # g[0]=(r^e % n *hash(sid), sid), g[1]=random bits r
  312. host_sign_guest_ids_list = [v.join(recv_host_sign_guest_ids_list[i],
  313. lambda g, r: (g[0], RsaIntersectionGuest.hash(gmpy2.divm(int(r),
  314. int(g[1]),
  315. self.rcv_n[i]),
  316. self.final_hash_operator,
  317. self.rsa_params.salt)))
  318. for i, v in enumerate(pubkey_ids_process_list)]
  319. # table(hash(guest_ids_process/r), sid))
  320. sid_host_sign_guest_ids_list = [g.map(lambda k, v: (v[1], v[0])) for g in host_sign_guest_ids_list]
  321. # intersect table(hash(guest_ids_process/r), sid)
  322. host_prvkey_ids_list = self.extract_cache_list(cache_data, self.host_party_id_list)
  323. encrypt_intersect_ids_list = [v.join(host_prvkey_ids_list[i], lambda sid, h: sid) for i, v in
  324. enumerate(sid_host_sign_guest_ids_list)]
  325. intersect_ids = self.filter_intersect_ids(encrypt_intersect_ids_list, keep_encrypt_ids=True)
  326. if self.sync_intersect_ids:
  327. self.send_intersect_ids(encrypt_intersect_ids_list, intersect_ids)
  328. else:
  329. LOGGER.info("Skip sync intersect ids with Host(s).")
  330. return intersect_ids