publish_model.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import os
  17. import grpc
  18. from fate_arch.common.base_utils import json_loads
  19. from fate_arch.protobuf.python import model_service_pb2, model_service_pb2_grpc
  20. from fate_flow.model.sync_model import SyncModel
  21. from fate_flow.pipelined_model import pipelined_model
  22. from fate_flow.pipelined_model.homo_model_deployer.model_deploy import model_deploy
  23. from fate_flow.settings import (
  24. ENABLE_MODEL_STORE, FATE_FLOW_MODEL_TRANSFER_ENDPOINT,
  25. GRPC_OPTIONS, HOST, HTTP_PORT, USE_REGISTRY, stat_logger,
  26. )
  27. from fate_flow.utils import model_utils
  28. def generate_publish_model_info(config_data):
  29. model_id = config_data['job_parameters']['model_id']
  30. model_version = config_data['job_parameters']['model_version']
  31. config_data['model'] = {}
  32. for role, role_party in config_data.get("role").items():
  33. config_data['model'][role] = {}
  34. for party_id in role_party:
  35. config_data['model'][role][party_id] = {
  36. 'model_id': model_utils.gen_party_model_id(model_id, role, party_id),
  37. 'model_version': model_version
  38. }
  39. def load_model(config_data):
  40. stat_logger.info(config_data)
  41. if not config_data.get('servings'):
  42. return 100, 'Please configure servings address'
  43. for serving in config_data['servings']:
  44. with grpc.insecure_channel(serving, GRPC_OPTIONS) as channel:
  45. stub = model_service_pb2_grpc.ModelServiceStub(channel)
  46. load_model_request = model_service_pb2.PublishRequest()
  47. for role_name, role_partys in config_data.get("role", {}).items():
  48. for _party_id in role_partys:
  49. load_model_request.role[role_name].partyId.append(str(_party_id))
  50. for role_name, role_model_config in config_data.get("model", {}).items():
  51. for _party_id, role_party_model_config in role_model_config.items():
  52. load_model_request.model[role_name].roleModelInfo[str(_party_id)].tableName = \
  53. role_party_model_config['model_version']
  54. load_model_request.model[role_name].roleModelInfo[str(_party_id)].namespace = \
  55. role_party_model_config['model_id']
  56. stat_logger.info('request serving: {} load model'.format(serving))
  57. load_model_request.local.role = config_data.get('local', {}).get('role', '')
  58. load_model_request.local.partyId = str(config_data.get('local', {}).get('party_id', ''))
  59. load_model_request.loadType = config_data['job_parameters'].get("load_type", "FATEFLOW")
  60. # make use of 'model.transfer.url' in serving server
  61. use_serving_url = config_data['job_parameters'].get('use_transfer_url_on_serving', False)
  62. if not USE_REGISTRY and not use_serving_url:
  63. load_model_request.filePath = f"http://{HOST}:{HTTP_PORT}{FATE_FLOW_MODEL_TRANSFER_ENDPOINT}"
  64. else:
  65. load_model_request.filePath = config_data['job_parameters'].get("file_path", "")
  66. stat_logger.info(load_model_request)
  67. response = stub.publishLoad(load_model_request)
  68. stat_logger.info(
  69. '{} {} load model status: {}'.format(load_model_request.local.role, load_model_request.local.partyId,
  70. response.statusCode))
  71. if response.statusCode != 0:
  72. return response.statusCode, '{} {}'.format(response.message, response.error)
  73. return 0, 'success'
  74. def bind_model_service(config_data):
  75. if not config_data.get('servings'):
  76. return 100, 'Please configure servings address'
  77. service_id = str(config_data.get('service_id', ''))
  78. initiator_role = config_data['initiator']['role']
  79. initiator_party_id = str(config_data['initiator']['party_id'])
  80. model_id = config_data['job_parameters']['model_id']
  81. model_version = config_data['job_parameters']['model_version']
  82. for serving in config_data['servings']:
  83. with grpc.insecure_channel(serving, GRPC_OPTIONS) as channel:
  84. stub = model_service_pb2_grpc.ModelServiceStub(channel)
  85. publish_model_request = model_service_pb2.PublishRequest()
  86. publish_model_request.serviceId = service_id
  87. # {"role": {"guest": ["9999"], "host": ["10000"], "arbiter": ["9999"]}}
  88. for role_name, role_party in config_data.get("role").items():
  89. publish_model_request.role[role_name].partyId.extend([str(party_id) for party_id in role_party])
  90. party_model_id = model_utils.gen_party_model_id(model_id, initiator_role, initiator_party_id)
  91. publish_model_request.model[initiator_role].roleModelInfo[initiator_party_id].tableName = model_version
  92. publish_model_request.model[initiator_role].roleModelInfo[initiator_party_id].namespace = party_model_id
  93. publish_model_request.local.role = initiator_role
  94. publish_model_request.local.partyId = initiator_party_id
  95. stat_logger.info(publish_model_request)
  96. response = stub.publishBind(publish_model_request)
  97. stat_logger.info(response)
  98. if response.statusCode != 0:
  99. return response.statusCode, response.message
  100. return 0, None
  101. def download_model(party_model_id, model_version):
  102. if ENABLE_MODEL_STORE:
  103. sync_model = SyncModel(
  104. party_model_id=party_model_id,
  105. model_version=model_version,
  106. )
  107. if sync_model.remote_exists():
  108. sync_model.download(True)
  109. model = pipelined_model.PipelinedModel(party_model_id, model_version)
  110. if not model.exists():
  111. return {}
  112. return model.collect_models(in_bytes=True)
  113. def convert_homo_model(request_data):
  114. party_model_id = model_utils.gen_party_model_id(model_id=request_data["model_id"],
  115. role=request_data["role"],
  116. party_id=request_data["party_id"])
  117. model_version = request_data.get("model_version")
  118. model = pipelined_model.PipelinedModel(model_id=party_model_id, model_version=model_version)
  119. if not model.exists():
  120. return 100, 'Model {} {} does not exist'.format(party_model_id, model_version), None
  121. define_meta = model.pipelined_component.get_define_meta()
  122. framework_name = request_data.get("framework_name")
  123. detail = []
  124. # todo: use subprocess?
  125. convert_tool = model.get_homo_model_convert_tool()
  126. for key, value in define_meta.get("model_proto", {}).items():
  127. if key == 'pipeline':
  128. continue
  129. for model_alias in value.keys():
  130. buffer_obj = model.read_component_model(key, model_alias)
  131. module_name = define_meta.get("component_define", {}).get(key, {}).get('module_name')
  132. converted_framework, converted_model = convert_tool.model_convert(model_contents=buffer_obj,
  133. module_name=module_name,
  134. framework_name=framework_name)
  135. if converted_model:
  136. converted_model_dir = os.path.join(model.pipelined_component.variables_data_path, key, model_alias, "converted_model")
  137. os.makedirs(converted_model_dir, exist_ok=True)
  138. saved_path = convert_tool.save_converted_model(converted_model,
  139. converted_framework,
  140. converted_model_dir)
  141. detail.append({
  142. "component_name": key,
  143. "model_alias": model_alias,
  144. "converted_model_path": saved_path
  145. })
  146. if len(detail) > 0:
  147. return (0,
  148. f"Conversion of homogeneous federated learning component(s) in model "
  149. f"{party_model_id}:{model_version} completed. Use export or homo/deploy "
  150. f"to download or deploy the converted model.",
  151. detail)
  152. else:
  153. return 100, f"No component in model {party_model_id}:{model_version} can be converted.", None
  154. def deploy_homo_model(request_data):
  155. party_model_id = model_utils.gen_party_model_id(model_id=request_data["model_id"],
  156. role=request_data["role"],
  157. party_id=request_data["party_id"])
  158. model_version = request_data["model_version"]
  159. component_name = request_data['component_name']
  160. service_id = request_data['service_id']
  161. framework_name = request_data.get('framework_name')
  162. model = pipelined_model.PipelinedModel(model_id=party_model_id, model_version=model_version)
  163. if not model.exists():
  164. return 100, 'Model {} {} does not exist'.format(party_model_id, model_version), None
  165. # get the model alias from the dsl saved with the pipeline
  166. pipeline = model.read_pipeline_model()
  167. train_dsl = json_loads(pipeline.train_dsl)
  168. if component_name not in train_dsl.get('components', {}):
  169. return 100, 'Model {} {} does not contain component {}'.\
  170. format(party_model_id, model_version, component_name), None
  171. model_alias_list = train_dsl['components'][component_name].get('output', {}).get('model')
  172. if not model_alias_list:
  173. return 100, 'Component {} in Model {} {} does not have output model'. \
  174. format(component_name, party_model_id, model_version), None
  175. # currently there is only one model output
  176. model_alias = model_alias_list[0]
  177. converted_model_dir = os.path.join(model.pipelined_component.variables_data_path, component_name, model_alias, "converted_model")
  178. if not os.path.isdir(converted_model_dir):
  179. return 100, '''Component {} in Model {} {} isn't converted'''.\
  180. format(component_name, party_model_id, model_version), None
  181. # todo: use subprocess?
  182. convert_tool = model.get_homo_model_convert_tool()
  183. if not framework_name:
  184. module_name = train_dsl['components'][component_name].get('module')
  185. buffer_obj = model.read_component_model(component_name, model_alias)
  186. framework_name = convert_tool.get_default_target_framework(model_contents=buffer_obj, module_name=module_name)
  187. model_object = convert_tool.load_converted_model(base_dir=converted_model_dir,
  188. framework_name=framework_name)
  189. deployed_service = model_deploy(party_model_id,
  190. model_version,
  191. model_object,
  192. framework_name,
  193. service_id,
  194. request_data['deployment_type'],
  195. request_data['deployment_parameters'])
  196. return (0,
  197. f"An online serving service is started in the {request_data['deployment_type']} system.",
  198. deployed_service)