one_hot_encoder.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # Copyright 2019 The FATE Authors. All Rights Reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import copy
  17. import functools
  18. import math
  19. import numpy as np
  20. from federatedml.model_base import ModelBase
  21. from federatedml.param.onehot_encoder_param import OneHotEncoderParam
  22. from federatedml.protobuf.generated import onehot_param_pb2, onehot_meta_pb2
  23. from federatedml.statistic.data_overview import get_header
  24. from federatedml.util import LOGGER
  25. from federatedml.util import abnormal_detection
  26. from federatedml.util import consts
  27. from federatedml.util.io_check import assert_io_num_rows_equal
  28. MODEL_PARAM_NAME = 'OneHotParam'
  29. MODEL_META_NAME = 'OneHotMeta'
  30. MODEL_NAME = 'OneHotEncoder'
  31. class OneHotInnerParam(object):
  32. def __init__(self):
  33. self.col_name_maps = {}
  34. self.header = []
  35. self.transform_indexes = []
  36. self.transform_names = []
  37. self.result_header = []
  38. self.transform_index_set = set()
  39. def set_header(self, header):
  40. self.header = header
  41. for idx, col_name in enumerate(self.header):
  42. self.col_name_maps[col_name] = idx
  43. def set_result_header(self, result_header: list or tuple):
  44. self.result_header = result_header.copy()
  45. def set_transform_all(self):
  46. self.transform_indexes = [i for i in range(len(self.header))]
  47. self.transform_names = self.header
  48. self.transform_index_set = set(self.transform_indexes)
  49. def add_transform_indexes(self, transform_indexes):
  50. if transform_indexes is None:
  51. return
  52. for idx in transform_indexes:
  53. if idx >= len(self.header):
  54. LOGGER.warning("Adding a index that out of header's bound")
  55. continue
  56. if idx not in self.transform_index_set:
  57. self.transform_indexes.append(idx)
  58. self.transform_index_set.add(idx)
  59. self.transform_names.append(self.header[idx])
  60. def add_transform_names(self, transform_names):
  61. if transform_names is None:
  62. return
  63. for col_name in transform_names:
  64. idx = self.col_name_maps.get(col_name)
  65. if idx is None:
  66. LOGGER.warning("Adding a col_name that is not exist in header")
  67. continue
  68. if idx not in self.transform_index_set:
  69. self.transform_indexes.append(idx)
  70. self.transform_index_set.add(idx)
  71. self.transform_names.append(self.header[idx])
  72. class TransferPair(object):
  73. def __init__(self, name):
  74. self.name = name
  75. self._values = set()
  76. self._transformed_headers = {}
  77. self._ordered_header = None
  78. def add_value(self, value):
  79. if value in self._values:
  80. return
  81. self._values.add(value)
  82. if len(self._values) > consts.ONE_HOT_LIMIT:
  83. raise ValueError(f"Input data should not have more than {consts.ONE_HOT_LIMIT} "
  84. f"possible value when doing one-hot encode")
  85. # self._transformed_headers[value] = self.__encode_new_header(value)
  86. # LOGGER.debug(f"transformed_header: {self._transformed_headers}")
  87. @property
  88. def values(self):
  89. if self._ordered_header is None:
  90. return list(self._values)
  91. if len(self._ordered_header) != len(self._values):
  92. raise ValueError("Indicated order header length is not equal to value set,"
  93. f" ordered_header: {self._ordered_header}, values: {self._values}")
  94. return self._ordered_header
  95. def set_ordered_header(self, ordered_header):
  96. self._ordered_header = ordered_header
  97. @property
  98. def transformed_headers(self):
  99. return [self._transformed_headers[x] for x in self.values]
  100. def query_name_by_value(self, value):
  101. return self._transformed_headers.get(value, None)
  102. def encode_new_headers(self):
  103. for value in self._values:
  104. self._transformed_headers[value] = "_".join(map(str, [self.name, value]))
  105. def __encode_new_header(self, value):
  106. return '_'.join([str(x) for x in [self.name, value]])
  107. class OneHotEncoder(ModelBase):
  108. def __init__(self):
  109. super(OneHotEncoder, self).__init__()
  110. self.col_maps = {}
  111. self.schema = {}
  112. self.output_data = None
  113. self.model_param = OneHotEncoderParam()
  114. self.inner_param: OneHotInnerParam = None
  115. def _init_model(self, model_param):
  116. self.model_param = model_param
  117. # self.cols_index = model_param.cols
  118. def _abnormal_detection(self, data_instances):
  119. """
  120. Make sure input data_instances is valid.
  121. """
  122. abnormal_detection.empty_table_detection(data_instances)
  123. abnormal_detection.empty_feature_detection(data_instances)
  124. self.check_schema_content(data_instances.schema)
  125. def fit(self, data_instances):
  126. self._init_params(data_instances)
  127. self._abnormal_detection(data_instances)
  128. f1 = functools.partial(self.record_new_header,
  129. inner_param=self.inner_param)
  130. self.col_maps = data_instances.applyPartitions(f1).reduce(self.merge_col_maps)
  131. LOGGER.debug("Before set_schema in fit, schema is : {}, header: {}".format(self.schema,
  132. self.inner_param.header))
  133. for col_name, pair_obj in self.col_maps.items():
  134. pair_obj.encode_new_headers()
  135. self._transform_schema()
  136. data_instances = self.transform(data_instances)
  137. LOGGER.debug("After transform in fit, schema is : {}, header: {}".format(self.schema,
  138. self.inner_param.header))
  139. return data_instances
  140. @assert_io_num_rows_equal
  141. def transform(self, data_instances):
  142. self._init_params(data_instances)
  143. LOGGER.debug("In OneHot transform, ori_header: {}, transfered_header: {}".format(
  144. self.inner_param.header, self.inner_param.result_header
  145. ))
  146. # one_data = data_instances.first()[1].features
  147. # LOGGER.debug("Before transform, data is : {}".format(one_data))
  148. f = functools.partial(self.transfer_one_instance,
  149. col_maps=self.col_maps,
  150. header=self.inner_param.header,
  151. result_header=self.inner_param.result_header,
  152. result_header_index_mapping=dict(zip(self.inner_param.result_header,
  153. range(len(self.inner_param.result_header)))))
  154. new_data = data_instances.mapValues(f)
  155. self.set_schema(new_data)
  156. self.add_summary('transferred_dimension', len(self.inner_param.result_header))
  157. LOGGER.debug(f"Final summary: {self.summary()}")
  158. # one_data = new_data.first()[1].features
  159. # LOGGER.debug("transfered data is : {}".format(one_data))
  160. return new_data
  161. def _transform_schema(self):
  162. header = self.inner_param.header.copy()
  163. LOGGER.debug("[Result][OneHotEncoder]Before one-hot, "
  164. "data_instances schema is : {}".format(self.inner_param.header))
  165. result_header = []
  166. for col_name in header:
  167. if col_name not in self.col_maps:
  168. result_header.append(col_name)
  169. continue
  170. pair_obj = self.col_maps[col_name]
  171. new_headers = pair_obj.transformed_headers
  172. result_header.extend(new_headers)
  173. self.inner_param.set_result_header(result_header)
  174. LOGGER.debug("[Result][OneHotEncoder]After one-hot, data_instances schema is :"
  175. " {}".format(header))
  176. def _init_params(self, data_instances):
  177. if len(self.schema) == 0:
  178. self.schema = data_instances.schema
  179. if self.inner_param is not None:
  180. return
  181. self.inner_param = OneHotInnerParam()
  182. # self.schema = data_instances.schema
  183. LOGGER.debug("In _init_params, schema is : {}".format(self.schema))
  184. header = get_header(data_instances)
  185. self.add_summary("original_dimension", len(header))
  186. self.inner_param.set_header(header)
  187. if self.model_param.transform_col_indexes == -1:
  188. self.inner_param.set_transform_all()
  189. else:
  190. self.inner_param.add_transform_indexes(self.model_param.transform_col_indexes)
  191. self.inner_param.add_transform_names(self.model_param.transform_col_names)
  192. @staticmethod
  193. def record_new_header(data, inner_param: OneHotInnerParam):
  194. """
  195. Generate a new schema based on data value. Each new value will generate a new header.
  196. Returns
  197. -------
  198. col_maps: a dict in which keys are original header, values are dicts. The dicts in value
  199. e.g.
  200. cols_map = {"x1": {1 : "x1_1"},
  201. ...}
  202. """
  203. col_maps = {}
  204. for col_name in inner_param.transform_names:
  205. col_maps[col_name] = TransferPair(col_name)
  206. # col_idx_name_pairs = list(zip(inner_param.transform_indexes, inner_param.transform_names))
  207. for _, instance in data:
  208. feature = instance.features
  209. for col_idx, col_name in zip(inner_param.transform_indexes, inner_param.transform_names):
  210. pair_obj = col_maps.get(col_name)
  211. feature_value = feature[col_idx]
  212. if not isinstance(feature_value, str):
  213. feature_value = math.ceil(feature_value)
  214. if feature_value != feature[col_idx]:
  215. raise ValueError("Onehot input data support integer or string only")
  216. pair_obj.add_value(feature_value)
  217. return col_maps
  218. @staticmethod
  219. def encode_new_header(col_name, feature_value):
  220. return '_'.join([str(x) for x in [col_name, feature_value]])
  221. @staticmethod
  222. def merge_col_maps(col_map1, col_map2):
  223. if col_map1 is None and col_map2 is None:
  224. return None
  225. if col_map1 is None:
  226. return col_map2
  227. if col_map2 is None:
  228. return col_map1
  229. for col_name, pair_obj in col_map2.items():
  230. if col_name not in col_map1:
  231. col_map1[col_name] = pair_obj
  232. continue
  233. else:
  234. col_1_obj = col_map1[col_name]
  235. for value in pair_obj.values:
  236. col_1_obj.add_value(value)
  237. return col_map1
  238. @staticmethod
  239. def transfer_one_instance(instance, col_maps, header, result_header, result_header_index_mapping):
  240. new_inst = instance.copy(exclusive_attr={"features"})
  241. feature = instance.features
  242. # _transformed_value = {}
  243. new_feature = [0] * len(result_header)
  244. for idx, col_name in enumerate(header):
  245. value = feature[idx]
  246. if col_name in result_header_index_mapping:
  247. result_idx = result_header_index_mapping.get(col_name)
  248. new_feature[result_idx] = value
  249. # _transformed_value[col_name] = value
  250. else:
  251. pair_obj = col_maps.get(col_name, None)
  252. if not pair_obj:
  253. continue
  254. new_col_name = pair_obj.query_name_by_value(value)
  255. if new_col_name is None:
  256. continue
  257. result_idx = result_header_index_mapping.get(new_col_name)
  258. new_feature[result_idx] = 1
  259. # _transformed_value[new_col_name] = 1
  260. feature_array = np.array(new_feature)
  261. new_inst.features = feature_array
  262. return new_inst
  263. def set_schema(self, data_instance):
  264. derived_header = dict()
  265. for col_name, pair_obj in self.col_maps.items():
  266. derived_header[col_name] = pair_obj.transformed_headers
  267. self.schema["anonymous_header"] = self.anonymous_generator.generate_derived_header(
  268. self.schema["header"],
  269. self.schema["anonymous_header"],
  270. derived_header)
  271. self.schema['header'] = self.inner_param.result_header
  272. data_instance.schema = self.schema
  273. def _get_meta(self):
  274. meta_protobuf_obj = onehot_meta_pb2.OneHotMeta(transform_col_names=self.inner_param.transform_names,
  275. header=self.inner_param.header,
  276. need_run=self.need_run)
  277. return meta_protobuf_obj
  278. def _get_param(self):
  279. pb_dict = {}
  280. for col_name, pair_obj in self.col_maps.items():
  281. values = [str(x) for x in pair_obj.values]
  282. value_dict_obj = onehot_param_pb2.ColsMap(values=values,
  283. transformed_headers=pair_obj.transformed_headers)
  284. pb_dict[col_name] = value_dict_obj
  285. result_obj = onehot_param_pb2.OneHotParam(col_map=pb_dict,
  286. result_header=self.inner_param.result_header)
  287. return result_obj
  288. def export_model(self):
  289. if self.model_output is not None:
  290. LOGGER.debug("Model output is : {}".format(self.model_output))
  291. return self.model_output
  292. if self.inner_param is None:
  293. self.inner_param = OneHotInnerParam()
  294. meta_obj = self._get_meta()
  295. param_obj = self._get_param()
  296. result = {
  297. MODEL_META_NAME: meta_obj,
  298. MODEL_PARAM_NAME: param_obj
  299. }
  300. return result
  301. def load_model(self, model_dict):
  302. self._parse_need_run(model_dict, MODEL_META_NAME)
  303. model_param = list(model_dict.get('model').values())[0].get(MODEL_PARAM_NAME)
  304. model_meta = list(model_dict.get('model').values())[0].get(MODEL_META_NAME)
  305. self.model_output = {
  306. MODEL_META_NAME: model_meta,
  307. MODEL_PARAM_NAME: model_param
  308. }
  309. self.inner_param = OneHotInnerParam()
  310. self.inner_param.set_header(list(model_meta.header))
  311. self.inner_param.add_transform_names(list(model_meta.transform_col_names))
  312. col_maps = dict(model_param.col_map)
  313. self.col_maps = {}
  314. for col_name, cols_map_obj in col_maps.items():
  315. if col_name not in self.col_maps:
  316. self.col_maps[col_name] = TransferPair(col_name)
  317. pair_obj = self.col_maps[col_name]
  318. for feature_value in list(cols_map_obj.values):
  319. try:
  320. feature_value = eval(feature_value)
  321. except NameError:
  322. pass
  323. pair_obj.add_value(feature_value)
  324. pair_obj.encode_new_headers()
  325. self.inner_param.set_result_header(list(model_param.result_header))