ecdh_intersect_host.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. #
  2. # Copyright 2021 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import uuid
  17. from federatedml.statistic.intersect.ecdh_intersect.ecdh_intersect_base import EcdhIntersect
  18. from federatedml.util import consts, LOGGER
  19. class EcdhIntersectionHost(EcdhIntersect):
  20. def __init__(self):
  21. super().__init__()
  22. self.role = consts.HOST
  23. self.id_local_first = None
  24. def _exchange_id(self, id, replace_val=True):
  25. if replace_val:
  26. id_only = id.mapValues(lambda v: None)
  27. else:
  28. id_only = id
  29. self.transfer_variable.id_ciphertext_exchange_h2g.remote(id_only,
  30. role=consts.GUEST,
  31. idx=0)
  32. LOGGER.info("sent id 1st ciphertext list to guest")
  33. id_guest = self.transfer_variable.id_ciphertext_exchange_g2h.get(idx=0)
  34. LOGGER.info("got id 1st ciphertext list from guest")
  35. return id_guest
  36. def _sync_doubly_encrypted_id(self, id):
  37. self.transfer_variable.doubly_encrypted_id.remote(id,
  38. role=consts.GUEST,
  39. idx=0)
  40. LOGGER.info("sent doubly encrypted id list to guest")
  41. def get_intersect_ids(self):
  42. first_cipher_intersect_ids = self.transfer_variable.intersect_ids.get(idx=0)
  43. LOGGER.info(f"obtained cipher intersect ids from guest")
  44. intersect_ids = self.map_encrypt_id_to_raw_id(first_cipher_intersect_ids,
  45. self.id_local_first,
  46. keep_encrypt_id=False)
  47. return intersect_ids
  48. def get_intersect_doubly_encrypted_id(self, data_instances, keep_key=True):
  49. self.init_curve()
  50. LOGGER.info(f"curve instance obtained")
  51. # 1st ID encrypt: (Eh, (h, Instance))
  52. self.id_local_first = self._encrypt_id(data_instances,
  53. self.curve_instance,
  54. reserve_original_key=keep_key,
  55. hash_operator=self.hash_operator,
  56. salt=self.salt,
  57. reserve_original_value=keep_key)
  58. LOGGER.info("encrypted local id for the 1st time")
  59. # send (Eh, -1), get (Eg, -1)
  60. id_remote_first = self._exchange_id(self.id_local_first, keep_key)
  61. # 2nd ID encrypt & send doubly encrypted guest ID list to guest
  62. id_remote_second = self._sign_id(id_remote_first,
  63. self.curve_instance,
  64. reserve_original_key=keep_key) # (EEg, Eg)
  65. LOGGER.info("encrypted guest id for the 2nd time")
  66. self._sync_doubly_encrypted_id(id_remote_second)
  67. def decrypt_intersect_doubly_encrypted_id(self, id_intersect_cipher_cipher=None):
  68. intersect_ids = None
  69. if self.sync_intersect_ids:
  70. intersect_ids = self.get_intersect_ids()
  71. return intersect_ids
  72. def get_intersect_key(self, party_id=None):
  73. intersect_key = {"curve_key": self.curve_instance.get_curve_key().decode("latin1")}
  74. return intersect_key
  75. def load_intersect_key(self, cache_meta):
  76. intersect_key = cache_meta[str(self.guest_party_id)]["intersect_key"]
  77. curve_key = intersect_key["curve_key"].encode("latin1")
  78. self.init_curve(curve_key)
  79. def generate_cache(self, data_instances):
  80. self.init_curve()
  81. LOGGER.info(f"curve instance obtained")
  82. cache_id = str(uuid.uuid4())
  83. self.cache_id = {self.guest_party_id: cache_id}
  84. self.cache_transfer_variable.remote(cache_id, role=consts.GUEST, idx=0)
  85. LOGGER.info(f"remote cache_id to guest")
  86. # 1st ID encrypt: (Eh, (h, Instance))
  87. id_local_first = self._encrypt_id(data_instances,
  88. self.curve_instance,
  89. reserve_original_key=True,
  90. hash_operator=self.hash_operator,
  91. salt=self.salt,
  92. reserve_original_value=True)
  93. LOGGER.info("encrypted local id for the 1st time")
  94. id_only = id_local_first.mapValues(lambda v: None)
  95. self.transfer_variable.id_ciphertext_exchange_h2g.remote(id_only,
  96. role=consts.GUEST,
  97. idx=0)
  98. LOGGER.info("sent id 1st ciphertext list to guest")
  99. cache_data = {self.guest_party_id: id_local_first}
  100. cache_meta = {self.guest_party_id: {"cache_id": cache_id,
  101. "intersect_meta": self.get_intersect_method_meta(),
  102. "intersect_key": self.get_intersect_key()}}
  103. return cache_data, cache_meta
  104. def get_intersect_doubly_encrypted_id_from_cache(self, data_instances, cache_data):
  105. id_remote_first = self.transfer_variable.id_ciphertext_exchange_g2h.get(idx=0)
  106. LOGGER.info("got id 1st ciphertext from guest")
  107. # 2nd ID encrypt & send doubly encrypted guest ID to guest
  108. id_remote_second = self._sign_id(id_remote_first,
  109. self.curve_instance,
  110. reserve_original_key=True) # (EEg, Eg)
  111. LOGGER.info("encrypted guest id for the 2nd time")
  112. self.id_local_first = self.extract_cache_list(cache_data, self.guest_party_id)[0]
  113. self._sync_doubly_encrypted_id(id_remote_second)
  114. def run_cardinality(self, data_instances):
  115. LOGGER.info(f"run exact_cardinality with DH")
  116. self.get_intersect_doubly_encrypted_id(data_instances, keep_key=True)
  117. if self.sync_cardinality:
  118. self.intersect_num = self.transfer_variable.cardinality.get(idx=0)
  119. LOGGER.info("Got intersect cardinality from guest.")