model_loader.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. #
  2. # Copyright 2021 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from fate_flow.components._base import (
  17. BaseParam, ComponentBase,
  18. ComponentInputProtocol, ComponentMeta,
  19. )
  20. from fate_flow.entity import MetricMeta, MetricType
  21. from fate_flow.model.checkpoint import CheckpointManager
  22. from fate_flow.model.sync_model import SyncComponent
  23. from fate_flow.pipelined_model.pipelined_model import PipelinedModel
  24. from fate_flow.settings import ENABLE_MODEL_STORE
  25. from fate_flow.utils.log_utils import getLogger
  26. from fate_flow.utils.model_utils import gen_party_model_id
  27. LOGGER = getLogger()
  28. model_loader_cpn_meta = ComponentMeta('ModelLoader')
  29. @model_loader_cpn_meta.bind_runner.on_guest.on_host.on_arbiter
  30. class ModelLoader(ComponentBase):
  31. """ ModelLoader is a component for loading models trained by previous jobs.
  32. `self.model_id`, `self.model_version`, `self.component_name` and `self.model_alias`
  33. come from the previous job. However, most of the data in `self.tracker` belongs to the current job.
  34. Such as `self.tracker.job_id`, `self.tracker.task_id`, `self.tracker.task_version`, etc.
  35. Be careful when using them.
  36. """
  37. def __init__(self):
  38. super().__init__()
  39. self.serialize = False
  40. self.model_id = None
  41. self.model_version = None
  42. self.component_name = None
  43. self.model_alias = None
  44. self.step_index = None
  45. self.step_name = None
  46. def read_component_model(self):
  47. pipelined_model = PipelinedModel(gen_party_model_id(
  48. self.model_id, self.tracker.role, self.tracker.party_id
  49. ), self.model_version)
  50. if self.model_alias is None:
  51. self.model_alias = pipelined_model.get_model_alias(self.component_name)
  52. component_model = pipelined_model._read_component_model(self.component_name, self.model_alias)
  53. if not component_model:
  54. raise ValueError('The component model is empty.')
  55. self.model_output = component_model
  56. self.tracker.set_metric_meta('model_loader', f'{self.component_name}-{self.model_alias}',
  57. MetricMeta('component_model', MetricType.COMPONENT_MODEL_INFO, {
  58. 'model_id': self.model_id,
  59. 'model_version': self.model_version,
  60. 'component_name': self.component_name,
  61. 'model_alias': self.model_alias,
  62. }))
  63. def read_checkpoint(self):
  64. checkpoint_manager = CheckpointManager(
  65. role=self.tracker.role, party_id=self.tracker.party_id,
  66. model_id=self.model_id, model_version=self.model_version,
  67. component_name=self.component_name,
  68. )
  69. checkpoint_manager.load_checkpoints_from_disk()
  70. if self.step_index is not None:
  71. checkpoint = checkpoint_manager.get_checkpoint_by_index(self.step_index)
  72. elif self.step_name is not None:
  73. checkpoint = checkpoint_manager.get_checkpoint_by_name(self.step_name)
  74. else:
  75. checkpoint = checkpoint_manager.latest_checkpoint
  76. if checkpoint is None:
  77. raise ValueError('The checkpoint was not found.')
  78. data = checkpoint.read(include_database=True)
  79. data['model_id'] = checkpoint_manager.model_id
  80. data['model_version'] = checkpoint_manager.model_version
  81. data['component_name'] = checkpoint_manager.component_name
  82. self.model_output = data.pop('models')
  83. self.tracker.set_metric_meta('model_loader', f'{checkpoint.step_index}-{checkpoint.step_name}',
  84. MetricMeta('checkpoint', MetricType.CHECKPOINT_INFO, data))
  85. def _run(self, cpn_input: ComponentInputProtocol):
  86. need_run = cpn_input.parameters.get('need_run', True)
  87. if not need_run:
  88. return
  89. for k in ('model_id', 'model_version', 'component_name'):
  90. v = cpn_input.parameters.get(k)
  91. if v is None:
  92. raise KeyError(f'The component ModelLoader needs "{k}"')
  93. setattr(self, k, v)
  94. for k in ('model_alias', 'step_index', 'step_name'):
  95. v = cpn_input.parameters.get(k)
  96. if v is not None:
  97. setattr(self, k, v)
  98. break
  99. if ENABLE_MODEL_STORE:
  100. sync_component = SyncComponent(
  101. role=self.tracker.role, party_id=self.tracker.party_id,
  102. model_id=self.model_id, model_version=self.model_version,
  103. component_name=self.component_name,
  104. )
  105. sync_component.download()
  106. if self.model_alias is not None:
  107. return self.read_component_model()
  108. if self.step_index is not None or self.step_name is not None:
  109. return self.read_checkpoint()
  110. try:
  111. return self.read_component_model()
  112. except Exception:
  113. try:
  114. return self.read_checkpoint()
  115. except Exception:
  116. raise EnvironmentError('Unable to find component model and checkpoint. '
  117. 'Try specifying "model_alias", "step_index" or "step_name".')
  118. @model_loader_cpn_meta.bind_param
  119. class ModelLoaderParam(BaseParam):
  120. def __init__(self, model_id: str = None, model_version: str = None, component_name: str = None,
  121. model_alias: str = None, step_index: int = None, step_name: str = None, need_run: bool = True):
  122. self.model_id = model_id
  123. self.model_version = model_version
  124. self.component_name = component_name
  125. self.model_alias = model_alias
  126. self.step_index = step_index
  127. self.step_name = step_name
  128. self.need_run = need_run
  129. if self.step_index is not None:
  130. self.step_index = int(self.step_index)
  131. def check(self):
  132. for i in ('model_id', 'model_version', 'component_name'):
  133. if getattr(self, i) is None:
  134. raise KeyError(f"The parameter '{i}' is required.")