compressor.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. import math
  2. from abc import ABC
  3. from abc import abstractmethod
  4. from federatedml.util import LOGGER
  5. from federatedml.secureprotol import PaillierEncrypt, IpclPaillierEncrypt
  6. from federatedml.transfer_variable.transfer_class.cipher_compressor_transfer_variable \
  7. import CipherCompressorTransferVariable
  8. def get_homo_encryption_max_int(encrypter):
  9. if isinstance(encrypter, (PaillierEncrypt, IpclPaillierEncrypt)):
  10. max_pos_int = encrypter.public_key.max_int
  11. min_neg_int = -max_pos_int
  12. else:
  13. raise ValueError('unknown encryption type')
  14. return max_pos_int, min_neg_int
  15. def cipher_compress_advisor(encrypter, plaintext_bit_len):
  16. max_pos_int, min_neg_int = get_homo_encryption_max_int(encrypter)
  17. max_bit_len = max_pos_int.bit_length()
  18. capacity = max_bit_len // plaintext_bit_len
  19. return capacity
  20. class CipherPackage(ABC):
  21. @abstractmethod
  22. def add(self, obj):
  23. pass
  24. @abstractmethod
  25. def unpack(self, decrypter):
  26. pass
  27. @abstractmethod
  28. def has_space(self):
  29. pass
  30. class PackingCipherTensor(object):
  31. """
  32. A naive realization of cipher tensor
  33. """
  34. def __init__(self, ciphers):
  35. if isinstance(ciphers, list):
  36. if len(ciphers) == 1:
  37. self.ciphers = ciphers[0]
  38. else:
  39. self.ciphers = ciphers
  40. self.dim = len(ciphers)
  41. else:
  42. self.ciphers = ciphers
  43. self.dim = 1
  44. def __add__(self, other):
  45. new_cipher_list = []
  46. if isinstance(other, PackingCipherTensor):
  47. assert self.dim == other.dim
  48. if self.dim == 1:
  49. return PackingCipherTensor(self.ciphers + other.ciphers)
  50. for c1, c2 in zip(self.ciphers, other.ciphers):
  51. new_cipher_list.append(c1 + c2)
  52. return PackingCipherTensor(ciphers=new_cipher_list)
  53. else:
  54. # scalar / single en num
  55. if self.dim == 1:
  56. return PackingCipherTensor(self.ciphers + other)
  57. for c in self.ciphers:
  58. new_cipher_list.append(c + other)
  59. return PackingCipherTensor(ciphers=new_cipher_list)
  60. def __radd__(self, other):
  61. return self.__add__(other)
  62. def __sub__(self, other):
  63. return self + other * -1
  64. def __rsub__(self, other):
  65. return other + (self * -1)
  66. def __mul__(self, other):
  67. if self.dim == 1:
  68. return PackingCipherTensor(self.ciphers * other)
  69. new_cipher_list = []
  70. for c in self.ciphers:
  71. new_cipher_list.append(c * other)
  72. return PackingCipherTensor(new_cipher_list)
  73. def __rmul__(self, other):
  74. return self.__mul__(other)
  75. def __truediv__(self, other):
  76. return self.__mul__(1 / other)
  77. def __repr__(self):
  78. return "[" + self.ciphers.__repr__() + "], dim {}".format(self.dim)
  79. class NormalCipherPackage(CipherPackage):
  80. def __init__(self, padding_length, max_capacity):
  81. self._padding_num = 2 ** padding_length
  82. self.max_capacity = max_capacity
  83. self._cipher_text = None
  84. self._capacity_left = max_capacity
  85. self._has_space = True
  86. def add(self, cipher_text):
  87. if self._capacity_left == 0:
  88. raise ValueError('cipher number exceeds package max capacity')
  89. if self._cipher_text is None:
  90. self._cipher_text = cipher_text
  91. else:
  92. self._cipher_text = self._cipher_text * self._padding_num
  93. self._cipher_text = self._cipher_text + cipher_text
  94. self._capacity_left -= 1
  95. if self._capacity_left == 0:
  96. self._has_space = False
  97. def unpack(self, decrypter):
  98. if isinstance(decrypter, (PaillierEncrypt, IpclPaillierEncrypt)):
  99. compressed_plain_text = decrypter.raw_decrypt(self._cipher_text)
  100. else:
  101. raise ValueError('unknown decrypter: {}'.format(type(decrypter)))
  102. if self.cur_cipher_contained() == 1:
  103. return [compressed_plain_text]
  104. unpack_result = []
  105. bit_len = (self._padding_num - 1).bit_length()
  106. for i in range(self.cur_cipher_contained()):
  107. num = (compressed_plain_text & (self._padding_num - 1))
  108. compressed_plain_text = compressed_plain_text >> bit_len
  109. unpack_result.insert(0, num)
  110. return unpack_result
  111. def has_space(self):
  112. return self._has_space
  113. def cur_cipher_contained(self):
  114. return self.max_capacity - self._capacity_left
  115. def retrieve(self):
  116. return self._cipher_text
  117. class PackingCipherTensorPackage(CipherPackage):
  118. """
  119. A naive realization of compressible tensor(only compress last dimension because previous ciphers have
  120. no space for compressing)
  121. """
  122. def __init__(self, padding_length, max_capcity):
  123. self.cached_list = []
  124. self.compressed_cipher = []
  125. self.compressed_dim = -1
  126. self.not_compress_len = None
  127. self.normal_package = NormalCipherPackage(padding_length, max_capcity)
  128. def add(self, obj: PackingCipherTensor):
  129. if self.normal_package.has_space():
  130. if obj.dim == 1:
  131. self.normal_package.add(obj.ciphers)
  132. else:
  133. self.cached_list.extend(obj.ciphers[:-1])
  134. self.not_compress_len = len(obj.ciphers[:-1])
  135. self.normal_package.add(obj.ciphers[-1])
  136. else:
  137. raise ValueError('have no space for compressing')
  138. def unpack(self, decrypter):
  139. compressed_part = self.normal_package.unpack(decrypter)
  140. de_rs = []
  141. if len(self.cached_list) != 0:
  142. de_rs = decrypter.recursive_raw_decrypt(self.cached_list)
  143. if len(de_rs) == 0:
  144. return [[i] for i in compressed_part]
  145. else:
  146. rs = []
  147. idx_0, idx_1 = 0, 0
  148. while idx_0 < len(self.cached_list):
  149. rs.append(de_rs[idx_0: idx_0 + self.not_compress_len] + [compressed_part[idx_1]])
  150. idx_0 += self.not_compress_len
  151. idx_1 += 1
  152. return rs
  153. def has_space(self):
  154. return self.normal_package.has_space()
  155. class CipherCompressorHost(object):
  156. def __init__(self, package_class=PackingCipherTensorPackage, sync_para=True):
  157. """
  158. Parameters
  159. ----------
  160. package_class type of compressed packages
  161. """
  162. self._package_class = package_class
  163. self._padding_length, self._capacity = None, None
  164. if sync_para:
  165. self.transfer_var = CipherCompressorTransferVariable()
  166. # received from host
  167. self._padding_length, self._capacity = self.transfer_var.compress_para.get(idx=0)
  168. LOGGER.debug("received parameter from guest is {} {}".format(self._padding_length, self._capacity))
  169. def compress(self, encrypted_obj_list):
  170. rs = []
  171. encrypted_obj_list = list(encrypted_obj_list)
  172. cur_package = self._package_class(self._padding_length, self._capacity)
  173. for c in encrypted_obj_list:
  174. if not cur_package.has_space():
  175. rs.append(cur_package)
  176. cur_package = self._package_class(self._padding_length, self._capacity)
  177. cur_package.add(c)
  178. rs.append(cur_package)
  179. return rs
  180. def compress_dtable(self, table):
  181. rs = table.mapValues(self.compress)
  182. return rs
  183. if __name__ == '__main__':
  184. a = PackingCipherTensor([1, 2, 3, 4])
  185. b = PackingCipherTensor([2, 3, 4, 5])
  186. c = PackingCipherTensor(124)
  187. d = PackingCipherTensor([114514])