federation.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from federatedml.transfer_variable.transfer_class.secret_share_transfer_variable import SecretShareTransferVariable
  17. class Communicator(object):
  18. def __init__(self, local_party=None, all_parties=None):
  19. self._transfer_variable = SecretShareTransferVariable()
  20. self._share_variable = self._transfer_variable.share.disable_auto_clean()
  21. self._rescontruct_variable = self._transfer_variable.rescontruct.set_preserve_num(3)
  22. self._mul_triplets_encrypted_variable = self._transfer_variable.multiply_triplets_encrypted.set_preserve_num(3)
  23. self._mul_triplets_cross_variable = self._transfer_variable.multiply_triplets_cross.set_preserve_num(3)
  24. self._q_field_variable = self._transfer_variable.q_field.disable_auto_clean()
  25. self._local_party = self._transfer_variable.local_party() if local_party is None else local_party
  26. self._all_parties = self._transfer_variable.all_parties() if all_parties is None else all_parties
  27. self._party_idx = self._all_parties.index(self._local_party)
  28. self._other_parties = self._all_parties[:self._party_idx] + self._all_parties[(self._party_idx + 1):]
  29. @property
  30. def party(self):
  31. return self._local_party
  32. @property
  33. def parties(self):
  34. return self._all_parties
  35. @property
  36. def other_parties(self):
  37. return self._other_parties
  38. @property
  39. def party_idx(self):
  40. return self._party_idx
  41. def remote_q_field(self, q_field, party):
  42. return self._q_field_variable.remote_parties(q_field, party, suffix=("q_field",))
  43. def get_q_field(self, party):
  44. return self._q_field_variable.get_parties(party, suffix=("q_field",))
  45. def get_rescontruct_shares(self, tensor_name):
  46. return self._rescontruct_variable.get_parties(self._other_parties, suffix=(tensor_name,))
  47. def broadcast_rescontruct_share(self, share, tensor_name):
  48. return self._rescontruct_variable.remote_parties(share, self._other_parties, suffix=(tensor_name,))
  49. def remote_share(self, share, tensor_name, party):
  50. return self._share_variable.remote_parties(share, party, suffix=(tensor_name,))
  51. def get_share(self, tensor_name, party):
  52. return self._share_variable.get_parties(party, suffix=(tensor_name,))
  53. def remote_encrypted_tensor(self, encrypted, tag):
  54. return self._mul_triplets_encrypted_variable.remote_parties(encrypted, parties=self._other_parties, suffix=tag)
  55. def remote_encrypted_cross_tensor(self, encrypted, parties, tag):
  56. return self._mul_triplets_cross_variable.remote_parties(encrypted, parties=parties, suffix=tag)
  57. def get_encrypted_tensors(self, tag):
  58. return (self._other_parties,
  59. self._mul_triplets_encrypted_variable.get_parties(parties=self._other_parties, suffix=tag))
  60. def get_encrypted_cross_tensors(self, tag):
  61. return self._mul_triplets_cross_variable.get_parties(parties=self._other_parties, suffix=tag)
  62. def clean(self):
  63. self._rescontruct_variable.clean()
  64. self._share_variable.clean()
  65. self._rescontruct_variable.clean()
  66. self._mul_triplets_encrypted_variable.clean()
  67. self._mul_triplets_cross_variable.clean()
  68. self._q_field_variable.clean()
  69. def set_flowid(self, flowid):
  70. self._transfer_variable.set_flowid(flowid)