union.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from federatedml.model_base import Metric, MetricMeta
  17. from federatedml.feature.instance import Instance
  18. from federatedml.model_base import ModelBase
  19. from federatedml.param.union_param import UnionParam
  20. from federatedml.statistic import data_overview
  21. from federatedml.util import LOGGER
  22. from federatedml.util.schema_check import assert_schema_consistent
  23. class Union(ModelBase):
  24. def __init__(self):
  25. super().__init__()
  26. self.model_param = UnionParam()
  27. self.metric_name = "union"
  28. self.metric_namespace = "train"
  29. self.metric_type = "UNION"
  30. self.repeated_ids = None
  31. self.key = None
  32. def _init_model(self, params):
  33. self.model_param = params
  34. self.allow_missing = params.allow_missing
  35. self.keep_duplicate = params.keep_duplicate
  36. self.feature_count = 0
  37. self.is_data_instance = None
  38. self.is_empty_feature = False
  39. @staticmethod
  40. def _keep_first(v1, v2):
  41. return v1
  42. def _renew_id(self, k, v):
  43. result = []
  44. if k in self.repeated_ids:
  45. new_k = f"{k}_{self.key}"
  46. result.append((new_k, v))
  47. else:
  48. result.append((k, v))
  49. return result
  50. def check_id(self, local_table, combined_table):
  51. local_schema, combined_schema = local_table.schema, combined_table.schema
  52. local_sid_name = local_schema.get("sid")
  53. combined_sid_name = combined_schema.get("sid")
  54. if local_sid_name != combined_sid_name:
  55. raise ValueError(f"Id names {local_sid_name} and {combined_sid_name} do not match! "
  56. f"Please check id column names.")
  57. def check_label_name(self, local_table, combined_table):
  58. if not self.is_data_instance:
  59. return
  60. local_schema, combined_schema = local_table.schema, combined_table.schema
  61. local_label_name = local_schema.get("label_name")
  62. combined_label_name = combined_schema.get("label_name")
  63. if local_label_name is None and combined_label_name is None:
  64. return
  65. if local_label_name is None or combined_label_name is None:
  66. raise ValueError("Union try to combine a labeled data set with an unlabelled one."
  67. "Please check labels.")
  68. if local_label_name != combined_label_name:
  69. raise ValueError("Label names do not match. "
  70. "Please check label column names.")
  71. def check_header(self, local_table, combined_table):
  72. local_schema, combined_schema = local_table.schema, combined_table.schema
  73. local_header = local_schema.get("header")
  74. combined_header = combined_schema.get("header")
  75. if local_header != combined_header:
  76. raise ValueError("Table headers do not match! Please check header.")
  77. def check_feature_length(self, data_instance):
  78. if not self.is_data_instance or self.allow_missing:
  79. return
  80. if len(data_instance.features) != self.feature_count:
  81. raise ValueError(f"Feature length {len(data_instance.features)} "
  82. f"mismatch with header length {self.feature_count}")
  83. @staticmethod
  84. def check_is_data_instance(table):
  85. entry = table.first()
  86. is_data_instance = isinstance(entry[1], Instance)
  87. return is_data_instance
  88. @assert_schema_consistent
  89. def fit(self, data):
  90. # LOGGER.debug(f"fit receives data is {data}")
  91. if not isinstance(data, dict) or len(data) <= 1:
  92. raise ValueError("Union module must receive more than one table as input.")
  93. empty_count = 0
  94. combined_table = None
  95. combined_schema = None
  96. metrics = []
  97. for (key, local_table) in data.items():
  98. LOGGER.debug("table to combine name: {}".format(key))
  99. num_data = local_table.count()
  100. LOGGER.debug("table count: {}".format(num_data))
  101. metrics.append(Metric(key, num_data))
  102. self.add_summary(key, num_data)
  103. if num_data == 0:
  104. LOGGER.warning("Table {} is empty.".format(key))
  105. if combined_table is None:
  106. combined_table = local_table
  107. combined_schema = local_table.schema
  108. empty_count += 1
  109. continue
  110. local_is_data_instance = self.check_is_data_instance(local_table)
  111. if self.is_data_instance is None or combined_table is None:
  112. self.is_data_instance = local_is_data_instance
  113. LOGGER.debug(f"self.is_data_instance is {self.is_data_instance}, "
  114. f"local_is_data_instance is {local_is_data_instance}")
  115. if self.is_data_instance != local_is_data_instance:
  116. raise ValueError(f"Cannot combine DataInstance and non-DataInstance object. Union aborted.")
  117. if self.is_data_instance:
  118. self.is_empty_feature = data_overview.is_empty_feature(local_table)
  119. if self.is_empty_feature:
  120. LOGGER.warning("Table {} has empty feature.".format(key))
  121. else:
  122. self.check_schema_content(local_table.schema)
  123. if combined_table is None or combined_table.count() == 0:
  124. # first non-empty table to combine
  125. combined_table = local_table
  126. combined_schema = local_table.schema
  127. if self.keep_duplicate:
  128. combined_table = combined_table.map(lambda k, v: (f"{k}_{key}", v))
  129. combined_table.schema = combined_schema
  130. else:
  131. self.check_id(local_table, combined_table)
  132. self.check_label_name(local_table, combined_table)
  133. self.check_header(local_table, combined_table)
  134. if self.keep_duplicate:
  135. local_table = local_table.map(lambda k, v: (f"{k}_{key}", v))
  136. combined_table = combined_table.union(local_table, self._keep_first)
  137. combined_table.schema = combined_schema
  138. # only check feature length if not empty
  139. if self.is_data_instance and not self.is_empty_feature:
  140. self.feature_count = len(combined_schema.get("header"))
  141. # LOGGER.debug(f"feature count: {self.feature_count}")
  142. combined_table.mapValues(self.check_feature_length)
  143. if combined_table is None:
  144. LOGGER.warning("All tables provided are empty or have empty features.")
  145. first_table = list(data.values())[0]
  146. combined_table = first_table.join(first_table)
  147. num_data = combined_table.count()
  148. metrics.append(Metric("Total", num_data))
  149. self.add_summary("Total", num_data)
  150. LOGGER.info(f"Result total data entry count: {num_data}")
  151. self.callback_metric(metric_name=self.metric_name,
  152. metric_namespace=self.metric_namespace,
  153. metric_data=metrics)
  154. self.tracker.set_metric_meta(metric_namespace=self.metric_namespace,
  155. metric_name=self.metric_name,
  156. metric_meta=MetricMeta(name=self.metric_name, metric_type=self.metric_type))
  157. LOGGER.info("Union operation finished. Total {} empty tables encountered.".format(empty_count))
  158. return combined_table
  159. def check_consistency(self):
  160. pass
  161. def obtain_data(self, data_list):
  162. return data_list