hetero_lr_gradient_test.py 4.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import unittest
  17. import numpy as np
  18. from fate_arch.session import computing_session as session
  19. from federatedml.feature.instance import Instance
  20. from federatedml.feature.sparse_vector import SparseVector
  21. from federatedml.optim.gradient import hetero_linear_model_gradient
  22. from federatedml.optim.gradient import hetero_lr_gradient_and_loss
  23. from federatedml.secureprotol import PaillierEncrypt
  24. class TestHeteroLogisticGradient(unittest.TestCase):
  25. def setUp(self):
  26. self.paillier_encrypt = PaillierEncrypt()
  27. self.paillier_encrypt.generate_key()
  28. # self.hetero_lr_gradient = HeteroLogisticGradient(self.paillier_encrypt)
  29. self.hetero_lr_gradient = hetero_lr_gradient_and_loss.Guest()
  30. size = 10
  31. self.en_wx = session.parallelize([self.paillier_encrypt.encrypt(i) for i in range(size)],
  32. partition=48,
  33. include_key=False)
  34. # self.en_wx = session.parallelize([self.paillier_encrypt.encrypt(i) for i in range(size)])
  35. self.en_sum_wx_square = session.parallelize([self.paillier_encrypt.encrypt(np.square(i)) for i in range(size)],
  36. partition=48,
  37. include_key=False)
  38. self.wx = np.array([i for i in range(size)])
  39. self.w = self.wx / np.array([1 for _ in range(size)])
  40. self.data_inst = session.parallelize(
  41. [Instance(features=np.array([1 for _ in range(size)]), label=pow(-1, i % 2)) for i in range(size)],
  42. partition=48, include_key=False)
  43. # test fore_gradient
  44. self.fore_gradient_local = [-0.5, 0.75, 0, 1.25, 0.5, 1.75, 1, 2.25, 1.5, 2.75]
  45. # test gradient
  46. self.gradient = [1.125, 1.125, 1.125, 1.125, 1.125, 1.125, 1.125, 1.125, 1.125, 1.125]
  47. self.gradient_fit_intercept = [1.125, 1.125, 1.125, 1.125, 1.125, 1.125, 1.125, 1.125, 1.125, 1.125, 1.125]
  48. self.loss = 4.505647
  49. def test_compute_partition_gradient(self):
  50. fore_gradient = self.en_wx.join(self.data_inst, lambda wx, d: 0.25 * wx - 0.5 * d.label)
  51. sparse_data = self._make_sparse_data()
  52. gradient_computer = hetero_linear_model_gradient.HeteroGradientBase()
  53. for fit_intercept in [True, False]:
  54. dense_result = gradient_computer.compute_gradient(self.data_inst, fore_gradient, fit_intercept)
  55. dense_result = [self.paillier_encrypt.decrypt(iterator) for iterator in dense_result]
  56. if fit_intercept:
  57. self.assertListEqual(dense_result, self.gradient_fit_intercept)
  58. else:
  59. self.assertListEqual(dense_result, self.gradient)
  60. sparse_result = gradient_computer.compute_gradient(sparse_data, fore_gradient, fit_intercept)
  61. sparse_result = [self.paillier_encrypt.decrypt(iterator) for iterator in sparse_result]
  62. self.assertListEqual(dense_result, sparse_result)
  63. def _make_sparse_data(self):
  64. def trans_sparse(instance):
  65. dense_features = instance.features
  66. indices = [i for i in range(len(dense_features))]
  67. sparse_features = SparseVector(indices=indices, data=dense_features, shape=len(dense_features))
  68. return Instance(inst_id=None,
  69. features=sparse_features,
  70. label=instance.label)
  71. return self.data_inst.mapValues(trans_sparse)
  72. if __name__ == "__main__":
  73. session.init("1111")
  74. unittest.main()
  75. session.stop()