match_id_process.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from collections import defaultdict
  17. import functools
  18. from federatedml.feature.instance import Instance
  19. from federatedml.transfer_variable.transfer_class.match_id_intersect_transfer_variable import \
  20. MatchIDIntersectTransferVariable
  21. from federatedml.util import consts
  22. from federatedml.util import LOGGER
  23. class MatchIDIntersect(object):
  24. """
  25. This will support repeated ID intersection using ID expanding.
  26. """
  27. def __init__(self, sample_id_generator: str, role: str):
  28. self.sample_id_generator = sample_id_generator
  29. self.transfer_variable = MatchIDIntersectTransferVariable()
  30. self.role = role
  31. self.id_map = None
  32. self.version = None
  33. self.owner_src_data = None
  34. self.data_type = None
  35. self.with_sample_id = False
  36. def __get_data_type(self, data):
  37. if self.data_type is None:
  38. one_feature = data.first()
  39. if isinstance(one_feature[1], Instance):
  40. self.data_type = Instance
  41. else:
  42. self.data_type = list
  43. return self.data_type
  44. @staticmethod
  45. def __to_id_map(data):
  46. id_map = defaultdict(list)
  47. for d in data:
  48. idx = d[1].features[0] if isinstance(d[1], Instance) else d[1][0]
  49. id_map[idx].append(d[0])
  50. return [(k, v) for k, v in id_map.items()]
  51. @staticmethod
  52. def __reduce_id_map(x1, x2):
  53. return x1 + x2
  54. @staticmethod
  55. def __to_sample_id_map(data):
  56. id_map = defaultdict(list)
  57. for d in data:
  58. id_map[d[1].inst_id].append(d[0])
  59. return [(k, v) for k, v in id_map.items()]
  60. def __generate_id_map(self, data):
  61. if self.role != self.sample_id_generator:
  62. LOGGER.warning("Not a repeated id owner, will not generate id map")
  63. return
  64. if not self.with_sample_id:
  65. all_id_map = data.mapReducePartitions(self.__to_id_map, self.__reduce_id_map)
  66. id_map = all_id_map.filter(lambda k, v: len(v) >= 2)
  67. else:
  68. id_map = data.mapReducePartitions(self.__to_sample_id_map, self.__reduce_id_map)
  69. return id_map
  70. @staticmethod
  71. def __func_restructure_id(k, id_map: list):
  72. return [(new_id, k) for new_id in id_map]
  73. @staticmethod
  74. def __func_restructure_id_for_partner(k, v):
  75. data, id_map = v[0], v[1]
  76. return [(new_id, data) for new_id in id_map]
  77. @staticmethod
  78. def __func_restructure_sample_id_for_partner(k, v):
  79. data, id_map = v[0], v[1]
  80. return [(new_id, data) for new_id in id_map]
  81. @staticmethod
  82. def __func_restructure_instance(v):
  83. v.features = v.features[1:]
  84. return v
  85. def __restructure_owner_sample_ids(self, data, id_map):
  86. rids = id_map.flatMap(functools.partial(self.__func_restructure_id))
  87. if not self.with_sample_id:
  88. _data = data.union(rids, lambda dv, rv: dv)
  89. if self.__get_data_type(self.owner_src_data) == Instance:
  90. r_data = self.owner_src_data.join(_data, lambda ov, dv: self.__func_restructure_instance(ov))
  91. else:
  92. r_data = self.owner_src_data.join(_data, lambda ov, dv: ov[1:])
  93. r_data.schema = self.owner_src_data.schema
  94. if r_data.schema.get('header') is not None:
  95. r_data.schema['header'] = r_data.schema['header'][1:]
  96. else:
  97. r_data = self.owner_src_data.join(rids, lambda ov, dv: ov)
  98. r_data.schema = self.owner_src_data.schema
  99. return r_data
  100. def __restructure_partner_sample_ids(self, data, id_map, match_data=None):
  101. data = data.join(match_data, lambda k, v: v)
  102. _data = data.join(id_map, lambda dv, iv: (dv, iv))
  103. # LOGGER.debug(f"_data is: {_data.first()}")
  104. repeated_ids = _data.flatMap(functools.partial(self.__func_restructure_id_for_partner))
  105. # LOGGER.debug(f"restructure id for partner called, result is: {repeated_ids.first()}")
  106. if not self.with_sample_id:
  107. sub_data = data.subtractByKey(id_map)
  108. expand_data = sub_data.union(repeated_ids, lambda sv, rv: sv)
  109. else:
  110. expand_data = repeated_ids
  111. expand_data.schema = data.schema
  112. if match_data:
  113. expand_data.schema = match_data.schema
  114. return expand_data
  115. def __restructure_sample_ids(self, data, id_map, match_data=None):
  116. # LOGGER.debug(f"id map is: {self.id_map.first()}")
  117. if self.role == self.sample_id_generator:
  118. return self.__restructure_owner_sample_ids(data, id_map)
  119. else:
  120. return self.__restructure_partner_sample_ids(data, id_map, match_data)
  121. def generate_intersect_data(self, data):
  122. if self.__get_data_type(data) == Instance:
  123. if not self.with_sample_id:
  124. _data = data.map(
  125. lambda k, v: (v.features[0], 1))
  126. else:
  127. _data = data.map(lambda k, v: (v.inst_id, v))
  128. else:
  129. _data = data.mapValues(lambda k, v: (v[0], 1))
  130. _data.schema = data.schema
  131. LOGGER.info("Finish recover real ids")
  132. return _data
  133. def use_sample_id(self):
  134. self.with_sample_id = True
  135. def recover(self, data):
  136. LOGGER.info("Start repeated id processing.")
  137. if self.role == self.sample_id_generator:
  138. LOGGER.info("Start to generate id_map")
  139. self.id_map = self.__generate_id_map(data)
  140. self.owner_src_data = data
  141. else:
  142. if not self.with_sample_id:
  143. LOGGER.info("Not sample_id_generator, return!")
  144. return data
  145. return self.generate_intersect_data(data)
  146. def expand(self, data, owner_only=False, match_data=None):
  147. if self.sample_id_generator == consts.HOST:
  148. id_map_federation = self.transfer_variable.id_map_from_host
  149. partner_role = consts.GUEST
  150. else:
  151. id_map_federation = self.transfer_variable.id_map_from_guest
  152. partner_role = consts.HOST
  153. if self.sample_id_generator == self.role:
  154. self.id_map = self.id_map.join(data, lambda i, d: i)
  155. LOGGER.info("Find repeated id_map from intersection ids")
  156. if not owner_only:
  157. id_map_federation.remote(self.id_map,
  158. role=partner_role,
  159. idx=-1)
  160. LOGGER.info("Remote id_map to partner")
  161. else:
  162. if owner_only:
  163. return data
  164. self.id_map = id_map_federation.get(idx=0)
  165. LOGGER.info("Get id_map from owner.")
  166. return self.__restructure_sample_ids(data, self.id_map, match_data)