tree_adapter.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import numpy as np
  2. from federatedml.feature.feature_selection.model_adapter import isometric_model
  3. from federatedml.feature.feature_selection.model_adapter.adapter_base import BaseAdapter
  4. from federatedml.util import consts
  5. def feature_importance_converter(model_meta, model_param):
  6. # extract feature importance from model param
  7. fid_mapping = dict(model_param.feature_name_fid_mapping)
  8. feat_importance_list = list(model_param.feature_importances)
  9. fids = list(fid_mapping.keys())
  10. cols_names, importance_val = [], []
  11. for feat_importance in feat_importance_list:
  12. site_name = feat_importance.sitename
  13. fid = feat_importance.fid
  14. importance = feat_importance.importance
  15. feature_name = fid_mapping[fid]
  16. cols_names.append(feature_name)
  17. importance_val.append(importance)
  18. for fid in fids:
  19. if fid_mapping[fid] not in cols_names:
  20. cols_names.append(fid_mapping[fid])
  21. importance_val.append(0)
  22. single_info = isometric_model.SingleMetricInfo(
  23. values=np.array(importance_val),
  24. col_names=cols_names
  25. )
  26. result = isometric_model.IsometricModel()
  27. result.add_metric_value(metric_name=consts.FEATURE_IMPORTANCE, metric_info=single_info)
  28. return result
  29. def feature_importance_with_anonymous_converter(model_meta, model_param):
  30. # extract feature importance from model param
  31. fid_mapping = dict(model_param.feature_name_fid_mapping)
  32. feat_importance_list = list(model_param.feature_importances)
  33. local_fids = list(fid_mapping.keys())
  34. local_cols, local_val = [], []
  35. # key is int party id, value is a dict, which has two key: col_name and value
  36. host_side_data = {}
  37. for feat_importance in feat_importance_list:
  38. fid = feat_importance.fid
  39. importance = feat_importance.importance
  40. site_name = feat_importance.sitename
  41. if site_name == consts.HOST_LOCAL:
  42. local_cols.append(fid_mapping[fid])
  43. local_val.append(importance)
  44. else:
  45. site_name = site_name.split(':')
  46. if site_name[0] == consts.HOST:
  47. continue
  48. else:
  49. local_cols.append(fid_mapping[fid])
  50. local_val.append(importance)
  51. for fid in local_fids:
  52. if fid_mapping[fid] not in local_cols:
  53. local_cols.append(fid_mapping[fid])
  54. local_val.append(0)
  55. single_info = isometric_model.SingleMetricInfo(
  56. values=np.array(local_val),
  57. col_names=local_cols
  58. )
  59. result = isometric_model.IsometricModel()
  60. result.add_metric_value(metric_name=consts.FEATURE_IMPORTANCE, metric_info=single_info)
  61. return result
  62. class HomoSBTAdapter(BaseAdapter):
  63. def convert(self, model_meta, model_param):
  64. return feature_importance_converter(model_meta, model_param)
  65. class HeteroSBTAdapter(BaseAdapter):
  66. def convert(self, model_meta, model_param):
  67. return feature_importance_with_anonymous_converter(model_meta, model_param)
  68. class HeteroFastSBTAdapter(BaseAdapter):
  69. def convert(self, model_meta, model_param):
  70. model_name = model_param.model_name
  71. if model_name == consts.HETERO_FAST_SBT_LAYERED:
  72. return feature_importance_with_anonymous_converter(model_meta, model_param)
  73. elif model_name == consts.HETERO_FAST_SBT_MIX:
  74. return feature_importance_with_anonymous_converter(model_meta, model_param)
  75. else:
  76. raise ValueError('model name {} is illegal'.format(model_name))