converter_base.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. from abc import ABC, abstractmethod
  18. from typing import Dict, Tuple
  19. from federatedml.util.anonymous_generator_util import Anonymous
  20. from federatedml.util import consts
  21. class AutoReplace(object):
  22. def __init__(self, guest_mapping, host_mapping, arbiter_mapping):
  23. self._mapping = {
  24. consts.GUEST: guest_mapping,
  25. consts.HOST: host_mapping,
  26. consts.ARBITER: arbiter_mapping
  27. }
  28. self._anonymous_generator = Anonymous(migrate_mapping=self._mapping)
  29. def get_mapping(self, role: str):
  30. if role not in self._mapping:
  31. raise ValueError('this role contains no site name {}'.format(role))
  32. return self._mapping[role]
  33. def party_tuple_format(self, string: str):
  34. """({role},{party_id})"""
  35. role, party_id = string.strip("()").split(",")
  36. return f"({role}, {self._mapping[role][int(party_id)]})"
  37. def colon_format(self, string: str):
  38. """{role}:{party_id}"""
  39. role, party_id = string.split(':')
  40. mapping = self.get_mapping(role)
  41. new_party_id = mapping[int(party_id)]
  42. return role + ':' + str(new_party_id)
  43. def maybe_anonymous_format(self, string: str):
  44. if self._anonymous_generator.is_anonymous(string):
  45. return self.migrate_anonymous_header([string])[0]
  46. else:
  47. return string
  48. def plain_replace(self, old_party_id, role):
  49. old_party_id = int(old_party_id)
  50. mapping = self._mapping[role]
  51. if old_party_id in mapping:
  52. return str(mapping[int(old_party_id)])
  53. return str(old_party_id)
  54. def migrate_anonymous_header(self, anonymous_header):
  55. if isinstance(anonymous_header, list):
  56. return self._anonymous_generator.migrate_anonymous(anonymous_header)
  57. else:
  58. return self._anonymous_generator.migrate_anonymous([anonymous_header])[0]
  59. def replace(self, string):
  60. if ':' in string:
  61. return self.colon_format(string)
  62. else:
  63. # nothing to replace
  64. return string
  65. class ProtoConverterBase(ABC):
  66. @abstractmethod
  67. def convert(self, param, meta,
  68. guest_id_mapping: Dict,
  69. host_id_mapping: Dict,
  70. arbiter_id_mapping: Dict
  71. ) -> Tuple:
  72. raise NotImplementedError('this interface is not implemented')