123456789101112131415161718192021222324252627282930313233343536373839 |
- from federatedml.transfer_variable.transfer_class.homo_label_encoder_transfer_variable \
- import HomoLabelEncoderTransferVariable
- from federatedml.util import consts
- from federatedml.util import LOGGER
- class HomoLabelEncoderClient(object):
- def __init__(self):
- self.transvar = HomoLabelEncoderTransferVariable()
- def label_alignment(self, class_set):
- LOGGER.info('start homo label alignments')
- self.transvar.local_labels.remote(class_set, role=consts.ARBITER, suffix=('label_align',))
- new_label_mapping = self.transvar.label_mapping.get(idx=0, suffix=('label_mapping',))
- reverse_mapping = {v: k for k, v in new_label_mapping.items()}
- new_classes_index = [new_label_mapping[k] for k in new_label_mapping]
- new_classes_index = sorted(new_classes_index)
- aligned_labels = [reverse_mapping[i] for i in new_classes_index]
- return aligned_labels, new_label_mapping
- class HomoLabelEncoderArbiter(object):
- def __init__(self):
- self.transvar = HomoLabelEncoderTransferVariable()
- def label_alignment(self):
- LOGGER.info('start homo label alignments')
- labels = self.transvar.local_labels.get(idx=-1, suffix=('label_align', ))
- label_set = set()
- for local_label in labels:
- label_set.update(local_label)
- global_label = list(label_set)
- global_label = sorted(global_label)
- label_mapping = {v: k for k, v in enumerate(global_label)}
- self.transvar.label_mapping.remote(label_mapping, idx=-1, suffix=('label_mapping',))
- return label_mapping
|