bucket_binning.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. from federatedml.feature.binning.base_binning import BaseBinning
  18. from federatedml.statistic.statics import MultivariateStatisticalSummary
  19. from federatedml.statistic import data_overview
  20. class BucketBinning(BaseBinning):
  21. """
  22. For bucket binning, the length of each bin is the same which is:
  23. L = [max(x) - min(x)] / n
  24. The split points are min(x) + L * k
  25. where k is the index of a bin.
  26. """
  27. def fit_split_points(self, data_instances):
  28. """
  29. Apply the binning method
  30. Parameters
  31. ----------
  32. data_instances : Table
  33. The input data
  34. Returns
  35. -------
  36. split_points : dict.
  37. Each value represent for the split points for a feature. The element in each row represent for
  38. the corresponding split point.
  39. e.g.
  40. split_points = {'x1': [0.1, 0.2, 0.3, 0.4 ...], # The first feature
  41. 'x2': [1, 2, 3, 4, ...], # The second feature
  42. ...] # Other features
  43. """
  44. header = data_overview.get_header(data_instances)
  45. anonymous_header = data_overview.get_anonymous_header(data_instances)
  46. self._default_setting(header, anonymous_header)
  47. # is_sparse = data_overview.is_sparse_data(data_instances)
  48. # if is_sparse:
  49. # raise RuntimeError("Bucket Binning method has not supported sparse data yet.")
  50. # self._init_cols(data_instances)
  51. statistics = MultivariateStatisticalSummary(data_instances,
  52. self.bin_inner_param.bin_indexes,
  53. abnormal_list=self.abnormal_list)
  54. max_dict = statistics.get_max()
  55. min_dict = statistics.get_min()
  56. for col_name, max_value in max_dict.items():
  57. min_value = min_dict.get(col_name)
  58. split_points = []
  59. L = (max_value - min_value) / self.bin_num
  60. for k in range(self.bin_num - 1):
  61. s_p = min_value + (k + 1) * L
  62. split_points.append(s_p)
  63. split_points.append(max_value)
  64. # final_split_points[col_name] = split_point
  65. self.bin_results.put_col_split_points(col_name, split_points)
  66. self.fit_category_features(data_instances)
  67. return self.bin_results.all_split_points