homo_label_encoder.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. from federatedml.transfer_variable.transfer_class.homo_label_encoder_transfer_variable \
  2. import HomoLabelEncoderTransferVariable
  3. from federatedml.util import consts
  4. from federatedml.util import LOGGER
  5. class HomoLabelEncoderClient(object):
  6. def __init__(self):
  7. self.transvar = HomoLabelEncoderTransferVariable()
  8. def label_alignment(self, class_set):
  9. LOGGER.info('start homo label alignments')
  10. self.transvar.local_labels.remote(class_set, role=consts.ARBITER, suffix=('label_align',))
  11. new_label_mapping = self.transvar.label_mapping.get(idx=0, suffix=('label_mapping',))
  12. reverse_mapping = {v: k for k, v in new_label_mapping.items()}
  13. new_classes_index = [new_label_mapping[k] for k in new_label_mapping]
  14. new_classes_index = sorted(new_classes_index)
  15. aligned_labels = [reverse_mapping[i] for i in new_classes_index]
  16. return aligned_labels, new_label_mapping
  17. class HomoLabelEncoderArbiter(object):
  18. def __init__(self):
  19. self.transvar = HomoLabelEncoderTransferVariable()
  20. def label_alignment(self):
  21. LOGGER.info('start homo label alignments')
  22. labels = self.transvar.local_labels.get(idx=-1, suffix=('label_align', ))
  23. label_set = set()
  24. for local_label in labels:
  25. label_set.update(local_label)
  26. global_label = list(label_set)
  27. global_label = sorted(global_label)
  28. label_mapping = {v: k for k, v in enumerate(global_label)}
  29. self.transvar.label_mapping.remote(label_mapping, idx=-1, suffix=('label_mapping',))
  30. return label_mapping