gh_packing_compressing_test.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import unittest
  17. import functools
  18. import math
  19. from fate_arch.session import computing_session as session
  20. from federatedml.ensemble.basic_algorithms.decision_tree.tree_core.g_h_optim import PackedGHCompressor, GHPacker, fix_point_precision
  21. from federatedml.secureprotol.encrypt import PaillierEncrypt
  22. from federatedml.ensemble.basic_algorithms.decision_tree.tree_core.splitter import SplitInfo
  23. from federatedml.util import consts
  24. import numpy as np
  25. np.random.seed(114514)
  26. def generate_bin_gh(num):
  27. # (-1, 1)
  28. g = np.random.random(num)
  29. h = np.random.random(num)
  30. g = g * 2 - 1
  31. return g, h
  32. def generate_reg_gh(num, lower, upper):
  33. g = np.random.random(num)
  34. h = np.zeros(num) + 2
  35. g = g * (upper - lower) + lower
  36. return g, h
  37. def cmp(a, b):
  38. if a[0] > b[0]:
  39. return 1
  40. else:
  41. return -1
  42. def en_gh_list(g, h, en):
  43. en_g = [en.encrypt(i) for i in g]
  44. en_h = [en.encrypt(i) for i in h]
  45. return en_g, en_h
  46. def truncate(f, n=consts.TREE_DECIMAL_ROUND):
  47. return math.floor(f * 10 ** n) / 10 ** n
  48. def make_random_sum(collected_gh, g, h, en_g_l, en_h_l, max_sample_num):
  49. selected_sample_num = np.random.randint(max_sample_num) + 1 # at least 1 sample
  50. idx = np.random.random(selected_sample_num)
  51. idx = np.unique((idx * max_sample_num).astype(int))
  52. print('randomly select {} samples'.format(len(idx)))
  53. selected_g = g[idx]
  54. selected_h = h[idx]
  55. g_sum = selected_g.sum()
  56. h_sum = selected_h.sum()
  57. g_h_list = sorted(collected_gh, key=functools.cmp_to_key(cmp))
  58. sum_gh = 0
  59. en_g_sum = 0
  60. en_h_sum = 0
  61. for i in idx:
  62. gh = g_h_list[i][1][0]
  63. sum_gh += gh
  64. en_g_sum += en_g_l[i]
  65. en_h_sum += en_h_l[i]
  66. return g_sum, h_sum, sum_gh, en_g_sum, en_h_sum, len(idx)
  67. class TestFeatureHistogram(unittest.TestCase):
  68. @staticmethod
  69. def prepare_testing_data(g, h, en, max_sample_num, sample_id, task_type, g_min=None, g_max=None):
  70. packer = GHPacker(max_sample_num, encrypter=en, sync_para=False, task_type=task_type,
  71. g_min=g_min, g_max=g_max)
  72. en_g_l, en_h_l = en_gh_list(g, h, en)
  73. data_list = [(id_, (g_, h_)) for id_, g_, h_ in zip(sample_id, g, h)]
  74. data_table = session.parallelize(data_list, 4, include_key=True)
  75. en_table = packer.pack_and_encrypt(data_table)
  76. collected_gh = list(en_table.collect())
  77. return packer, en_g_l, en_h_l, en_table, collected_gh
  78. @classmethod
  79. def setUpClass(cls):
  80. session.init("test_gh_packing")
  81. cls.max_sample_num = 1000
  82. cls.test_num = 10
  83. cls.split_info_test_num = 200
  84. key_length = 1024
  85. sample_id = [i for i in range(cls.max_sample_num)]
  86. # classification data
  87. cls.g, cls.h = generate_bin_gh(cls.max_sample_num)
  88. cls.p_en = PaillierEncrypt()
  89. cls.p_en.generate_key(key_length)
  90. cls.p_packer, cls.p_en_g_l, cls.p_en_h_l, cls.p_en_table, cls.p_collected_gh = \
  91. cls.prepare_testing_data(cls.g, cls.h, cls.p_en, cls.max_sample_num, sample_id, consts.CLASSIFICATION)
  92. cls.compressor = PackedGHCompressor(sync_para=False)
  93. cls.compressor.compressor._padding_length, cls.compressor.compressor._capacity = \
  94. cls.p_packer.packer.cipher_compress_suggest()
  95. print('paillier compress para {}'.format(cls.p_packer.packer.cipher_compress_suggest()))
  96. # regression data
  97. cls.g_reg, cls.h_reg = generate_reg_gh(cls.max_sample_num, -1000, 1000)
  98. cls.reg_p_packer, cls.reg_p_en_g_l, cls.reg_p_en_h_l, cls.reg_p_en_table, cls.reg_p_collected_gh = \
  99. cls.prepare_testing_data(cls.g_reg, cls.h_reg, cls.p_en, cls.max_sample_num, sample_id, consts.REGRESSION,
  100. g_min=-1000, g_max=1000)
  101. cls.reg_compressor = PackedGHCompressor(sync_para=False)
  102. cls.reg_compressor.compressor._padding_length, cls.reg_compressor.compressor._capacity = \
  103. cls.reg_p_packer.packer.cipher_compress_suggest()
  104. print('paillier compress para {}'.format(cls.p_packer.packer.cipher_compress_suggest()))
  105. print('initialization done')
  106. def run_gh_accumulate_test(self, test_num, collected_gh, en_g_l, en_h_l, packer, en, g, h, check=True):
  107. print('{} test to run'.format(test_num))
  108. for i in range(test_num):
  109. print('executing test {}'.format(i))
  110. g_sum, h_sum, en_sum, en_g_sum, en_h_sum, sample_num = make_random_sum(collected_gh, g, h,
  111. en_g_l,
  112. en_h_l,
  113. self.max_sample_num)
  114. de_num = en.raw_decrypt(en_sum)
  115. unpack_num = packer.packer.unpack_an_int(de_num, packer.packer.bit_assignment[0])
  116. g_sum_ = unpack_num[0] / fix_point_precision - sample_num * packer.g_offset
  117. h_sum_ = unpack_num[1] / fix_point_precision
  118. g_sum_2 = en.decrypt(en_g_sum)
  119. h_sum_2 = en.decrypt(en_h_sum)
  120. print(g_sum, h_sum)
  121. print(g_sum_2, h_sum_2)
  122. print(g_sum_, h_sum_)
  123. g_sum, h_sum = truncate(g_sum), truncate(h_sum)
  124. g_sum_, h_sum_ = truncate(g_sum_), truncate(h_sum_)
  125. g_sum_2, h_sum_2 = truncate(g_sum_2), truncate(h_sum_2)
  126. print(g_sum, h_sum)
  127. print(g_sum_2, h_sum_2)
  128. print(g_sum_, h_sum_)
  129. if check:
  130. # make sure packing result close to plaintext sum
  131. self.assertTrue(g_sum_ == g_sum)
  132. self.assertTrue(h_sum_ == h_sum)
  133. print('passed')
  134. def test_pack_gh_accumulate(self):
  135. # test the correctness of gh packing(in comparision to plaintext)
  136. # Paillier
  137. self.run_gh_accumulate_test(self.test_num, self.p_collected_gh, self.p_en_g_l, self.p_en_h_l, self.p_packer,
  138. self.p_en, self.g, self.h)
  139. print('*' * 30)
  140. print('test paillier done')
  141. print('*' * 30)
  142. def test_split_info_cipher_compress(self):
  143. # test the correctness of cipher compressing
  144. print('testing binary')
  145. collected_gh = self.p_collected_gh
  146. en_g_l = self.p_en_g_l
  147. en_h_l = self.p_en_h_l
  148. packer = self.p_packer
  149. en = self.p_en
  150. sp_list = []
  151. g_sum_list, h_sum_list = [], []
  152. pack_en_list = []
  153. for i in range(self.split_info_test_num):
  154. g_sum, h_sum, en_sum, en_g_sum, en_h_sum, sample_num = make_random_sum(collected_gh, self.g, self.h,
  155. en_g_l,
  156. en_h_l,
  157. self.max_sample_num)
  158. sp = SplitInfo(sum_grad=en_sum, sum_hess=0, sample_count=sample_num)
  159. sp_list.append(sp)
  160. g_sum_list.append(g_sum)
  161. h_sum_list.append(h_sum)
  162. pack_en_list.append(en_sum)
  163. print('generating split-info done')
  164. packages = self.compressor.compress_split_info(sp_list[:-1], sp_list[-1])
  165. print('package length is {}'.format(len(packages)))
  166. unpack_rs = packer.decompress_and_unpack(packages)
  167. case_id = 0
  168. for s, g, h, en_gh in zip(unpack_rs, g_sum_list, h_sum_list, pack_en_list):
  169. print('*' * 10)
  170. print(case_id)
  171. case_id += 1
  172. de_num = en.raw_decrypt(en_gh)
  173. unpack_num = packer.packer.unpack_an_int(de_num, packer.packer.bit_assignment[0])
  174. g_sum_ = unpack_num[0] / fix_point_precision - s.sample_count * packer.g_offset
  175. h_sum_ = unpack_num[1] / fix_point_precision
  176. print(s.sample_count)
  177. print(s.sum_grad, g_sum_, g)
  178. print(s.sum_hess, h_sum_, h)
  179. # make sure cipher compress is correct
  180. self.assertTrue(truncate(s.sum_grad) == truncate(g_sum_))
  181. self.assertTrue(truncate(s.sum_hess) == truncate(h_sum_))
  182. print('check passed')
  183. def test_regression_cipher_compress(self):
  184. # test the correctness of cipher compressing
  185. print('testing regression')
  186. collected_gh = self.reg_p_collected_gh
  187. en_g_l = self.reg_p_en_g_l
  188. en_h_l = self.reg_p_en_h_l
  189. packer = self.reg_p_packer
  190. en = self.p_en
  191. sp_list = []
  192. g_sum_list, h_sum_list = [], []
  193. pack_en_list = []
  194. for i in range(self.split_info_test_num):
  195. g_sum, h_sum, en_sum, en_g_sum, en_h_sum, sample_num = make_random_sum(collected_gh, self.g_reg, self.h_reg,
  196. en_g_l,
  197. en_h_l,
  198. self.max_sample_num)
  199. sp = SplitInfo(sum_grad=en_sum, sum_hess=0, sample_count=sample_num)
  200. sp_list.append(sp)
  201. g_sum_list.append(g_sum)
  202. h_sum_list.append(h_sum)
  203. pack_en_list.append(en_sum)
  204. print('generating split-info done')
  205. packages = self.reg_compressor.compress_split_info(sp_list[:-1], sp_list[-1])
  206. print('package length is {}'.format(len(packages)))
  207. unpack_rs = packer.decompress_and_unpack(packages)
  208. case_id = 0
  209. for s, g, h, en_gh in zip(unpack_rs, g_sum_list, h_sum_list, pack_en_list):
  210. print('*' * 10)
  211. print(case_id)
  212. case_id += 1
  213. de_num = en.raw_decrypt(en_gh) # make sure packing result close to plaintext sum
  214. unpack_num = packer.packer.unpack_an_int(de_num, packer.packer.bit_assignment[0])
  215. g_sum_ = unpack_num[0] / fix_point_precision - s.sample_count * packer.g_offset
  216. h_sum_ = unpack_num[1] / fix_point_precision
  217. print(s.sample_count)
  218. print(s.sum_grad, g_sum_, g)
  219. print(s.sum_hess, h_sum_, h)
  220. # make sure cipher compress is correct
  221. self.assertTrue(truncate(s.sum_grad) == truncate(g_sum_))
  222. self.assertTrue(truncate(s.sum_hess) == truncate(h_sum_))
  223. print('check passed')
  224. def test_regression_gh_packing(self):
  225. # Paillier
  226. self.run_gh_accumulate_test(
  227. self.test_num,
  228. self.reg_p_collected_gh,
  229. self.reg_p_en_g_l,
  230. self.reg_p_en_h_l,
  231. self.reg_p_packer,
  232. self.p_en,
  233. self.g_reg,
  234. self.h_reg,
  235. check=False) # float error in regression is not controllable
  236. @classmethod
  237. def tearDownClass(self):
  238. session.stop()
  239. if __name__ == '__main__':
  240. unittest.main()