model_base.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. import copy
  19. import typing
  20. import numpy as np
  21. from google.protobuf import json_format
  22. from fate_arch.computing import is_table
  23. from federatedml.callbacks.callback_list import CallbackList
  24. from federatedml.feature.instance import Instance
  25. from federatedml.param.evaluation_param import EvaluateParam
  26. from federatedml.protobuf import deserialize_models
  27. from federatedml.statistic.data_overview import header_alignment, predict_detail_dict_to_str
  28. from federatedml.util import LOGGER, abnormal_detection
  29. from federatedml.util.anonymous_generator_util import Anonymous
  30. from federatedml.util.component_properties import ComponentProperties, RunningFuncs
  31. from federatedml.util.io_check import assert_match_id_consistent
  32. def serialize_models(models):
  33. serialized_models: typing.Dict[str, typing.Tuple[str, bytes, dict]] = {}
  34. for model_name, buffer_object in models.items():
  35. serialized_string = buffer_object.SerializeToString()
  36. pb_name = type(buffer_object).__name__
  37. json_format_dict = json_format.MessageToDict(
  38. buffer_object, including_default_value_fields=True
  39. )
  40. serialized_models[model_name] = (
  41. pb_name,
  42. serialized_string,
  43. json_format_dict,
  44. )
  45. return serialized_models
  46. class ComponentOutput:
  47. def __init__(self, data, models, cache: typing.List[tuple]) -> None:
  48. self._data = data
  49. if not isinstance(self._data, list):
  50. self._data = [data]
  51. self._models = models
  52. if self._models is None:
  53. self._models = {}
  54. self._cache = cache
  55. if not isinstance(self._cache, list):
  56. self._cache = [cache]
  57. @property
  58. def data(self) -> list:
  59. return self._data
  60. @property
  61. def model(self):
  62. return serialize_models(self._models)
  63. @property
  64. def cache(self):
  65. return self._cache
  66. class MetricType:
  67. LOSS = "LOSS"
  68. class Metric:
  69. def __init__(self, key, value: float, timestamp: float = None):
  70. self.key = key
  71. self.value = value
  72. self.timestamp = timestamp
  73. def to_dict(self):
  74. return dict(key=self.key, value=self.value, timestamp=self.timestamp)
  75. class MetricMeta:
  76. def __init__(self, name: str, metric_type: MetricType, extra_metas: dict = None):
  77. self.name = name
  78. self.metric_type = metric_type
  79. self.metas = {}
  80. self.extra_metas = extra_metas
  81. def update_metas(self, metas: dict):
  82. self.metas.update(metas)
  83. def to_dict(self):
  84. return dict(
  85. name=self.name,
  86. metric_type=self.metric_type,
  87. metas=self.metas,
  88. extra_metas=self.extra_metas,
  89. )
  90. class CallbacksVariable(object):
  91. def __init__(self):
  92. self.stop_training = False
  93. self.best_iteration = -1
  94. self.validation_summary = None
  95. class WarpedTrackerClient:
  96. def __init__(self, tracker) -> None:
  97. self._tracker = tracker
  98. def log_metric_data(
  99. self, metric_namespace: str, metric_name: str, metrics: typing.List[Metric]
  100. ):
  101. return self._tracker.log_metric_data(
  102. metric_namespace=metric_namespace,
  103. metric_name=metric_name,
  104. metrics=[metric.to_dict() for metric in metrics],
  105. )
  106. def set_metric_meta(
  107. self, metric_namespace: str, metric_name: str, metric_meta: MetricMeta
  108. ):
  109. return self._tracker.set_metric_meta(
  110. metric_namespace=metric_namespace,
  111. metric_name=metric_name,
  112. metric_meta=metric_meta.to_dict(),
  113. )
  114. def log_component_summary(self, summary_data: dict):
  115. return self._tracker.log_component_summary(summary_data=summary_data)
  116. class ModelBase(object):
  117. component_name = None
  118. @classmethod
  119. def set_component_name(cls, name):
  120. cls.component_name = name
  121. @classmethod
  122. def get_component_name(cls):
  123. return cls.component_name
  124. def __init__(self):
  125. self.model_output = None
  126. self.mode = None
  127. self.role = None
  128. self.data_output = None
  129. self.cache_output = None
  130. self.model_param = None
  131. self.transfer_variable = None
  132. self.flowid = ""
  133. self.task_version_id = ""
  134. self.need_one_vs_rest = False
  135. self.callback_one_vs_rest = False
  136. self.checkpoint_manager = None
  137. self.cv_fold = 0
  138. self.validation_freqs = None
  139. self.component_properties = ComponentProperties()
  140. self._summary = dict()
  141. self._align_cache = dict()
  142. self._tracker = None
  143. self.step_name = "step_name"
  144. self.callback_list: CallbackList
  145. self.callback_variables = CallbacksVariable()
  146. self.anonymous_generator = None
  147. @property
  148. def tracker(self) -> WarpedTrackerClient:
  149. if self._tracker is None:
  150. raise RuntimeError(f"use tracker before set")
  151. return self._tracker
  152. @tracker.setter
  153. def tracker(self, value):
  154. self._tracker = WarpedTrackerClient(value)
  155. @property
  156. def stop_training(self):
  157. return self.callback_variables.stop_training
  158. @property
  159. def need_cv(self):
  160. return self.component_properties.need_cv
  161. @property
  162. def need_run(self):
  163. return self.component_properties.need_run
  164. @need_run.setter
  165. def need_run(self, value: bool):
  166. self.component_properties.need_run = value
  167. def _init_model(self, model):
  168. pass
  169. def load_model(self, model_dict):
  170. pass
  171. def _parse_need_run(self, model_dict, model_meta_name):
  172. meta_obj = list(model_dict.get("model").values())[0].get(model_meta_name)
  173. need_run = meta_obj.need_run
  174. # self.need_run = need_run
  175. self.component_properties.need_run = need_run
  176. def run(self, cpn_input, retry: bool = True):
  177. self.task_version_id = cpn_input.task_version_id
  178. self.tracker = cpn_input.tracker
  179. self.checkpoint_manager = cpn_input.checkpoint_manager
  180. deserialize_models(cpn_input.models)
  181. # retry
  182. if (
  183. retry
  184. and hasattr(self, '_retry')
  185. and callable(self._retry)
  186. and self.checkpoint_manager is not None
  187. and self.checkpoint_manager.latest_checkpoint is not None
  188. ):
  189. self._retry(cpn_input=cpn_input)
  190. # normal
  191. else:
  192. self._run(cpn_input=cpn_input)
  193. return ComponentOutput(self.save_data(), self._export(), self.save_cache())
  194. def _export(self):
  195. # export model
  196. try:
  197. model = self._export_model()
  198. meta = self._export_meta()
  199. export_dict = {"Meta": meta, "Param": model}
  200. except NotImplementedError:
  201. export_dict = self.export_model()
  202. # export nothing, return
  203. if export_dict is None:
  204. return export_dict
  205. try:
  206. meta_name = [k for k in export_dict if k.endswith("Meta")][0]
  207. except BaseException:
  208. raise KeyError("Meta not found in export model")
  209. try:
  210. param_name = [k for k in export_dict if k.endswith("Param")][0]
  211. except BaseException:
  212. raise KeyError("Param not found in export model")
  213. meta = export_dict[meta_name]
  214. # set component name
  215. if hasattr(meta, "component"):
  216. meta.component = self.get_component_name()
  217. else:
  218. import warnings
  219. warnings.warn(f"{meta} should add `component` field")
  220. return export_dict
  221. def _export_meta(self):
  222. raise NotImplementedError("_export_meta not implemented")
  223. def _export_model(self):
  224. raise NotImplementedError("_export_model not implemented")
  225. def _run(self, cpn_input) -> None:
  226. # paramters
  227. self.model_param.update(cpn_input.parameters)
  228. self.model_param.check()
  229. self.component_properties.parse_component_param(
  230. cpn_input.roles, self.model_param
  231. )
  232. self.role = self.component_properties.role
  233. self.component_properties.parse_dsl_args(cpn_input.datasets, cpn_input.models)
  234. self.component_properties.parse_caches(cpn_input.caches)
  235. self.anonymous_generator = Anonymous(role=self.role, party_id=self.component_properties.local_partyid)
  236. # init component, implemented by subclasses
  237. self._init_model(self.model_param)
  238. self.callback_list = CallbackList(self.role, self.mode, self)
  239. if hasattr(self.model_param, "callback_param"):
  240. callback_param = getattr(self.model_param, "callback_param")
  241. self.callback_list.init_callback_list(callback_param)
  242. running_funcs = self.component_properties.extract_running_rules(
  243. datasets=cpn_input.datasets, models=cpn_input.models, cpn=self
  244. )
  245. LOGGER.debug(f"running_funcs: {running_funcs.todo_func_list}")
  246. saved_result = []
  247. for func, params, save_result, use_previews in running_funcs:
  248. # for func, params in zip(todo_func_list, todo_func_params):
  249. if use_previews:
  250. if params:
  251. real_param = [saved_result, params]
  252. else:
  253. real_param = saved_result
  254. LOGGER.debug("func: {}".format(func))
  255. this_data_output = func(*real_param)
  256. saved_result = []
  257. else:
  258. this_data_output = func(*params)
  259. if save_result:
  260. saved_result.append(this_data_output)
  261. if len(saved_result) == 1:
  262. self.data_output = saved_result[0]
  263. # LOGGER.debug("One data: {}".format(self.data_output.first()[1].features))
  264. LOGGER.debug(
  265. "saved_result is : {}, data_output: {}".format(
  266. saved_result, self.data_output
  267. )
  268. )
  269. # self.check_consistency()
  270. self.save_summary()
  271. def _retry(self, cpn_input) -> None:
  272. self.model_param.update(cpn_input.parameters)
  273. self.model_param.check()
  274. self.component_properties.parse_component_param(
  275. cpn_input.roles, self.model_param
  276. )
  277. self.role = self.component_properties.role
  278. self.component_properties.parse_dsl_args(cpn_input.datasets, cpn_input.models)
  279. self.component_properties.parse_caches(cpn_input.caches)
  280. # init component, implemented by subclasses
  281. self._init_model(self.model_param)
  282. self.callback_list = CallbackList(self.role, self.mode, self)
  283. if hasattr(self.model_param, "callback_param"):
  284. callback_param = getattr(self.model_param, "callback_param")
  285. self.callback_list.init_callback_list(callback_param)
  286. (
  287. train_data,
  288. validate_data,
  289. test_data,
  290. data,
  291. ) = self.component_properties.extract_input_data(
  292. datasets=cpn_input.datasets, model=self
  293. )
  294. running_funcs = RunningFuncs()
  295. latest_checkpoint = self.get_latest_checkpoint()
  296. running_funcs.add_func(self.load_model, [latest_checkpoint])
  297. running_funcs = self.component_properties.warm_start_process(
  298. running_funcs, self, train_data, validate_data
  299. )
  300. LOGGER.debug(f"running_funcs: {running_funcs.todo_func_list}")
  301. self._execute_running_funcs(running_funcs)
  302. def _execute_running_funcs(self, running_funcs):
  303. saved_result = []
  304. for func, params, save_result, use_previews in running_funcs:
  305. # for func, params in zip(todo_func_list, todo_func_params):
  306. if use_previews:
  307. if params:
  308. real_param = [saved_result, params]
  309. else:
  310. real_param = saved_result
  311. LOGGER.debug("func: {}".format(func))
  312. detected_func = assert_match_id_consistent(func)
  313. this_data_output = detected_func(*real_param)
  314. saved_result = []
  315. else:
  316. detected_func = assert_match_id_consistent(func)
  317. this_data_output = detected_func(*params)
  318. if save_result:
  319. saved_result.append(this_data_output)
  320. if len(saved_result) == 1:
  321. self.data_output = saved_result[0]
  322. LOGGER.debug(
  323. "saved_result is : {}, data_output: {}".format(
  324. saved_result, self.data_output
  325. )
  326. )
  327. self.save_summary()
  328. def export_serialized_models(self):
  329. return serialize_models(self.export_model())
  330. def get_metrics_param(self):
  331. return EvaluateParam(eval_type="binary", pos_label=1)
  332. def check_consistency(self):
  333. if not is_table(self.data_output):
  334. return
  335. if (
  336. self.component_properties.input_data_count
  337. + self.component_properties.input_eval_data_count
  338. != self.data_output.count()
  339. and self.component_properties.input_data_count
  340. != self.component_properties.input_eval_data_count
  341. ):
  342. raise ValueError("Input data count does not match with output data count")
  343. def predict(self, data_inst):
  344. pass
  345. def fit(self, *args):
  346. pass
  347. def transform(self, data_inst):
  348. pass
  349. def cross_validation(self, data_inst):
  350. pass
  351. def stepwise(self, data_inst):
  352. pass
  353. def one_vs_rest_fit(self, train_data=None):
  354. pass
  355. def one_vs_rest_predict(self, train_data):
  356. pass
  357. def init_validation_strategy(self, train_data=None, validate_data=None):
  358. pass
  359. def save_data(self):
  360. return self.data_output
  361. def export_model(self):
  362. return self.model_output
  363. def save_cache(self):
  364. return self.cache_output
  365. def set_flowid(self, flowid):
  366. # self.flowid = '.'.join([self.task_version_id, str(flowid)])
  367. self.flowid = flowid
  368. self.set_transfer_variable()
  369. def set_transfer_variable(self):
  370. if self.transfer_variable is not None:
  371. LOGGER.debug(
  372. "set flowid to transfer_variable, flowid: {}".format(self.flowid)
  373. )
  374. self.transfer_variable.set_flowid(self.flowid)
  375. def set_task_version_id(self, task_version_id):
  376. """task_version_id: jobid + component_name, reserved variable"""
  377. self.task_version_id = task_version_id
  378. def get_metric_name(self, name_prefix):
  379. if not self.need_cv:
  380. return name_prefix
  381. return "_".join(map(str, [name_prefix, self.flowid]))
  382. def set_tracker(self, tracker):
  383. self._tracker = tracker
  384. def set_checkpoint_manager(self, checkpoint_manager):
  385. checkpoint_manager.load_checkpoints_from_disk()
  386. self.checkpoint_manager = checkpoint_manager
  387. @staticmethod
  388. def set_predict_data_schema(predict_datas, schemas):
  389. if predict_datas is None:
  390. return predict_datas
  391. if isinstance(predict_datas, list):
  392. predict_data = predict_datas[0]
  393. schema = schemas[0]
  394. else:
  395. predict_data = predict_datas
  396. schema = schemas
  397. if predict_data is not None:
  398. predict_data.schema = {
  399. "header": [
  400. "label",
  401. "predict_result",
  402. "predict_score",
  403. "predict_detail",
  404. "type",
  405. ],
  406. "sid": schema.get("sid"),
  407. "content_type": "predict_result"
  408. }
  409. if schema.get("match_id_name") is not None:
  410. predict_data.schema["match_id_name"] = schema.get("match_id_name")
  411. return predict_data
  412. @staticmethod
  413. def predict_score_to_output(
  414. data_instances, predict_score, classes=None, threshold=0.5
  415. ):
  416. """
  417. Get predict result output
  418. Parameters
  419. ----------
  420. data_instances: table, data used for prediction
  421. predict_score: table, probability scores
  422. classes: list or None, all classes/label names
  423. threshold: float, predict threshold, used for binary label
  424. Returns
  425. -------
  426. Table, predict result
  427. """
  428. # regression
  429. if classes is None:
  430. predict_result = data_instances.join(
  431. predict_score, lambda d, pred: [d.label,
  432. pred,
  433. pred,
  434. predict_detail_dict_to_str({"label": pred})]
  435. )
  436. # binary
  437. elif isinstance(classes, list) and len(classes) == 2:
  438. class_neg, class_pos = classes[0], classes[1]
  439. pred_label = predict_score.mapValues(
  440. lambda x: class_pos if x > threshold else class_neg
  441. )
  442. predict_result = data_instances.mapValues(lambda x: x.label)
  443. predict_result = predict_result.join(predict_score, lambda x, y: (x, y))
  444. class_neg_name, class_pos_name = str(class_neg), str(class_pos)
  445. predict_result = predict_result.join(
  446. pred_label,
  447. lambda x, y: [
  448. x[0],
  449. y,
  450. x[1],
  451. predict_detail_dict_to_str({class_neg_name: (1 - x[1]), class_pos_name: x[1]})
  452. ],
  453. )
  454. # multi-label: input = array of predicted score of all labels
  455. elif isinstance(classes, list) and len(classes) > 2:
  456. # pred_label = predict_score.mapValues(lambda x: classes[x.index(max(x))])
  457. classes = [str(val) for val in classes]
  458. predict_result = data_instances.mapValues(lambda x: x.label)
  459. predict_result = predict_result.join(
  460. predict_score,
  461. lambda x, y: [
  462. x,
  463. int(classes[np.argmax(y)]),
  464. float(np.max(y)),
  465. predict_detail_dict_to_str(dict(zip(classes, list(y))))
  466. ],
  467. )
  468. else:
  469. raise ValueError(
  470. f"Model's classes type is {type(classes)}, classes must be None or list of length no less than 2."
  471. )
  472. def _transfer(instance, pred_res):
  473. return Instance(features=pred_res, inst_id=instance.inst_id)
  474. predict_result = data_instances.join(predict_result, _transfer)
  475. return predict_result
  476. def callback_meta(self, metric_name, metric_namespace, metric_meta: MetricMeta):
  477. if self.need_cv:
  478. metric_name = ".".join([metric_name, str(self.cv_fold)])
  479. flow_id_list = self.flowid.split(".")
  480. LOGGER.debug(
  481. "Need cv, change callback_meta, flow_id_list: {}".format(flow_id_list)
  482. )
  483. if len(flow_id_list) > 1:
  484. curve_name = ".".join(flow_id_list[1:])
  485. metric_meta.update_metas({"curve_name": curve_name})
  486. else:
  487. metric_meta.update_metas({"curve_name": metric_name})
  488. self.tracker.set_metric_meta(
  489. metric_name=metric_name,
  490. metric_namespace=metric_namespace,
  491. metric_meta=metric_meta,
  492. )
  493. def callback_metric(
  494. self, metric_name, metric_namespace, metric_data: typing.List[Metric]
  495. ):
  496. if self.need_cv:
  497. metric_name = ".".join([metric_name, str(self.cv_fold)])
  498. self.tracker.log_metric_data(
  499. metric_name=metric_name,
  500. metric_namespace=metric_namespace,
  501. metrics=metric_data,
  502. )
  503. def callback_warm_start_init_iter(self, iter_num):
  504. metric_meta = MetricMeta(
  505. name="train",
  506. metric_type="init_iter",
  507. extra_metas={
  508. "unit_name": "iters",
  509. },
  510. )
  511. self.callback_meta(
  512. metric_name="init_iter", metric_namespace="train", metric_meta=metric_meta
  513. )
  514. self.callback_metric(
  515. metric_name="init_iter",
  516. metric_namespace="train",
  517. metric_data=[Metric("init_iter", iter_num)],
  518. )
  519. def get_latest_checkpoint(self):
  520. return self.checkpoint_manager.latest_checkpoint.read()
  521. def save_summary(self):
  522. self.tracker.log_component_summary(summary_data=self.summary())
  523. def set_cv_fold(self, cv_fold):
  524. self.cv_fold = cv_fold
  525. def summary(self):
  526. return copy.deepcopy(self._summary)
  527. def set_summary(self, new_summary):
  528. """
  529. Model summary setter
  530. Parameters
  531. ----------
  532. new_summary: dict, summary to replace the original one
  533. Returns
  534. -------
  535. """
  536. if not isinstance(new_summary, dict):
  537. raise ValueError(
  538. f"summary should be of dict type, received {type(new_summary)} instead."
  539. )
  540. self._summary = copy.deepcopy(new_summary)
  541. def add_summary(self, new_key, new_value):
  542. """
  543. Add key:value pair to model summary
  544. Parameters
  545. ----------
  546. new_key: str
  547. new_value: object
  548. Returns
  549. -------
  550. """
  551. original_value = self._summary.get(new_key, None)
  552. if original_value is not None:
  553. LOGGER.warning(
  554. f"{new_key} already exists in model summary."
  555. f"Corresponding value {original_value} will be replaced by {new_value}"
  556. )
  557. self._summary[new_key] = new_value
  558. # LOGGER.debug(f"{new_key}: {new_value} added to summary.")
  559. def merge_summary(self, new_content, suffix=None, suffix_sep="_"):
  560. """
  561. Merge new content into model summary
  562. Parameters
  563. ----------
  564. new_content: dict, content to be merged into summary
  565. suffix: str or None, suffix used to create new key if any key in new_content already exixts in model summary
  566. suffix_sep: string, default '_', suffix separator used to create new key
  567. Returns
  568. -------
  569. """
  570. if not isinstance(new_content, dict):
  571. raise ValueError(
  572. f"To merge new content into model summary, "
  573. f"value must be of dict type, received {type(new_content)} instead."
  574. )
  575. new_summary = self.summary()
  576. keyset = new_summary.keys() | new_content.keys()
  577. for key in keyset:
  578. if key in new_summary and key in new_content:
  579. if suffix is not None:
  580. new_key = f"{key}{suffix_sep}{suffix}"
  581. else:
  582. new_key = key
  583. new_value = new_content.get(key)
  584. new_summary[new_key] = new_value
  585. elif key in new_content:
  586. new_summary[key] = new_content.get(key)
  587. else:
  588. pass
  589. self.set_summary(new_summary)
  590. @staticmethod
  591. def extract_data(data: dict):
  592. LOGGER.debug("In extract_data, data input: {}".format(data))
  593. if len(data) == 0:
  594. return data
  595. if len(data) == 1:
  596. return list(data.values())[0]
  597. return data
  598. @staticmethod
  599. def check_schema_content(schema):
  600. """
  601. check for repeated header & illegal/non-printable chars except for space
  602. allow non-ascii chars
  603. :param schema: dict
  604. :return:
  605. """
  606. abnormal_detection.check_legal_schema(schema)
  607. def align_data_header(self, data_instances, pre_header):
  608. """
  609. align features of given data, raise error if value in given schema not found
  610. :param data_instances: data table
  611. :param pre_header: list, header of model
  612. :return: dtable, aligned data
  613. """
  614. result_data = self._align_cache.get(id(data_instances))
  615. if result_data is None:
  616. result_data = header_alignment(
  617. data_instances=data_instances, pre_header=pre_header
  618. )
  619. self._align_cache[id(data_instances)] = result_data
  620. return result_data
  621. @staticmethod
  622. def pass_data(data):
  623. if isinstance(data, dict) and len(data) >= 1:
  624. data = list(data.values())[0]
  625. return data
  626. def obtain_data(self, data_list):
  627. if isinstance(data_list, list):
  628. return data_list[0]
  629. return data_list