12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667 |
- from typing import List
- from federatedml.protobuf.model_migrate.converter_factory import converter_factory
- from federatedml.model_base import serialize_models
- import copy
- def generate_id_mapping(old_id, new_id):
- if old_id is None and new_id is None:
- return {}
- elif not (isinstance(old_id, list) and isinstance(new_id, list)):
- raise ValueError('illegal input format: id lists type should be list, however got: \n'
- 'content: {}/ type: {} \n'
- 'content: {}/ type: {}'.format(old_id, type(old_id), new_id, type(new_id)))
- if len(old_id) != len(new_id):
- raise ValueError('id lists length does not match: len({}) != len({})'.format(old_id, new_id))
- mapping = {}
- for id0, id1 in zip(old_id, new_id):
- if not isinstance(id0, int) or not isinstance(id1, int):
- raise ValueError('party id must be an integer, got {}:{} and {}:{}'.format(id0, type(id0),
- id1, type(id1)))
- mapping[id0] = id1
- return mapping
- def model_migration(model_contents: dict,
- module_name,
- old_guest_list: List[int],
- new_guest_list: List[int],
- old_host_list: List[int],
- new_host_list: List[int],
- old_arbiter_list=None,
- new_arbiter_list=None,
- ):
- converter = converter_factory(module_name)
- if converter is None:
- # no supported converter, return
- return serialize_models(model_contents)
- # replace old id with new id using converter
- guest_mapping_dict = generate_id_mapping(old_guest_list, new_guest_list)
- host_mapping_dict = generate_id_mapping(old_host_list, new_host_list)
- arbiter_mapping_dict = generate_id_mapping(old_arbiter_list, new_arbiter_list)
- model_contents_cpy = copy.deepcopy(model_contents)
- keys = model_contents.keys()
- param, meta = None, None
- param_key, meta_key = None, None
- for key in keys:
- if 'Param' in key:
- param_key = key
- param = model_contents_cpy[key]
- if 'Meta' in key:
- meta_key = key
- meta = model_contents_cpy[key]
- if param is None or meta is None:
- raise ValueError('param or meta is None')
- converted_param, converted_meta = converter.convert(param, meta, guest_mapping_dict,
- host_mapping_dict, arbiter_mapping_dict)
- return serialize_models({param_key: converted_param, meta_key: converted_meta})
|