sqn_sync.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import numpy as np
  18. from federatedml.util import consts
  19. class SqnSyncBase(object):
  20. def __init__(self):
  21. self.batch_data_index_transfer = None
  22. self.host_forwards_transfer = None
  23. self.forward_hess = None
  24. self.forward_hess_transfer = None
  25. class Guest(SqnSyncBase):
  26. def __init__(self):
  27. super().__init__()
  28. self.guest_hess_vector = None
  29. def register_transfer_variable(self, transfer_variable):
  30. self.batch_data_index_transfer = transfer_variable.sqn_sample_index
  31. self.guest_hess_vector = transfer_variable.guest_hess_vector
  32. self.host_forwards_transfer = transfer_variable.host_sqn_forwards
  33. self.forward_hess_transfer = transfer_variable.forward_hess
  34. def sync_sample_data(self, data_instances, sample_size, random_seed, suffix=tuple()):
  35. n = data_instances.count()
  36. if sample_size >= n:
  37. sample_rate = 1.0
  38. else:
  39. sample_rate = sample_size / n
  40. sampled_data = data_instances.sample(fraction=sample_rate, seed=random_seed)
  41. batch_index = sampled_data.mapValues(lambda x: None)
  42. self.batch_data_index_transfer.remote(obj=batch_index,
  43. role=consts.HOST,
  44. suffix=suffix)
  45. return sampled_data
  46. def get_host_forwards(self, suffix=tuple()):
  47. host_forwards = self.host_forwards_transfer.get(idx=-1,
  48. suffix=suffix)
  49. return host_forwards
  50. def remote_forward_hess(self, forward_hess, suffix=tuple()):
  51. self.forward_hess_transfer.remote(obj=forward_hess,
  52. role=consts.HOST,
  53. suffix=suffix)
  54. def sync_hess_vector(self, hess_vector, suffix):
  55. self.guest_hess_vector.remote(obj=hess_vector,
  56. role=consts.ARBITER,
  57. suffix=suffix)
  58. class Host(SqnSyncBase):
  59. def __init__(self):
  60. super().__init__()
  61. self.host_hess_vector = None
  62. def register_transfer_variable(self, transfer_variable):
  63. self.batch_data_index_transfer = transfer_variable.sqn_sample_index
  64. self.host_forwards_transfer = transfer_variable.host_sqn_forwards
  65. self.host_hess_vector = transfer_variable.host_hess_vector
  66. self.forward_hess_transfer = transfer_variable.forward_hess
  67. def sync_sample_data(self, data_instances, suffix=tuple()):
  68. batch_index = self.batch_data_index_transfer.get(idx=0,
  69. suffix=suffix)
  70. sample_data = data_instances.join(batch_index, lambda x, y: x)
  71. return sample_data
  72. def remote_host_forwards(self, host_forwards, suffix=tuple()):
  73. self.host_forwards_transfer.remote(obj=host_forwards,
  74. role=consts.GUEST,
  75. suffix=suffix)
  76. def get_forward_hess(self, suffix=tuple()):
  77. forward_hess = self.forward_hess_transfer.get(idx=0,
  78. suffix=suffix)
  79. return forward_hess
  80. def sync_hess_vector(self, hess_vector, suffix):
  81. self.host_hess_vector.remote(obj=hess_vector,
  82. role=consts.ARBITER,
  83. suffix=suffix)
  84. class Arbiter(object):
  85. def __init__(self):
  86. super().__init__()
  87. self.guest_hess_vector = None
  88. self.host_hess_vector = None
  89. def register_transfer_variable(self, transfer_variable):
  90. self.guest_hess_vector = transfer_variable.guest_hess_vector
  91. self.host_hess_vector = transfer_variable.host_hess_vector
  92. def sync_hess_vector(self, suffix):
  93. guest_hess_vector = self.guest_hess_vector.get(idx=0,
  94. suffix=suffix)
  95. host_hess_vectors = self.host_hess_vector.get(idx=-1,
  96. suffix=suffix)
  97. host_hess_vectors = [x.reshape(-1) for x in host_hess_vectors]
  98. hess_vectors = np.hstack((h for h in host_hess_vectors))
  99. hess_vectors = np.hstack((hess_vectors, guest_hess_vector))
  100. return hess_vectors