base_feature_selection.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import copy
  18. import functools
  19. import random
  20. from federatedml.feature.feature_selection import filter_factory
  21. from federatedml.feature.feature_selection.model_adapter.adapter_factory import adapter_factory
  22. from federatedml.feature.feature_selection.selection_properties import SelectionProperties, CompletedSelectionResults
  23. from federatedml.model_base import ModelBase
  24. from federatedml.param.feature_selection_param import FeatureSelectionParam
  25. from federatedml.protobuf.generated import feature_selection_param_pb2, feature_selection_meta_pb2
  26. from federatedml.statistic.data_overview import get_header, \
  27. get_anonymous_header, look_up_names_from_header, header_alignment
  28. from federatedml.transfer_variable.transfer_class.hetero_feature_selection_transfer_variable import \
  29. HeteroFeatureSelectionTransferVariable
  30. from federatedml.util import LOGGER
  31. from federatedml.util import abnormal_detection
  32. from federatedml.util import consts
  33. from federatedml.util.io_check import assert_io_num_rows_equal
  34. from federatedml.util.schema_check import assert_schema_consistent
  35. MODEL_PARAM_NAME = 'FeatureSelectionParam'
  36. MODEL_META_NAME = 'FeatureSelectionMeta'
  37. MODEL_NAME = 'HeteroFeatureSelection'
  38. class BaseHeteroFeatureSelection(ModelBase):
  39. def __init__(self):
  40. super(BaseHeteroFeatureSelection, self).__init__()
  41. self.transfer_variable = HeteroFeatureSelectionTransferVariable()
  42. self.curt_select_properties = SelectionProperties()
  43. self.completed_selection_result = CompletedSelectionResults()
  44. self.loaded_local_select_properties = dict()
  45. self.loaded_host_filter_results = dict()
  46. self.schema = None
  47. self.header = None
  48. self.anonymous_header = None
  49. self.party_name = 'Base'
  50. # Possible previous model
  51. self.binning_model = None
  52. self.static_obj = None
  53. self.model_param = FeatureSelectionParam()
  54. # self.meta_dicts = {}
  55. self.meta_list = []
  56. self.isometric_models = {}
  57. def _init_model(self, params):
  58. self.model_param = params
  59. # self.cols_index = params.select_cols
  60. self.filter_methods = params.filter_methods
  61. # self.local_only = params.local_only
  62. def _init_select_params(self, data_instances):
  63. if self.schema is None:
  64. self.schema = data_instances.schema
  65. if self.header is not None:
  66. # load current data anonymous header for prediction with model of version < 1.9.0
  67. # if len(self.completed_selection_result.anonymous_header) == 0:
  68. if self.anonymous_header is None:
  69. data_anonymous_header = get_anonymous_header(data_instances)
  70. # LOGGER.info(f"data_anonymous_header: {data_anonymous_header}")
  71. self.anonymous_header = data_anonymous_header
  72. self.completed_selection_result.set_anonymous_header(data_anonymous_header)
  73. if self.role == consts.HOST:
  74. anonymous_header_in_old_format = self.anonymous_generator. \
  75. generated_compatible_anonymous_header_with_old_version(data_anonymous_header)
  76. anonymous_dict = dict(zip(anonymous_header_in_old_format, data_anonymous_header))
  77. self.transfer_variable.host_anonymous_header_dict.remote(anonymous_dict,
  78. role=consts.GUEST,
  79. idx=0)
  80. for filter_name, select_properties in self.loaded_local_select_properties.items():
  81. self.completed_selection_result.add_filter_results(filter_name, select_properties)
  82. else:
  83. host_anonymous_dict_list = self.transfer_variable.host_anonymous_header_dict.get(idx=-1)
  84. for filter_name, cur_select_properties in self.loaded_local_select_properties.items():
  85. cur_host_select_properties_list = []
  86. host_feature_values_obj_list, host_left_cols_obj_list = self.loaded_host_filter_results[
  87. filter_name]
  88. for i, host_left_cols_obj in enumerate(host_left_cols_obj_list):
  89. cur_host_select_properties = SelectionProperties()
  90. old_host_header = list(host_anonymous_dict_list[i].keys())
  91. host_feature_values = host_feature_values_obj_list[i].feature_values
  92. cur_host_select_properties.load_properties_with_new_header(old_host_header,
  93. host_feature_values,
  94. host_left_cols_obj,
  95. host_anonymous_dict_list[i])
  96. cur_host_select_properties_list.append(cur_host_select_properties)
  97. self.completed_selection_result.add_filter_results(filter_name,
  98. cur_select_properties,
  99. cur_host_select_properties_list)
  100. return
  101. self.schema = data_instances.schema
  102. header = get_header(data_instances)
  103. anonymous_header = get_anonymous_header(data_instances)
  104. self.header = header
  105. self.anonymous_header = anonymous_header
  106. self.curt_select_properties.set_header(header)
  107. # use anonymous header of input data
  108. self.curt_select_properties.set_anonymous_header(anonymous_header)
  109. self.curt_select_properties.set_last_left_col_indexes([x for x in range(len(header))])
  110. if self.model_param.select_col_indexes == -1:
  111. self.curt_select_properties.set_select_all_cols()
  112. else:
  113. self.curt_select_properties.add_select_col_indexes(self.model_param.select_col_indexes)
  114. if self.model_param.use_anonymous:
  115. select_names = look_up_names_from_header(self.model_param.select_names, anonymous_header, header)
  116. # LOGGER.debug(f"use_anonymous is true, select names: {select_names}")
  117. else:
  118. select_names = self.model_param.select_names
  119. self.curt_select_properties.add_select_col_names(select_names)
  120. self.completed_selection_result.set_header(header)
  121. self.completed_selection_result.set_anonymous_header(anonymous_header)
  122. self.completed_selection_result.set_select_col_names(self.curt_select_properties.select_col_names)
  123. self.completed_selection_result.set_all_left_col_indexes(self.curt_select_properties.all_left_col_indexes)
  124. def _get_meta(self):
  125. meta_dicts = {'filter_methods': self.filter_methods,
  126. 'cols': self.completed_selection_result.get_select_col_names(),
  127. 'need_run': self.need_run,
  128. "filter_metas": self.meta_list}
  129. meta_protobuf_obj = feature_selection_meta_pb2.FeatureSelectionMeta(**meta_dicts)
  130. return meta_protobuf_obj
  131. def _get_param(self):
  132. # LOGGER.debug("curt_select_properties.left_col_name: {}, completed_selection_result: {}".format(
  133. # self.curt_select_properties.left_col_names, self.completed_selection_result.all_left_col_names
  134. # ))
  135. # LOGGER.debug("Length of left cols: {}".format(len(self.completed_selection_result.all_left_col_names)))
  136. # left_cols = {x: True for x in self.curt_select_properties.left_col_names}
  137. left_cols = {x: True for x in self.completed_selection_result.all_left_col_names}
  138. final_left_cols = feature_selection_param_pb2.LeftCols(
  139. original_cols=self.completed_selection_result.get_select_col_names(),
  140. left_cols=left_cols
  141. )
  142. host_col_names = []
  143. if self.role == consts.GUEST:
  144. for host_id, this_host_name in enumerate(self.completed_selection_result.get_host_sorted_col_names()):
  145. party_id = self.component_properties.host_party_idlist[host_id]
  146. # LOGGER.debug("In _get_param, this_host_name: {}, party_id: {}".format(this_host_name, party_id))
  147. host_col_names.append(feature_selection_param_pb2.HostColNames(col_names=this_host_name,
  148. party_id=str(party_id)))
  149. else:
  150. party_id = self.component_properties.local_partyid
  151. # if self.anonymous_header:
  152. # anonymous_names = self.anonymous_header
  153. """else:
  154. anonymous_names = [anonymous_generator.generate_anonymous(fid, model=self)
  155. for fid in range(len(self.header))]
  156. """
  157. host_col_names.append(feature_selection_param_pb2.HostColNames(col_names=self.anonymous_header,
  158. party_id=str(party_id)))
  159. col_name_to_anonym_dict = None
  160. if self.header and self.anonymous_header:
  161. col_name_to_anonym_dict = dict(zip(self.header, self.anonymous_header))
  162. result_obj = feature_selection_param_pb2.FeatureSelectionParam(
  163. results=self.completed_selection_result.filter_results,
  164. final_left_cols=final_left_cols,
  165. col_names=self.completed_selection_result.get_sorted_col_names(),
  166. host_col_names=host_col_names,
  167. header=self.curt_select_properties.header,
  168. col_name_to_anonym_dict=col_name_to_anonym_dict
  169. )
  170. return result_obj
  171. def save_data(self):
  172. return self.data_output
  173. def export_model(self):
  174. # LOGGER.debug("Model output is : {}".format(self.model_output))
  175. """
  176. if self.model_output is not None:
  177. LOGGER.debug("model output already exists, return directly")
  178. return self.model_output
  179. """
  180. meta_obj = self._get_meta()
  181. param_obj = self._get_param()
  182. result = {
  183. MODEL_META_NAME: meta_obj,
  184. MODEL_PARAM_NAME: param_obj
  185. }
  186. self.model_output = result
  187. return result
  188. def _load_selection_model(self, model_dict):
  189. LOGGER.debug("Feature selection need run: {}".format(self.need_run))
  190. if not self.need_run:
  191. return
  192. model_param = list(model_dict.get('model').values())[0].get(MODEL_PARAM_NAME)
  193. model_meta = list(model_dict.get('model').values())[0].get(MODEL_META_NAME)
  194. self.model_output = {
  195. MODEL_META_NAME: model_meta,
  196. MODEL_PARAM_NAME: model_param
  197. }
  198. header = list(model_param.header)
  199. # LOGGER.info(f"col_name_to_anonym_dict: {model_param.col_name_to_anonym_dict}")
  200. self.header = header
  201. self.curt_select_properties.set_header(header)
  202. self.completed_selection_result.set_header(header)
  203. self.curt_select_properties.set_last_left_col_indexes([x for x in range(len(header))])
  204. self.curt_select_properties.add_select_col_names(header)
  205. # for model ver >= 1.9.0
  206. if model_param.col_name_to_anonym_dict:
  207. col_name_to_anonym_dict = dict(model_param.col_name_to_anonym_dict)
  208. self.anonymous_header = [col_name_to_anonym_dict[x] for x in header]
  209. self.completed_selection_result.set_anonymous_header(self.anonymous_header)
  210. host_col_names_list = model_param.host_col_names
  211. for result in model_param.results:
  212. cur_select_properties = copy.deepcopy(self.curt_select_properties)
  213. feature_values, left_cols_obj = dict(result.feature_values), result.left_cols
  214. cur_select_properties.load_properties(header, feature_values, left_cols_obj)
  215. cur_host_select_properties_list = []
  216. host_feature_values_obj_list = list(result.host_feature_values)
  217. host_left_cols_obj_list = list(result.host_left_cols)
  218. for i, host_left_cols_obj in enumerate(host_left_cols_obj_list):
  219. cur_host_select_properties = SelectionProperties()
  220. host_col_names_obj = host_col_names_list[i]
  221. host_header = list(host_col_names_obj.col_names)
  222. host_feature_values = host_feature_values_obj_list[i].feature_values
  223. cur_host_select_properties.load_properties(host_header, host_feature_values, host_left_cols_obj)
  224. cur_host_select_properties_list.append(cur_host_select_properties)
  225. self.completed_selection_result.add_filter_results(result.filter_name,
  226. cur_select_properties,
  227. cur_host_select_properties_list)
  228. # for model ver 1.8.x
  229. else:
  230. LOGGER.warning(f"Anonymous column name dictionary not found in given model."
  231. f"Will infer host(s)' anonymous names.")
  232. """
  233. self.loaded_host_col_names_list = [list(host_col_names_obj.col_names)
  234. for host_col_names_obj in model_param.host_col_names]
  235. """
  236. for result in model_param.results:
  237. cur_select_properties = copy.deepcopy(self.curt_select_properties)
  238. feature_values, left_cols_obj = dict(result.feature_values), result.left_cols
  239. cur_select_properties.load_properties(header, feature_values, left_cols_obj)
  240. # record local select properties
  241. self.loaded_local_select_properties[result.filter_name] = cur_select_properties
  242. host_feature_values_obj_list = list(result.host_feature_values)
  243. host_left_cols_obj_list = list(result.host_left_cols)
  244. self.loaded_host_filter_results[result.filter_name] = (host_feature_values_obj_list,
  245. host_left_cols_obj_list)
  246. final_left_cols_names = dict(model_param.final_left_cols.left_cols)
  247. # LOGGER.debug("final_left_cols_names: {}".format(final_left_cols_names))
  248. for col_name, _ in final_left_cols_names.items():
  249. self.curt_select_properties.add_left_col_name(col_name)
  250. self.completed_selection_result.add_filter_results(filter_name='conclusion',
  251. select_properties=self.curt_select_properties)
  252. self.update_curt_select_param()
  253. def _load_isometric_model(self, iso_model):
  254. LOGGER.debug(f"When loading isometric_model, iso_model names are:"
  255. f" {iso_model.keys()}")
  256. for cpn_name, model_dict in iso_model.items():
  257. model_param = None
  258. model_meta = None
  259. for name, model_pb in model_dict.items():
  260. if name.endswith("Param"):
  261. model_param = model_pb
  262. else:
  263. model_meta = model_pb
  264. model_name = model_param.model_name
  265. if model_name in self.isometric_models:
  266. raise ValueError("Should not load two same type isometric models"
  267. " in feature selection")
  268. adapter = adapter_factory(model_name)
  269. this_iso_model = adapter.convert(model_meta, model_param)
  270. self.isometric_models[model_name] = this_iso_model
  271. def load_model(self, model_dict):
  272. LOGGER.debug(f"Start to load model")
  273. if 'model' in model_dict:
  274. LOGGER.debug("Loading selection model")
  275. self._load_selection_model(model_dict)
  276. if 'isometric_model' in model_dict:
  277. LOGGER.debug("Loading isometric_model")
  278. self._load_isometric_model(model_dict['isometric_model'])
  279. @staticmethod
  280. def select_cols(instance, left_col_idx):
  281. instance.features = instance.features[left_col_idx]
  282. return instance
  283. def _transfer_data(self, data_instances):
  284. f = functools.partial(self.select_cols,
  285. left_col_idx=self.completed_selection_result.all_left_col_indexes)
  286. new_data = data_instances.mapValues(f)
  287. # LOGGER.debug("When transfering, all left_col_names: {}".format(
  288. # self.completed_selection_result.all_left_col_names
  289. # ))
  290. new_data = self.set_schema(new_data,
  291. self.completed_selection_result.all_left_col_names,
  292. self.completed_selection_result.all_left_anonymous_col_names)
  293. # one_data = new_data.first()[1]
  294. # LOGGER.debug(
  295. # "In feature selection transform, Before transform: {}, length: {} After transform: {}, length: {}".format(
  296. # before_one_data[1].features, len(before_one_data[1].features),
  297. # one_data.features, len(one_data.features)))
  298. return new_data
  299. def _abnormal_detection(self, data_instances):
  300. """
  301. Make sure input data_instances is valid.
  302. """
  303. abnormal_detection.empty_table_detection(data_instances)
  304. abnormal_detection.empty_feature_detection(data_instances)
  305. self.check_schema_content(data_instances.schema)
  306. def set_schema(self, data_instance, header=None, anonymous_header=None):
  307. if header is None:
  308. self.schema["header"] = self.curt_select_properties.header
  309. self.schema["anonymous_header"] = self.curt_select_properties.anonymous_header
  310. else:
  311. self.schema["header"] = header
  312. self.schema["anonymous_header"] = anonymous_header
  313. data_instance.schema = self.schema
  314. return data_instance
  315. def update_curt_select_param(self):
  316. new_select_properties = SelectionProperties()
  317. # all select properties must have the same header
  318. new_select_properties.set_header(self.curt_select_properties.header)
  319. new_select_properties.set_anonymous_header(self.curt_select_properties.anonymous_header)
  320. new_select_properties.set_last_left_col_indexes(self.curt_select_properties.all_left_col_indexes)
  321. new_select_properties.add_select_col_names(self.curt_select_properties.left_col_names)
  322. self.curt_select_properties = new_select_properties
  323. def _filter(self, data_instances, method, suffix, idx=0):
  324. this_filter = filter_factory.get_filter(filter_name=method, model_param=self.model_param,
  325. role=self.role, model=self, idx=idx)
  326. if method == consts.STATISTIC_FILTER:
  327. method = self.model_param.statistic_param.metrics[idx]
  328. elif method == consts.IV_FILTER:
  329. metric = self.model_param.iv_param.metrics[idx]
  330. f_type = self.model_param.iv_param.filter_type[idx]
  331. method = f"{metric}_{f_type}"
  332. elif method == consts.PSI_FILTER:
  333. metric = self.model_param.psi_param.metrics[idx]
  334. f_type = self.model_param.psi_param.filter_type[idx]
  335. method = f"{metric}_{f_type}"
  336. this_filter.set_selection_properties(self.curt_select_properties)
  337. this_filter.set_transfer_variable(self.transfer_variable)
  338. # .info(f"this_filter type: {this_filter.filter_type}, method: {method}, filter obj: {this_filter}")
  339. self.curt_select_properties = this_filter.fit(data_instances, suffix).selection_properties
  340. # LOGGER.info(f"filter.fit called")
  341. host_select_properties = getattr(this_filter, 'host_selection_properties', None)
  342. # if host_select_properties is not None:
  343. # LOGGER.debug("method: {}, host_select_properties: {}".format(
  344. # method, host_select_properties[0].all_left_col_names))
  345. self.completed_selection_result.add_filter_results(filter_name=method,
  346. select_properties=self.curt_select_properties,
  347. host_select_properties=host_select_properties)
  348. last_col_nums = len(self.curt_select_properties.last_left_col_names)
  349. left_col_names = self.curt_select_properties.left_col_names
  350. self.add_summary(method, {
  351. "last_col_nums": last_col_nums,
  352. "left_col_nums": len(left_col_names),
  353. "left_col_names": left_col_names
  354. })
  355. # LOGGER.debug("method: {}, selection_cols: {}, left_cols: {}".format(
  356. # method, self.curt_select_properties.select_col_names, self.curt_select_properties.left_col_names))
  357. self.update_curt_select_param()
  358. # LOGGER.debug("After updated, method: {}, selection_cols: {}".format(
  359. # method, self.curt_select_properties.select_col_names))
  360. self.meta_list.append(this_filter.get_meta_obj())
  361. def fit(self, data_instances):
  362. LOGGER.info("Start Hetero Selection Fit and transform.")
  363. self._abnormal_detection(data_instances)
  364. self._init_select_params(data_instances)
  365. original_col_nums = len(self.curt_select_properties.last_left_col_names)
  366. empty_cols = False
  367. if len(self.curt_select_properties.select_col_indexes) == 0:
  368. LOGGER.warning("None of columns has been set to select, "
  369. "will randomly select one column to participate in fitting filter(s). "
  370. "All columns will be kept, "
  371. "but be aware that this may lead to unexpected behavior.")
  372. header = data_instances.schema.get("header")
  373. select_idx = random.choice(range(len(header)))
  374. self.curt_select_properties.select_col_indexes = [select_idx]
  375. self.curt_select_properties.select_col_names = [header[select_idx]]
  376. empty_cols = True
  377. suffix = self.filter_methods
  378. if self.role == consts.HOST:
  379. self.transfer_variable.host_empty_cols.remote(empty_cols, role=consts.GUEST, idx=0, suffix=suffix)
  380. else:
  381. host_empty_cols_list = self.transfer_variable.host_empty_cols.get(idx=-1, suffix=suffix)
  382. host_list = self.component_properties.host_party_idlist
  383. for idx, res in enumerate(host_empty_cols_list):
  384. if res:
  385. LOGGER.warning(f"Host {host_list[idx]}'s select columns are empty;"
  386. f"host {host_list[idx]} will randomly select one "
  387. f"column to participate in fitting filter(s). "
  388. f"All columns from this host will be kept, "
  389. f"but be aware that this may lead to unexpected behavior.")
  390. for filter_idx, method in enumerate(self.filter_methods):
  391. if method in [consts.STATISTIC_FILTER, consts.IV_FILTER, consts.PSI_FILTER,
  392. consts.HETERO_SBT_FILTER, consts.HOMO_SBT_FILTER, consts.HETERO_FAST_SBT_FILTER,
  393. consts.VIF_FILTER]:
  394. if method == consts.STATISTIC_FILTER:
  395. metrics = self.model_param.statistic_param.metrics
  396. elif method == consts.IV_FILTER:
  397. metrics = self.model_param.iv_param.metrics
  398. elif method == consts.PSI_FILTER:
  399. metrics = self.model_param.psi_param.metrics
  400. elif method in [consts.HETERO_SBT_FILTER, consts.HOMO_SBT_FILTER, consts.HETERO_FAST_SBT_FILTER]:
  401. metrics = self.model_param.sbt_param.metrics
  402. elif method == consts.VIF_FILTER:
  403. metrics = self.model_param.vif_param.metrics
  404. else:
  405. raise ValueError(f"method: {method} is not supported")
  406. for idx, _ in enumerate(metrics):
  407. self._filter(data_instances, method,
  408. suffix=(str(filter_idx), str(idx)), idx=idx)
  409. else:
  410. self._filter(data_instances, method, suffix=str(filter_idx))
  411. last_col_nums = self.curt_select_properties.last_left_col_names
  412. self.add_summary("all", {
  413. "last_col_nums": original_col_nums,
  414. "left_col_nums": len(last_col_nums),
  415. "left_col_names": last_col_nums
  416. })
  417. new_data = self._transfer_data(data_instances)
  418. # LOGGER.debug(f"Final summary: {self.summary()}")
  419. LOGGER.info("Finish Hetero Selection Fit and transform.")
  420. return new_data
  421. @assert_io_num_rows_equal
  422. @assert_schema_consistent
  423. def transform(self, data_instances):
  424. self._abnormal_detection(data_instances)
  425. self._init_select_params(data_instances)
  426. # align data instance to model header & anonymous header
  427. data_instances = header_alignment(data_instances, self.header, self.anonymous_header)
  428. new_data = self._transfer_data(data_instances)
  429. return new_data