homo_split_points.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import numpy as np
  18. from federatedml.feature.binning.quantile_binning import QuantileBinning
  19. from federatedml.framework.weights import DictWeights
  20. from federatedml.param.feature_binning_param import FeatureBinningParam
  21. from federatedml.framework.homo.aggregator.secure_aggregator import SecureAggregatorClient, SecureAggregatorServer
  22. from federatedml.util import abnormal_detection
  23. from federatedml.util import consts
  24. class HomoFeatureBinningServer(object):
  25. def __init__(self):
  26. self.aggregator = SecureAggregatorServer(secure_aggregate=True, communicate_match_suffix='homo_feature_binning')
  27. self.suffix = tuple()
  28. def set_suffix(self, suffix):
  29. self.suffix = suffix
  30. def average_run(self, data_instances=None, bin_param: FeatureBinningParam = None, bin_num=10, abnormal_list=None):
  31. agg_split_points = self.aggregator.aggregate_model(suffix=self.suffix)
  32. self.aggregator.broadcast_model(agg_split_points, suffix=self.suffix)
  33. def fit(self, *args, **kwargs):
  34. pass
  35. def query_quantile_points(self, data_instances, quantile_points):
  36. suffix = tuple(list(self.suffix) + [str(quantile_points)])
  37. agg_quantile_points = self.aggregator.aggregate_model(suffix=suffix)
  38. self.aggregator.broadcast_model(agg_quantile_points, suffix=suffix)
  39. class HomoFeatureBinningClient(object):
  40. def __init__(self, bin_method=consts.QUANTILE):
  41. self.aggregator = SecureAggregatorClient(
  42. secure_aggregate=True,
  43. aggregate_type='mean',
  44. communicate_match_suffix='homo_feature_binning')
  45. self.suffix = tuple()
  46. self.bin_method = bin_method
  47. self.bin_obj: QuantileBinning = None
  48. self.bin_param = None
  49. self.abnormal_list = None
  50. def set_suffix(self, suffix):
  51. self.suffix = suffix
  52. def average_run(self, data_instances, bin_num=10, abnormal_list=None):
  53. if self.bin_param is None:
  54. bin_param = FeatureBinningParam(bin_num=bin_num)
  55. self.bin_param = bin_param
  56. else:
  57. bin_param = self.bin_param
  58. if self.bin_method == consts.QUANTILE:
  59. bin_obj = QuantileBinning(params=bin_param, abnormal_list=abnormal_list, allow_duplicate=True)
  60. else:
  61. raise ValueError("Homo Split Point do not accept bin_method: {}".format(self.bin_method))
  62. abnormal_detection.empty_table_detection(data_instances)
  63. abnormal_detection.empty_feature_detection(data_instances)
  64. split_points = bin_obj.fit_split_points(data_instances)
  65. split_points = {k: np.array(v) for k, v in split_points.items()}
  66. split_points_weights = DictWeights(d=split_points)
  67. self.aggregator.send_model(split_points_weights, self.suffix)
  68. dict_split_points = self.aggregator.get_aggregated_model(self.suffix)
  69. split_points = {k: list(v) for k, v in dict_split_points.unboxed.items()}
  70. self.bin_obj = bin_obj
  71. return split_points
  72. def convert_feature_to_bin(self, data_instances, split_points=None):
  73. if self.bin_obj is None:
  74. return None, None, None
  75. return self.bin_obj.convert_feature_to_bin(data_instances, split_points)
  76. def set_bin_param(self, bin_param: FeatureBinningParam):
  77. if self.bin_param is not None:
  78. raise RuntimeError("Bin param has been set and it's immutable")
  79. self.bin_param = bin_param
  80. return self
  81. def set_abnormal_list(self, abnormal_list):
  82. self.abnormal_list = abnormal_list
  83. return self
  84. def fit(self, data_instances):
  85. if self.bin_obj is not None:
  86. return self
  87. if self.bin_param is None:
  88. self.bin_param = FeatureBinningParam()
  89. self.bin_obj = QuantileBinning(params=self.bin_param, abnormal_list=self.abnormal_list,
  90. allow_duplicate=True)
  91. self.bin_obj.fit_split_points(data_instances)
  92. return self
  93. def query_quantile_points(self, data_instances, quantile_points):
  94. if self.bin_obj is None:
  95. self.fit(data_instances)
  96. # bin_col_names = self.bin_obj.bin_inner_param.bin_names
  97. query_result = self.bin_obj.query_quantile_point(quantile_points)
  98. query_points = DictWeights(d=query_result)
  99. suffix = tuple(list(self.suffix) + [str(quantile_points)])
  100. self.aggregator.send_model(query_points, suffix)
  101. query_points = self.aggregator.get_aggregated_model(suffix)
  102. query_points = {k: v for k, v in query_points.unboxed.items()}
  103. return query_points