packer.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import functools
  2. from federatedml.util import LOGGER
  3. from federatedml.secureprotol import PaillierEncrypt, IpclPaillierEncrypt
  4. from federatedml.cipher_compressor.compressor import get_homo_encryption_max_int
  5. from federatedml.secureprotol.encrypt_mode import EncryptModeCalculator
  6. from federatedml.cipher_compressor.compressor import PackingCipherTensor
  7. from federatedml.cipher_compressor.compressor import CipherPackage
  8. from federatedml.transfer_variable.transfer_class.cipher_compressor_transfer_variable \
  9. import CipherCompressorTransferVariable
  10. from federatedml.util import consts
  11. from typing import Union
  12. def cipher_list_to_cipher_tensor(cipher_list: list):
  13. cipher_tensor = PackingCipherTensor(ciphers=cipher_list)
  14. return cipher_tensor
  15. class GuestIntegerPacker(object):
  16. def __init__(self, pack_num: int, pack_num_range: list, encrypter: Union[PaillierEncrypt, IpclPaillierEncrypt],
  17. sync_para=True):
  18. """
  19. max_int: max int allowed for packing result
  20. pack_num: number of int to pack, they must be POSITIVE integer
  21. pack_num_range: list of integer, it gives range of every integer to pack
  22. need_cipher_compress: if dont need cipher compress, related parameter will be set to 1
  23. """
  24. self._pack_num = pack_num
  25. assert len(pack_num_range) == self._pack_num, 'list len must equal to pack_num'
  26. self._pack_num_range = pack_num_range
  27. self._pack_num_bit = [i.bit_length() for i in pack_num_range]
  28. self.encrypter = encrypter
  29. max_pos_int, _ = get_homo_encryption_max_int(self.encrypter)
  30. self._max_int = max_pos_int
  31. self._max_bit = self._max_int.bit_length() - 1 # reserve 1 bit, in case overflow
  32. # sometimes max_int is not able to hold all num need to be packed, so we
  33. # use more than one large integer to pack them all
  34. self.bit_assignment = []
  35. tmp_list = []
  36. bit_count = 0
  37. for bit_len in self._pack_num_bit:
  38. if bit_count + bit_len >= self._max_bit:
  39. if bit_count == 0:
  40. raise ValueError('unable to pack this num using in current int capacity')
  41. self.bit_assignment.append(tmp_list)
  42. tmp_list = []
  43. bit_count = 0
  44. bit_count += bit_len
  45. tmp_list.append(bit_len)
  46. if len(tmp_list) != 0:
  47. self.bit_assignment.append(tmp_list)
  48. self._pack_int_needed = len(self.bit_assignment)
  49. # transfer variable
  50. compress_parameter = self.cipher_compress_suggest()
  51. if sync_para:
  52. self.trans_var = CipherCompressorTransferVariable()
  53. self.trans_var.compress_para.remote(compress_parameter, role=consts.HOST, idx=-1)
  54. LOGGER.debug('int packer init done, bit assign is {}, compress para is {}'.format(self.bit_assignment,
  55. compress_parameter))
  56. def cipher_compress_suggest(self):
  57. compressible = self.bit_assignment[-1]
  58. total_bit_count = sum(compressible)
  59. compress_num = self._max_bit // total_bit_count
  60. padding_bit = total_bit_count
  61. return padding_bit, compress_num
  62. def pack_int_list(self, int_list: list):
  63. assert len(int_list) == self._pack_num, 'list length is not equal to pack_num'
  64. start_idx = 0
  65. rs = []
  66. for bit_assign_of_one_int in self.bit_assignment:
  67. to_pack = int_list[start_idx: start_idx + len(bit_assign_of_one_int)]
  68. packing_rs = self._pack_fix_len_int_list(to_pack, bit_assign_of_one_int)
  69. rs.append(packing_rs)
  70. start_idx += len(bit_assign_of_one_int)
  71. return rs
  72. def _pack_fix_len_int_list(self, int_list: list, bit_assign: list):
  73. result = int_list[0]
  74. for i, offset in zip(int_list[1:], bit_assign[1:]):
  75. result = result << offset
  76. result += i
  77. return result
  78. def unpack_an_int(self, integer: int, bit_assign_list: list):
  79. rs_list = []
  80. for bit_assign in reversed(bit_assign_list[1:]):
  81. mask_int = (2**bit_assign) - 1
  82. unpack_int = integer & mask_int
  83. rs_list.append(unpack_int)
  84. integer = integer >> bit_assign
  85. rs_list.append(integer)
  86. return list(reversed(rs_list))
  87. def pack(self, data_table):
  88. packing_data_table = data_table.mapValues(self.pack_int_list)
  89. return packing_data_table
  90. def pack_and_encrypt(self, data_table, post_process_func=cipher_list_to_cipher_tensor):
  91. packing_data_table = self.pack(data_table)
  92. en_packing_data_table = self.encrypter.distribute_raw_encrypt(packing_data_table)
  93. if post_process_func:
  94. en_packing_data_table = en_packing_data_table.mapValues(post_process_func)
  95. return en_packing_data_table
  96. def unpack_result(self, decrypted_result_list: list, post_func=None):
  97. final_rs = []
  98. for l_ in decrypted_result_list:
  99. rs_list = self.unpack_an_int_list(l_, post_func)
  100. final_rs.append(rs_list)
  101. return final_rs
  102. def unpack_an_int_list(self, int_list, post_func=None):
  103. assert len(int_list) == len(self.bit_assignment), 'length of integer list is not equal to bit_assignment'
  104. rs_list = []
  105. for idx, integer in enumerate(int_list):
  106. unpack_list = self.unpack_an_int(integer, self.bit_assignment[idx])
  107. if post_func:
  108. unpack_list = post_func(unpack_list)
  109. rs_list.extend(unpack_list)
  110. return rs_list
  111. def decrypt_cipher_packages(self, content):
  112. if isinstance(content, list):
  113. assert issubclass(type(content[0]), CipherPackage), 'content is not CipherPackages'
  114. decrypt_rs = []
  115. for i in content:
  116. unpack_ = i.unpack(self.encrypter)
  117. decrypt_rs += unpack_
  118. return decrypt_rs
  119. else:
  120. raise ValueError('illegal input type')
  121. def decrypt_cipher_package_and_unpack(self, data_table):
  122. de_func = functools.partial(self.decrypt_cipher_packages)
  123. de_table = data_table.mapValues(de_func)
  124. unpack_table = de_table.mapValues(self.unpack_result)
  125. return unpack_table