pipelined_model.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import base64
  17. import hashlib
  18. import json
  19. import os
  20. import shutil
  21. import typing
  22. from google.protobuf import json_format
  23. from fate_arch.common.base_utils import json_dumps, json_loads
  24. from fate_flow.component_env_utils import provider_utils
  25. from fate_flow.db.runtime_config import RuntimeConfig
  26. from fate_flow.model import (
  27. Locker, local_cache_required,
  28. lock, parse_proto_object,
  29. )
  30. from fate_flow.pipelined_model.pipelined_component import PipelinedComponent
  31. from fate_flow.protobuf.python.pipeline_pb2 import Pipeline
  32. from fate_flow.settings import TEMP_DIRECTORY, stat_logger
  33. from fate_flow.utils.job_utils import (
  34. PIPELINE_COMPONENT_NAME, PIPELINE_MODEL_ALIAS,
  35. PIPELINE_COMPONENT_MODULE_NAME, PIPELINE_MODEL_NAME,
  36. )
  37. from fate_flow.utils.base_utils import get_fate_flow_directory
  38. class PipelinedModel(Locker):
  39. def __init__(self, model_id, model_version):
  40. """
  41. Support operations on FATE PipelinedModels
  42. :param model_id: the model id stored at the local party.
  43. :param model_version: the model version.
  44. """
  45. os.makedirs(TEMP_DIRECTORY, exist_ok=True)
  46. self.role, self.party_id, self._model_id = model_id.split('#', 2)
  47. self.party_model_id = self.model_id = model_id
  48. self.model_version = model_version
  49. self.pipelined_component = PipelinedComponent(role=self.role, party_id=self.party_id,
  50. model_id=self._model_id, model_version=self.model_version)
  51. self.model_path = self.pipelined_component.model_path
  52. super().__init__(self.model_path)
  53. def save_pipeline_model(self, pipeline_buffer_object, save_define_meta_file=True):
  54. model_buffers = {
  55. PIPELINE_MODEL_NAME: (
  56. type(pipeline_buffer_object).__name__,
  57. pipeline_buffer_object.SerializeToString(),
  58. json_format.MessageToDict(pipeline_buffer_object, including_default_value_fields=True),
  59. ),
  60. }
  61. self.save_component_model(PIPELINE_COMPONENT_NAME, PIPELINE_COMPONENT_MODULE_NAME, PIPELINE_MODEL_ALIAS, model_buffers)
  62. # only update pipeline model file if save_define_meta_file is False
  63. if save_define_meta_file:
  64. self.pipelined_component.save_define_meta_from_db_to_file()
  65. def save_component_model(self, *args, **kwargs):
  66. component_model = self.create_component_model(*args, **kwargs)
  67. self.write_component_model(component_model)
  68. def create_component_model(self, component_name, component_module_name, model_alias,
  69. model_buffers: typing.Dict[str, typing.Tuple[str, bytes, dict]],
  70. user_specified_run_parameters: dict = None):
  71. component_model = {"buffer": {}}
  72. component_model_storage_path = os.path.join(self.pipelined_component.variables_data_path, component_name, model_alias)
  73. model_proto_index = {}
  74. for model_name, (proto_index, object_serialized, object_json) in model_buffers.items():
  75. storage_path = os.path.join(component_model_storage_path, model_name).replace(get_fate_flow_directory(), "")
  76. component_model["buffer"][storage_path] = (base64.b64encode(object_serialized).decode(), object_json)
  77. model_proto_index[model_name] = proto_index # index of model name and proto buffer class name
  78. stat_logger.info(f"saved {component_name} {model_alias} {model_name} buffer")
  79. component_model["component_name"] = component_name
  80. component_model["component_module_name"] = component_module_name
  81. component_model["model_alias"] = model_alias
  82. component_model["model_proto_index"] = model_proto_index
  83. component_model["run_parameters"] = user_specified_run_parameters
  84. return component_model
  85. @lock
  86. def write_component_model(self, component_model):
  87. for storage_path, (object_serialized_encoded, object_json) in component_model.get("buffer").items():
  88. storage_path = get_fate_flow_directory() + storage_path
  89. os.makedirs(os.path.dirname(storage_path), exist_ok=True)
  90. with open(storage_path, "wb") as fw:
  91. fw.write(base64.b64decode(object_serialized_encoded.encode()))
  92. with open(f"{storage_path}.json", "w", encoding="utf8") as fw:
  93. fw.write(json_dumps(object_json))
  94. self.pipelined_component.save_define_meta(
  95. component_model["component_name"], component_model["component_module_name"],
  96. component_model["model_alias"], component_model["model_proto_index"],
  97. component_model.get("run_parameters") or {},
  98. )
  99. stat_logger.info(f'saved {component_model["component_name"]} {component_model["model_alias"]} successfully')
  100. @local_cache_required(True)
  101. def _read_component_model(self, component_name, model_alias):
  102. component_model_storage_path = os.path.join(self.pipelined_component.variables_data_path, component_name, model_alias)
  103. model_proto_index = self.get_model_proto_index(component_name=component_name, model_alias=model_alias)
  104. model_buffers = {}
  105. for model_name, buffer_name in model_proto_index.items():
  106. storage_path = os.path.join(component_model_storage_path, model_name)
  107. with open(storage_path, "rb") as f:
  108. buffer_object_serialized_string = f.read()
  109. try:
  110. with open(f"{storage_path}.json", encoding="utf-8") as f:
  111. buffer_object_json_format = json_loads(f.read())
  112. except FileNotFoundError:
  113. buffer_object_json_format = ""
  114. # TODO: should be running in worker
  115. """
  116. buffer_object_json_format = json_format.MessageToDict(
  117. parse_proto_object(buffer_name, buffer_object_serialized_string),
  118. including_default_value_fields=True
  119. )
  120. with open(f"{storage_path}.json", "x", encoding="utf-8") as f:
  121. f.write(json_dumps(buffer_object_json_format))
  122. """
  123. model_buffers[model_name] = (
  124. buffer_name,
  125. buffer_object_serialized_string,
  126. buffer_object_json_format,
  127. )
  128. return model_buffers
  129. # TODO: use different functions instead of passing arguments
  130. def read_component_model(self, component_name, model_alias=None, parse=True, output_json=False):
  131. if model_alias is None:
  132. model_alias = self.get_model_alias(component_name)
  133. if not self.pipelined_component.exists(component_name, model_alias):
  134. return {}
  135. _model_buffers = self._read_component_model(component_name, model_alias)
  136. model_buffers = {}
  137. for model_name, (
  138. buffer_name,
  139. buffer_object_serialized_string,
  140. buffer_object_json_format,
  141. ) in _model_buffers.items():
  142. if output_json:
  143. model_buffers[model_name] = buffer_object_json_format
  144. elif parse:
  145. model_buffers[model_name] = parse_proto_object(buffer_name, buffer_object_serialized_string)
  146. else:
  147. model_buffers[model_name] = [
  148. buffer_name,
  149. base64.b64encode(buffer_object_serialized_string).decode(),
  150. ]
  151. return model_buffers
  152. # TODO: integration with read_component_model
  153. @local_cache_required(True)
  154. def read_pipeline_model(self, parse=True):
  155. component_model_storage_path = os.path.join(self.pipelined_component.variables_data_path, PIPELINE_COMPONENT_NAME, PIPELINE_MODEL_ALIAS)
  156. model_proto_index = self.get_model_proto_index(PIPELINE_COMPONENT_NAME, PIPELINE_MODEL_ALIAS)
  157. model_buffers = {}
  158. for model_name, buffer_name in model_proto_index.items():
  159. with open(os.path.join(component_model_storage_path, model_name), "rb") as fr:
  160. buffer_object_serialized_string = fr.read()
  161. model_buffers[model_name] = (parse_proto_object(buffer_name, buffer_object_serialized_string, Pipeline) if parse
  162. else [buffer_name, base64.b64encode(buffer_object_serialized_string).decode()])
  163. return model_buffers[PIPELINE_MODEL_NAME]
  164. @local_cache_required(True)
  165. def collect_models(self, in_bytes=False, b64encode=True):
  166. define_meta = self.pipelined_component.get_define_meta()
  167. model_buffers = {}
  168. for component_name in define_meta.get("model_proto", {}).keys():
  169. for model_alias, model_proto_index in define_meta["model_proto"][component_name].items():
  170. component_model_storage_path = os.path.join(self.pipelined_component.variables_data_path, component_name, model_alias)
  171. for model_name, buffer_name in model_proto_index.items():
  172. with open(os.path.join(component_model_storage_path, model_name), "rb") as fr:
  173. serialized_string = fr.read()
  174. if in_bytes:
  175. if b64encode:
  176. serialized_string = base64.b64encode(serialized_string).decode()
  177. model_buffers[f"{component_name}.{model_alias}:{model_name}"] = serialized_string
  178. else:
  179. model_buffers[model_name] = parse_proto_object(buffer_name, serialized_string)
  180. return model_buffers
  181. @staticmethod
  182. def get_model_migrate_tool():
  183. return provider_utils.get_provider_class_object(RuntimeConfig.COMPONENT_PROVIDER, "model_migrate", True)
  184. @staticmethod
  185. def get_homo_model_convert_tool():
  186. return provider_utils.get_provider_class_object(RuntimeConfig.COMPONENT_PROVIDER, "homo_model_convert", True)
  187. def exists(self):
  188. return self.pipelined_component.exists()
  189. @local_cache_required(True)
  190. def packaging_model(self):
  191. self.gen_model_import_config()
  192. # self.archive_model_file_path
  193. shutil.make_archive(self.archive_model_base_path, 'zip', self.model_path)
  194. with open(self.archive_model_file_path, 'rb') as f:
  195. hash_ = hashlib.sha256(f.read()).hexdigest()
  196. stat_logger.info(f'Make model {self.model_id} {self.model_version} archive successfully. '
  197. f'path: {self.archive_model_file_path} hash: {hash_}')
  198. return hash_
  199. @lock
  200. def unpack_model(self, archive_file_path: str, force_update: bool = False, hash_: str = None):
  201. if self.exists() and not force_update:
  202. raise FileExistsError(f'Model {self.model_id} {self.model_version} local cache already existed.')
  203. if hash_ is not None:
  204. with open(archive_file_path, 'rb') as f:
  205. sha256 = hashlib.sha256(f.read()).hexdigest()
  206. if hash_ != sha256:
  207. raise ValueError(f'Model archive hash mismatch. '
  208. f'path: {archive_file_path} expected: {hash_} actual: {sha256}')
  209. shutil.unpack_archive(archive_file_path, self.model_path, 'zip')
  210. stat_logger.info(f'Unpack model {self.model_id} {self.model_version} archive successfully. path: {self.model_path}')
  211. def get_component_define(self, component_name=None):
  212. component_define = self.pipelined_component.get_define_meta()['component_define']
  213. if component_name is None:
  214. return component_define
  215. return component_define.get(component_name, {})
  216. def get_model_proto_index(self, component_name=None, model_alias=None):
  217. model_proto = self.pipelined_component.get_define_meta()['model_proto']
  218. if component_name is None:
  219. return model_proto
  220. model_proto = model_proto.get(component_name, {})
  221. if model_alias is None:
  222. return model_proto
  223. return model_proto.get(model_alias, {})
  224. def get_model_alias(self, component_name):
  225. model_proto_index = self.get_model_proto_index(component_name)
  226. if len(model_proto_index) != 1:
  227. raise KeyError('Failed to detect "model_alias", please specify it manually.')
  228. return next(iter(model_proto_index.keys()))
  229. @property
  230. def archive_model_base_path(self):
  231. return os.path.join(TEMP_DIRECTORY, f'{self.party_model_id}_{self.model_version}')
  232. @property
  233. def archive_model_file_path(self):
  234. return f'{self.archive_model_base_path}.zip'
  235. @local_cache_required(True)
  236. def calculate_model_file_size(self):
  237. size = 0
  238. for root, dirs, files in os.walk(self.model_path):
  239. size += sum([os.path.getsize(os.path.join(root, name)) for name in files])
  240. return round(size/1024)
  241. @local_cache_required(True)
  242. def gen_model_import_config(self):
  243. config = {
  244. 'role': self.role,
  245. 'party_id': int(self.party_id),
  246. 'model_id': self._model_id,
  247. 'model_version': self.model_version,
  248. 'file': self.archive_model_file_path,
  249. 'force_update': False,
  250. }
  251. with (self.model_path / 'import_model.json').open('w', encoding='utf-8') as f:
  252. json.dump(config, f, indent=4)