model_migrate.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. from typing import List
  2. from federatedml.protobuf.model_migrate.converter_factory import converter_factory
  3. from federatedml.model_base import serialize_models
  4. import copy
  5. def generate_id_mapping(old_id, new_id):
  6. if old_id is None and new_id is None:
  7. return {}
  8. elif not (isinstance(old_id, list) and isinstance(new_id, list)):
  9. raise ValueError('illegal input format: id lists type should be list, however got: \n'
  10. 'content: {}/ type: {} \n'
  11. 'content: {}/ type: {}'.format(old_id, type(old_id), new_id, type(new_id)))
  12. if len(old_id) != len(new_id):
  13. raise ValueError('id lists length does not match: len({}) != len({})'.format(old_id, new_id))
  14. mapping = {}
  15. for id0, id1 in zip(old_id, new_id):
  16. if not isinstance(id0, int) or not isinstance(id1, int):
  17. raise ValueError('party id must be an integer, got {}:{} and {}:{}'.format(id0, type(id0),
  18. id1, type(id1)))
  19. mapping[id0] = id1
  20. return mapping
  21. def model_migration(model_contents: dict,
  22. module_name,
  23. old_guest_list: List[int],
  24. new_guest_list: List[int],
  25. old_host_list: List[int],
  26. new_host_list: List[int],
  27. old_arbiter_list=None,
  28. new_arbiter_list=None,
  29. ):
  30. converter = converter_factory(module_name)
  31. if converter is None:
  32. # no supported converter, return
  33. return serialize_models(model_contents)
  34. # replace old id with new id using converter
  35. guest_mapping_dict = generate_id_mapping(old_guest_list, new_guest_list)
  36. host_mapping_dict = generate_id_mapping(old_host_list, new_host_list)
  37. arbiter_mapping_dict = generate_id_mapping(old_arbiter_list, new_arbiter_list)
  38. model_contents_cpy = copy.deepcopy(model_contents)
  39. keys = model_contents.keys()
  40. param, meta = None, None
  41. param_key, meta_key = None, None
  42. for key in keys:
  43. if 'Param' in key:
  44. param_key = key
  45. param = model_contents_cpy[key]
  46. if 'Meta' in key:
  47. meta_key = key
  48. meta = model_contents_cpy[key]
  49. if param is None or meta is None:
  50. raise ValueError('param or meta is None')
  51. converted_param, converted_meta = converter.convert(param, meta, guest_mapping_dict,
  52. host_mapping_dict, arbiter_mapping_dict)
  53. return serialize_models({param_key: converted_param, meta_key: converted_meta})