label_transform.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2021 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import copy
  18. import numpy as np
  19. from federatedml.model_base import Metric, MetricMeta
  20. from federatedml.model_base import ModelBase
  21. from federatedml.param.label_transform_param import LabelTransformParam
  22. from federatedml.protobuf.generated import label_transform_meta_pb2, label_transform_param_pb2
  23. from federatedml.statistic.data_overview import get_label_count, get_predict_result_labels, \
  24. predict_detail_dict_to_str, predict_detail_str_to_dict
  25. from federatedml.util import LOGGER
  26. class LabelTransformer(ModelBase):
  27. def __init__(self):
  28. super().__init__()
  29. self.model_param = LabelTransformParam()
  30. self.metric_name = "label_transform"
  31. self.metric_namespace = "train"
  32. self.metric_type = "LABEL_TRANSFORM"
  33. self.model_param_name = 'LabelTransformParam'
  34. self.model_meta_name = 'LabelTransformMeta'
  35. self.weight_mode = None
  36. self.encoder_key_type = None
  37. self.encoder_value_type = None
  38. self.label_encoder = None
  39. self.label_list = None
  40. def _init_model(self, params):
  41. self.model_param = params
  42. self.label_encoder = params.label_encoder
  43. self.label_list = params.label_list
  44. self.need_run = params.need_run
  45. def update_label_encoder(self, data):
  46. if self.label_encoder is not None:
  47. LOGGER.info(f"label encoder provided")
  48. LOGGER.info("count labels in data.")
  49. data_type = data.schema.get("content_type")
  50. if data_type is None:
  51. label_count = get_label_count(data)
  52. labels = sorted(label_count.keys())
  53. # predict result
  54. else:
  55. labels = sorted(get_predict_result_labels(data))
  56. if self.label_list is not None:
  57. LOGGER.info(f"label list provided")
  58. self.encoder_key_type = {str(v): type(v).__name__ for v in self.label_list}
  59. else:
  60. self.encoder_key_type = {str(v): type(v).__name__ for v in labels}
  61. if len(labels) != len(self.label_encoder):
  62. missing_values = [k for k in labels if str(k) not in self.label_encoder]
  63. LOGGER.warning(f"labels: {missing_values} found in input data "
  64. f"but are not matched in provided label_encoder. "
  65. f"Note that unmatched labels will not be transformed.")
  66. self.label_encoder.update(zip([str(k) for k in missing_values],
  67. missing_values))
  68. self.encoder_key_type.update(zip([str(k) for k in missing_values],
  69. [type(v).__name__ for v in missing_values]))
  70. else:
  71. data_type = data.schema.get("content_type")
  72. if data_type is None:
  73. label_count = get_label_count(data)
  74. labels = sorted(label_count.keys())
  75. # predict result
  76. else:
  77. labels = sorted(get_predict_result_labels(data))
  78. self.label_encoder = dict(zip(labels, range(len(labels))))
  79. if self.encoder_key_type is None:
  80. self.encoder_key_type = {str(k): type(k).__name__ for k in self.label_encoder.keys()}
  81. self.encoder_value_type = {str(k): type(v).__name__ for k, v in self.label_encoder.items()}
  82. self.label_encoder = {load_value_to_type(k,
  83. self.encoder_key_type.get(str(k), None)): v for k,
  84. v in self.label_encoder.items()}
  85. for k, v in self.label_encoder.items():
  86. if v is None:
  87. raise ValueError(f"given encoder key {k} not found in data or provided label list, please check.")
  88. def _get_meta(self):
  89. meta = label_transform_meta_pb2.LabelTransformMeta(
  90. need_run=self.need_run
  91. )
  92. return meta
  93. def _get_param(self):
  94. label_encoder = self.label_encoder
  95. if self.label_encoder is not None:
  96. label_encoder = {str(k): str(v) for k, v in self.label_encoder.items()}
  97. param = label_transform_param_pb2.LabelTransformParam(
  98. label_encoder=label_encoder,
  99. encoder_key_type=self.encoder_key_type,
  100. encoder_value_type=self.encoder_value_type)
  101. return param
  102. def export_model(self):
  103. meta_obj = self._get_meta()
  104. param_obj = self._get_param()
  105. result = {
  106. self.model_meta_name: meta_obj,
  107. self.model_param_name: param_obj
  108. }
  109. self.model_output = result
  110. return result
  111. def load_model(self, model_dict):
  112. meta_obj = list(model_dict.get('model').values())[0].get(self.model_meta_name)
  113. param_obj = list(model_dict.get('model').values())[0].get(self.model_param_name)
  114. self.need_run = meta_obj.need_run
  115. self.encoder_key_type = param_obj.encoder_key_type
  116. self.encoder_value_type = param_obj.encoder_value_type
  117. self.label_encoder = {
  118. load_value_to_type(k, self.encoder_key_type[k]): load_value_to_type(v, self.encoder_value_type[k])
  119. for k, v in param_obj.label_encoder.items()
  120. }
  121. return
  122. def callback_info(self):
  123. metric_meta = MetricMeta(name='train',
  124. metric_type=self.metric_type,
  125. extra_metas={
  126. "label_encoder": self.label_encoder
  127. })
  128. self.callback_metric(metric_name=self.metric_name,
  129. metric_namespace=self.metric_namespace,
  130. metric_data=[Metric(self.metric_name, 0)])
  131. self.tracker.set_metric_meta(metric_namespace=self.metric_namespace,
  132. metric_name=self.metric_name,
  133. metric_meta=metric_meta)
  134. @staticmethod
  135. def replace_instance_label(instance, label_encoder):
  136. new_instance = copy.deepcopy(instance)
  137. label_replace_val = label_encoder.get(instance.label)
  138. if label_replace_val is None:
  139. raise ValueError(f"{instance.label} not found in given label encoder")
  140. new_instance.label = label_replace_val
  141. return new_instance
  142. @staticmethod
  143. def replace_predict_label(predict_inst, label_encoder):
  144. transform_predict_inst = copy.deepcopy(predict_inst)
  145. true_label, predict_label, predict_score, predict_detail_str, result_type = transform_predict_inst.features
  146. predict_detail = predict_detail_str_to_dict(predict_detail_str)
  147. true_label_replace_val, predict_label_replace_val = label_encoder.get(
  148. true_label), label_encoder.get(predict_label)
  149. if true_label_replace_val is None:
  150. raise ValueError(f"{true_label_replace_val} not found in given label encoder")
  151. if predict_label_replace_val is None:
  152. raise ValueError(f"{predict_label_replace_val} not found in given label encoder")
  153. label_encoder_detail = {str(k): v for k, v in label_encoder.items()}
  154. predict_detail_dict = {label_encoder_detail[label]: score for label, score in predict_detail.items()}
  155. predict_detail = predict_detail_dict_to_str(predict_detail_dict)
  156. transform_predict_inst.features = [true_label_replace_val, predict_label_replace_val, predict_score,
  157. predict_detail, result_type]
  158. return transform_predict_inst
  159. @staticmethod
  160. def replace_predict_label_cluster(predict_inst, label_encoder):
  161. transform_predict_inst = copy.deepcopy(predict_inst)
  162. true_label, predict_label = transform_predict_inst.features[0], transform_predict_inst.features[1]
  163. true_label, predict_label = label_encoder[true_label], label_encoder[predict_label]
  164. transform_predict_inst.features = [true_label, predict_label]
  165. return transform_predict_inst
  166. @staticmethod
  167. def transform_data_label(data, label_encoder):
  168. data_type = data.schema.get("content_type")
  169. if data_type == "cluster_result":
  170. return data.mapValues(lambda v: LabelTransformer.replace_predict_label_cluster(v, label_encoder))
  171. elif data_type == "predict_result":
  172. predict_detail = data.first()[1].features[3]
  173. if predict_detail == 1 and list(predict_detail.keys())[0] == "label":
  174. LOGGER.info(f"Regression prediction result provided. Original data returned.")
  175. return data
  176. return data.mapValues(lambda v: LabelTransformer.replace_predict_label(v, label_encoder))
  177. elif data_type is None:
  178. return data.mapValues(lambda v: LabelTransformer.replace_instance_label(v, label_encoder))
  179. else:
  180. raise ValueError(f"unknown data type: {data_type} encountered. Label transform aborted.")
  181. def transform(self, data):
  182. LOGGER.info(f"Enter Label Transformer Transform")
  183. if self.label_encoder is None:
  184. raise ValueError(f"Input Label Encoder is None. Label Transform aborted.")
  185. label_encoder = self.label_encoder
  186. data_type = data.schema.get("content_type")
  187. # revert label encoding if predict result
  188. if data_type is not None:
  189. label_encoder = dict(zip(self.label_encoder.values(), self.label_encoder.keys()))
  190. result_data = LabelTransformer.transform_data_label(data, label_encoder)
  191. result_data.schema = data.schema
  192. self.callback_info()
  193. return result_data
  194. def fit(self, data):
  195. LOGGER.info(f"Enter Label Transform Fit")
  196. self.update_label_encoder(data)
  197. result_data = LabelTransformer.transform_data_label(data, self.label_encoder)
  198. result_data.schema = data.schema
  199. self.callback_info()
  200. return result_data
  201. # also used in feature imputation, to be moved to common util
  202. def load_value_to_type(value, value_type):
  203. if value is None:
  204. loaded_value = None
  205. elif value_type in ["int", "int64", "long", "float", "float64", "double"]:
  206. loaded_value = getattr(np, value_type)(value)
  207. elif value_type in ["str", "_str"]:
  208. loaded_value = str(value)
  209. else:
  210. raise ValueError(f"unknown value type: {value_type}")
  211. return loaded_value