column_expand.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import copy
  17. from federatedml.model_base import ModelBase
  18. from federatedml.param.column_expand_param import ColumnExpandParam
  19. from federatedml.protobuf.generated import column_expand_meta_pb2, column_expand_param_pb2
  20. from federatedml.util import consts, LOGGER, data_format_preprocess
  21. DELIMITER = ","
  22. class FeatureGenerator(object):
  23. def __init__(self, method, append_header, fill_value):
  24. self.method = method
  25. self.append_header = append_header
  26. self.fill_value = fill_value
  27. self.append_value = self._get_append_value()
  28. self.generator = self._get_generator()
  29. def _get_append_value(self):
  30. if len(self.fill_value) == 0:
  31. return
  32. if len(self.fill_value) == 1:
  33. fill_value = str(self.fill_value[0])
  34. new_features = [fill_value] * len(self.append_header)
  35. append_value = DELIMITER.join(new_features)
  36. else:
  37. append_value = DELIMITER.join([str(v) for v in self.fill_value])
  38. return append_value
  39. def _get_generator(self):
  40. while True:
  41. yield self.append_value
  42. def generate(self):
  43. return next(self.generator)
  44. class ColumnExpand(ModelBase):
  45. def __init__(self):
  46. super(ColumnExpand, self).__init__()
  47. self.model_param = ColumnExpandParam()
  48. self.need_run = None
  49. self.append_header = None
  50. self.method = None
  51. self.fill_value = None
  52. self.summary_obj = None
  53. self.header = None
  54. self.new_feature_generator = None
  55. self.model_param_name = 'ColumnExpandParam'
  56. self.model_meta_name = 'ColumnExpandMeta'
  57. def _init_model(self, params):
  58. self.model_param = params
  59. self.need_run = params.need_run
  60. self.append_header = params.append_header
  61. self.method = params.method
  62. self.fill_value = params.fill_value
  63. self.new_feature_generator = FeatureGenerator(params.method,
  64. params.append_header,
  65. params.fill_value)
  66. @staticmethod
  67. def _append_feature(entry, append_value):
  68. # empty content
  69. if entry is None or len(entry) == 0:
  70. new_entry = append_value
  71. else:
  72. new_entry = entry + DELIMITER + append_value
  73. return new_entry
  74. def _append_column(self, data):
  75. append_value = self.new_feature_generator.generate()
  76. new_data = data.mapValues(lambda v: ColumnExpand._append_feature(v, append_value))
  77. new_schema = copy.deepcopy(data.schema)
  78. header = new_schema.get("header", "")
  79. new_schema = data_format_preprocess.DataFormatPreProcess.extend_header(new_schema, self.append_header)
  80. if len(header) == 0:
  81. if new_schema.get("sid", None) is not None:
  82. new_schema["sid"] = new_schema.get("sid").strip()
  83. if new_schema.get("meta"):
  84. anonymous_header = new_schema.get("anonymous_header", [])
  85. new_anonymous_header = self.anonymous_generator.extend_columns(anonymous_header,
  86. self.append_header)
  87. new_schema["anonymous_header"] = new_anonymous_header
  88. new_data.schema = new_schema
  89. new_header = new_schema.get("header")
  90. return new_data, new_header
  91. def _get_meta(self):
  92. meta = column_expand_meta_pb2.ColumnExpandMeta(
  93. append_header=self.append_header,
  94. method=self.method,
  95. fill_value=[str(v) for v in self.fill_value],
  96. need_run=self.need_run
  97. )
  98. return meta
  99. def _get_param(self):
  100. param = column_expand_param_pb2.ColumnExpandParam(header=self.header)
  101. return param
  102. def export_model(self):
  103. meta_obj = self._get_meta()
  104. param_obj = self._get_param()
  105. result = {
  106. self.model_meta_name: meta_obj,
  107. self.model_param_name: param_obj
  108. }
  109. self.model_output = result
  110. return result
  111. def load_model(self, model_dict):
  112. meta_obj = list(model_dict.get('model').values())[0].get(self.model_meta_name)
  113. param_obj = list(model_dict.get('model').values())[0].get(self.model_param_name)
  114. self.append_header = list(meta_obj.append_header)
  115. self.method = meta_obj.method
  116. self.fill_value = list(meta_obj.fill_value)
  117. self.need_run = meta_obj.need_run
  118. self.new_feature_generator = FeatureGenerator(self.method,
  119. self.append_header,
  120. self.fill_value)
  121. self.header = param_obj.header
  122. return
  123. def fit(self, data):
  124. LOGGER.info(f"Enter Column Expand fit")
  125. # return original value if no append header provided
  126. if self.method == consts.MANUAL and len(self.append_header) == 0:
  127. LOGGER.info(f"Finish Column Expand fit. Original data returned.")
  128. self.header = data.schema["header"]
  129. return data
  130. new_data, self.header = self._append_column(data)
  131. LOGGER.info(f"Finish Column Expand fit")
  132. return new_data
  133. def transform(self, data):
  134. LOGGER.info(f"Enter Column Expand transform")
  135. if self.method == consts.MANUAL and len(self.append_header) == 0:
  136. LOGGER.info(f"Finish Column Expand transform. Original data returned.")
  137. return data
  138. new_data, self.header = self._append_column(data)
  139. LOGGER.info(f"Finish Column Expand transform")
  140. return new_data