sample_weight.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import copy
  18. import numpy as np
  19. from federatedml.model_base import Metric, MetricMeta
  20. from federatedml.model_base import ModelBase
  21. from federatedml.statistic import data_overview
  22. from federatedml.param.sample_weight_param import SampleWeightParam
  23. from federatedml.protobuf.generated.sample_weight_model_meta_pb2 import SampleWeightModelMeta
  24. from federatedml.protobuf.generated.sample_weight_model_param_pb2 import SampleWeightModelParam
  25. from federatedml.statistic.data_overview import get_label_count, check_negative_sample_weight
  26. from federatedml.util import consts, LOGGER
  27. class SampleWeight(ModelBase):
  28. def __init__(self):
  29. super().__init__()
  30. self.model_param = SampleWeightParam()
  31. self.metric_name = "sample_weight"
  32. self.metric_namespace = "train"
  33. self.metric_type = "SAMPLE_WEIGHT"
  34. self.model_meta_name = "SampleWeightModelMeta"
  35. self.model_param_name = "SampleWeightModelParam"
  36. self.weight_mode = None
  37. self.header = None
  38. self.class_weight_dict = None
  39. def _init_model(self, params):
  40. self.model_param = params
  41. self.class_weight = params.class_weight
  42. self.sample_weight_name = params.sample_weight_name
  43. self.normalize = params.normalize
  44. self.need_run = params.need_run
  45. @staticmethod
  46. def get_class_weight(data_instances):
  47. class_weight = get_label_count(data_instances)
  48. n_samples = data_instances.count()
  49. n_classes = len(class_weight.keys())
  50. res_class_weight = {str(k): n_samples / (n_classes * v) for k, v in class_weight.items()}
  51. return res_class_weight
  52. @staticmethod
  53. def replace_weight(data_instance, class_weight, weight_loc=None, weight_base=None):
  54. weighted_data_instance = copy.copy(data_instance)
  55. original_features = weighted_data_instance.features
  56. if weight_loc is not None:
  57. if weight_base is not None:
  58. inst_weight = original_features[weight_loc] / weight_base
  59. else:
  60. inst_weight = original_features[weight_loc]
  61. weighted_data_instance.set_weight(inst_weight)
  62. weighted_data_instance.features = original_features[np.arange(original_features.shape[0]) != weight_loc]
  63. else:
  64. weighted_data_instance.set_weight(class_weight.get(str(data_instance.label), 1))
  65. return weighted_data_instance
  66. @staticmethod
  67. def assign_sample_weight(data_instances, class_weight, weight_loc, normalize):
  68. weight_base = None
  69. if weight_loc is not None and normalize:
  70. def sum_sample_weight(kv_iterator):
  71. sample_weight = 0
  72. for _, inst in kv_iterator:
  73. sample_weight += inst.features[weight_loc]
  74. return sample_weight
  75. weight_sum = data_instances.mapPartitions(sum_sample_weight).reduce(lambda x, y: x + y)
  76. # LOGGER.debug(f"weight_sum is {weight_sum}")
  77. weight_base = weight_sum / data_instances.count()
  78. # LOGGER.debug(f"weight_base is {weight_base}")
  79. return data_instances.mapValues(lambda v: SampleWeight.replace_weight(v, class_weight, weight_loc, weight_base))
  80. @staticmethod
  81. def get_weight_loc(data_instances, sample_weight_name):
  82. weight_loc = None
  83. if sample_weight_name:
  84. try:
  85. weight_loc = data_instances.schema["header"].index(sample_weight_name)
  86. except ValueError:
  87. return
  88. return weight_loc
  89. def transform_weighted_instance(self, data_instances, weight_loc):
  90. if self.class_weight and self.class_weight == 'balanced':
  91. self.class_weight_dict = SampleWeight.get_class_weight(data_instances)
  92. else:
  93. if self.class_weight_dict is None:
  94. self.class_weight_dict = self.class_weight
  95. return SampleWeight.assign_sample_weight(data_instances, self.class_weight_dict, weight_loc, self.normalize)
  96. def callback_info(self):
  97. class_weight = None
  98. classes = None
  99. if self.class_weight_dict:
  100. class_weight = {str(k): v for k, v in self.class_weight_dict.items()}
  101. classes = sorted([str(k) for k in self.class_weight_dict.keys()])
  102. # LOGGER.debug(f"callback class weight is: {class_weight}")
  103. metric_meta = MetricMeta(name='train',
  104. metric_type=self.metric_type,
  105. extra_metas={
  106. "weight_mode": self.weight_mode,
  107. "class_weight": class_weight,
  108. "classes": classes,
  109. "sample_weight_name": self.sample_weight_name
  110. })
  111. self.callback_metric(metric_name=self.metric_name,
  112. metric_namespace=self.metric_namespace,
  113. metric_data=[Metric(self.metric_name, 0)])
  114. self.tracker.set_metric_meta(metric_namespace=self.metric_namespace,
  115. metric_name=self.metric_name,
  116. metric_meta=metric_meta)
  117. def export_model(self):
  118. meta_obj = SampleWeightModelMeta(sample_weight_name=self.sample_weight_name,
  119. normalize=self.normalize,
  120. need_run=self.need_run)
  121. param_obj = SampleWeightModelParam(header=self.header,
  122. weight_mode=self.weight_mode,
  123. class_weight=self.class_weight_dict)
  124. result = {
  125. self.model_meta_name: meta_obj,
  126. self.model_param_name: param_obj
  127. }
  128. return result
  129. def load_model(self, model_dict):
  130. param_obj = list(model_dict.get('model').values())[0].get(self.model_param_name)
  131. meta_obj = list(model_dict.get('model').values())[0].get(self.model_meta_name)
  132. self.header = list(param_obj.header)
  133. self.need_run = meta_obj.need_run
  134. self.weight_mode = param_obj.weight_mode
  135. if self.weight_mode == "class weight":
  136. self.class_weight_dict = {k: v for k, v in param_obj.class_weight.items()}
  137. elif self.weight_mode == "sample weight name":
  138. self.sample_weight_name = meta_obj.sample_weight_name
  139. self.normalize = meta_obj.normalize
  140. else:
  141. raise ValueError(f"Unknown weight mode {self.weight_mode} loaded. "
  142. f"Only support 'class weight' and 'sample weight name'")
  143. def transform(self, data_instances):
  144. LOGGER.info(f"Enter Sample Weight Transform")
  145. new_schema = copy.deepcopy(data_instances.schema)
  146. new_schema["sample_weight"] = "weight"
  147. weight_loc = None
  148. if self.weight_mode == "sample weight name":
  149. weight_loc = SampleWeight.get_weight_loc(data_instances, self.sample_weight_name)
  150. if weight_loc is not None:
  151. new_schema["header"].pop(weight_loc)
  152. else:
  153. LOGGER.warning(f"Cannot find weight column of given sample_weight_name '{self.sample_weight_name}'."
  154. f"Original input data returned")
  155. return data_instances
  156. result_instances = self.transform_weighted_instance(data_instances, weight_loc)
  157. result_instances.schema = new_schema
  158. self.callback_info()
  159. if result_instances.mapPartitions(check_negative_sample_weight).reduce(lambda x, y: x or y):
  160. LOGGER.warning(f"Negative weight found in weighted instances.")
  161. return result_instances
  162. def fit(self, data_instances):
  163. if self.sample_weight_name is None and self.class_weight is None:
  164. return data_instances
  165. self.header = data_overview.get_header(data_instances)
  166. if self.class_weight:
  167. self.weight_mode = "class weight"
  168. if self.sample_weight_name and self.class_weight:
  169. LOGGER.warning(f"Both 'sample_weight_name' and 'class_weight' provided. "
  170. f"Only weight from 'sample_weight_name' is used.")
  171. new_schema = copy.deepcopy(data_instances.schema)
  172. new_schema["sample_weight"] = "weight"
  173. weight_loc = None
  174. if self.sample_weight_name:
  175. self.weight_mode = "sample weight name"
  176. weight_loc = SampleWeight.get_weight_loc(data_instances, self.sample_weight_name)
  177. if weight_loc is not None:
  178. new_schema["header"].pop(weight_loc)
  179. else:
  180. raise ValueError(f"Cannot find weight column of given sample_weight_name '{self.sample_weight_name}'.")
  181. result_instances = self.transform_weighted_instance(data_instances, weight_loc)
  182. result_instances.schema = new_schema
  183. self.callback_info()
  184. if result_instances.mapPartitions(check_negative_sample_weight).reduce(lambda x, y: x or y):
  185. LOGGER.warning(f"Negative weight found in weighted instances.")
  186. return result_instances