123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from collections import defaultdict
- import functools
- from federatedml.feature.instance import Instance
- from federatedml.transfer_variable.transfer_class.match_id_intersect_transfer_variable import \
- MatchIDIntersectTransferVariable
- from federatedml.util import consts
- from federatedml.util import LOGGER
- class MatchIDIntersect(object):
- """
- This will support repeated ID intersection using ID expanding.
- """
- def __init__(self, sample_id_generator: str, role: str):
- self.sample_id_generator = sample_id_generator
- self.transfer_variable = MatchIDIntersectTransferVariable()
- self.role = role
- self.id_map = None
- self.version = None
- self.owner_src_data = None
- self.data_type = None
- self.with_sample_id = False
- def __get_data_type(self, data):
- if self.data_type is None:
- one_feature = data.first()
- if isinstance(one_feature[1], Instance):
- self.data_type = Instance
- else:
- self.data_type = list
- return self.data_type
- @staticmethod
- def __to_id_map(data):
- id_map = defaultdict(list)
- for d in data:
- idx = d[1].features[0] if isinstance(d[1], Instance) else d[1][0]
- id_map[idx].append(d[0])
- return [(k, v) for k, v in id_map.items()]
- @staticmethod
- def __reduce_id_map(x1, x2):
- return x1 + x2
- @staticmethod
- def __to_sample_id_map(data):
- id_map = defaultdict(list)
- for d in data:
- id_map[d[1].inst_id].append(d[0])
- return [(k, v) for k, v in id_map.items()]
- def __generate_id_map(self, data):
- if self.role != self.sample_id_generator:
- LOGGER.warning("Not a repeated id owner, will not generate id map")
- return
- if not self.with_sample_id:
- all_id_map = data.mapReducePartitions(self.__to_id_map, self.__reduce_id_map)
- id_map = all_id_map.filter(lambda k, v: len(v) >= 2)
- else:
- id_map = data.mapReducePartitions(self.__to_sample_id_map, self.__reduce_id_map)
- return id_map
- @staticmethod
- def __func_restructure_id(k, id_map: list):
- return [(new_id, k) for new_id in id_map]
- @staticmethod
- def __func_restructure_id_for_partner(k, v):
- data, id_map = v[0], v[1]
- return [(new_id, data) for new_id in id_map]
- @staticmethod
- def __func_restructure_sample_id_for_partner(k, v):
- data, id_map = v[0], v[1]
- return [(new_id, data) for new_id in id_map]
- @staticmethod
- def __func_restructure_instance(v):
- v.features = v.features[1:]
- return v
- def __restructure_owner_sample_ids(self, data, id_map):
- rids = id_map.flatMap(functools.partial(self.__func_restructure_id))
- if not self.with_sample_id:
- _data = data.union(rids, lambda dv, rv: dv)
- if self.__get_data_type(self.owner_src_data) == Instance:
- r_data = self.owner_src_data.join(_data, lambda ov, dv: self.__func_restructure_instance(ov))
- else:
- r_data = self.owner_src_data.join(_data, lambda ov, dv: ov[1:])
- r_data.schema = self.owner_src_data.schema
- if r_data.schema.get('header') is not None:
- r_data.schema['header'] = r_data.schema['header'][1:]
- else:
- r_data = self.owner_src_data.join(rids, lambda ov, dv: ov)
- r_data.schema = self.owner_src_data.schema
- return r_data
- def __restructure_partner_sample_ids(self, data, id_map, match_data=None):
- data = data.join(match_data, lambda k, v: v)
- _data = data.join(id_map, lambda dv, iv: (dv, iv))
- # LOGGER.debug(f"_data is: {_data.first()}")
- repeated_ids = _data.flatMap(functools.partial(self.__func_restructure_id_for_partner))
- # LOGGER.debug(f"restructure id for partner called, result is: {repeated_ids.first()}")
- if not self.with_sample_id:
- sub_data = data.subtractByKey(id_map)
- expand_data = sub_data.union(repeated_ids, lambda sv, rv: sv)
- else:
- expand_data = repeated_ids
- expand_data.schema = data.schema
- if match_data:
- expand_data.schema = match_data.schema
- return expand_data
- def __restructure_sample_ids(self, data, id_map, match_data=None):
- # LOGGER.debug(f"id map is: {self.id_map.first()}")
- if self.role == self.sample_id_generator:
- return self.__restructure_owner_sample_ids(data, id_map)
- else:
- return self.__restructure_partner_sample_ids(data, id_map, match_data)
- def generate_intersect_data(self, data):
- if self.__get_data_type(data) == Instance:
- if not self.with_sample_id:
- _data = data.map(
- lambda k, v: (v.features[0], 1))
- else:
- _data = data.map(lambda k, v: (v.inst_id, v))
- else:
- _data = data.mapValues(lambda k, v: (v[0], 1))
- _data.schema = data.schema
- LOGGER.info("Finish recover real ids")
- return _data
- def use_sample_id(self):
- self.with_sample_id = True
- def recover(self, data):
- LOGGER.info("Start repeated id processing.")
- if self.role == self.sample_id_generator:
- LOGGER.info("Start to generate id_map")
- self.id_map = self.__generate_id_map(data)
- self.owner_src_data = data
- else:
- if not self.with_sample_id:
- LOGGER.info("Not sample_id_generator, return!")
- return data
- return self.generate_intersect_data(data)
- def expand(self, data, owner_only=False, match_data=None):
- if self.sample_id_generator == consts.HOST:
- id_map_federation = self.transfer_variable.id_map_from_host
- partner_role = consts.GUEST
- else:
- id_map_federation = self.transfer_variable.id_map_from_guest
- partner_role = consts.HOST
- if self.sample_id_generator == self.role:
- self.id_map = self.id_map.join(data, lambda i, d: i)
- LOGGER.info("Find repeated id_map from intersection ids")
- if not owner_only:
- id_map_federation.remote(self.id_map,
- role=partner_role,
- idx=-1)
- LOGGER.info("Remote id_map to partner")
- else:
- if owner_only:
- return data
- self.id_map = id_map_federation.get(idx=0)
- LOGGER.info("Get id_map from owner.")
- return self.__restructure_sample_ids(data, self.id_map, match_data)
|