binning_adapter.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import numpy as np
  18. import operator
  19. from federatedml.feature.feature_selection.model_adapter import isometric_model
  20. from federatedml.feature.feature_selection.model_adapter.adapter_base import BaseAdapter
  21. from federatedml.util import LOGGER
  22. from federatedml.util import consts
  23. class BinningAdapter(BaseAdapter):
  24. def _load_one_class(self, local_result, remote_results):
  25. values_dict = dict(local_result.binning_result)
  26. values_sorted_dict = sorted(values_dict.items(), key=operator.itemgetter(0))
  27. values = []
  28. col_names = []
  29. for n, v in values_sorted_dict:
  30. values.append(v.iv)
  31. col_names.append(n)
  32. # LOGGER.debug(f"When loading iv, values: {values}, col_names: {col_names}")
  33. host_party_ids = [int(x.party_id) for x in remote_results]
  34. host_values = []
  35. host_col_names = []
  36. for host_obj in remote_results:
  37. binning_result = dict(host_obj.binning_result)
  38. h_values = []
  39. h_col_names = []
  40. for n, v in binning_result.items():
  41. h_values.append(v.iv)
  42. h_col_names.append(n)
  43. host_values.append(np.array(h_values))
  44. host_col_names.append(h_col_names)
  45. # LOGGER.debug(f"host_party_ids: {host_party_ids}, host_values: {host_values},"
  46. # f"host_col_names: {host_col_names}")
  47. LOGGER.debug(f"host_party_ids: {host_party_ids}")
  48. single_info = isometric_model.SingleMetricInfo(
  49. values=np.array(values),
  50. col_names=col_names,
  51. host_party_ids=host_party_ids,
  52. host_values=host_values,
  53. host_col_names=host_col_names
  54. )
  55. return single_info
  56. def convert(self, model_meta, model_param):
  57. multi_class_result = model_param.multi_class_result
  58. has_remote_result = multi_class_result.has_host_result
  59. label_counts = len(list(multi_class_result.labels))
  60. local_results = list(multi_class_result.results)
  61. host_results = list(multi_class_result.host_results)
  62. result = isometric_model.IsometricModel()
  63. for idx, lr in enumerate(local_results):
  64. if label_counts == 2:
  65. result.add_metric_value(metric_name=f"iv",
  66. metric_info=self._load_one_class(lr, host_results))
  67. else:
  68. if has_remote_result:
  69. remote_results = [hs for i, hs in enumerate(host_results) if (i % label_counts) == idx]
  70. else:
  71. remote_results = []
  72. result.add_metric_value(metric_name=f"iv",
  73. metric_info=self._load_one_class(lr, remote_results))
  74. return result