task_info.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import sys
  17. from pipeline.utils.logger import LOGGER
  18. class TaskInfo(object):
  19. def __init__(self, jobid, component, job_client, role='guest', party_id=9999):
  20. self._jobid = jobid
  21. self._component = component
  22. self._job_client = job_client
  23. self._party_id = party_id
  24. self._role = role
  25. @LOGGER.catch(onerror=lambda _: sys.exit(1))
  26. def get_output_data(self, limits=None, to_pandas=True):
  27. '''
  28. gets downloaded data of arbitrary component
  29. Parameters
  30. ----------
  31. limits: int, None, default None. Maximum number of lines returned, including header. If None, return all lines.
  32. to_pandas: bool, default True.
  33. Returns
  34. -------
  35. single output example: pandas.DataFrame
  36. multiple output example:
  37. {
  38. train_data: train_data_df,
  39. validate_data: validate_data_df,
  40. test_data: test_data_df
  41. }
  42. '''
  43. return self._job_client.get_output_data(self._jobid, self._component.name, self._role,
  44. self._party_id, limits, to_pandas=to_pandas)
  45. @LOGGER.catch(onerror=lambda _: sys.exit(1))
  46. def get_model_param(self):
  47. '''
  48. get fitted model parameters
  49. Returns
  50. -------
  51. dict
  52. '''
  53. return self._job_client.get_model_param(self._jobid, self._component.name, self._role, self._party_id)
  54. @LOGGER.catch(onerror=lambda _: sys.exit(1))
  55. def get_output_data_table(self):
  56. '''
  57. get output data table information, including table name and namespace, as given by flow client
  58. Returns
  59. -------
  60. dict
  61. '''
  62. return self._job_client.get_output_data_table(self._jobid, self._component.name, self._role, self._party_id)
  63. @LOGGER.catch(onerror=lambda _: sys.exit(1))
  64. def get_summary(self):
  65. '''
  66. get module summary of arbitrary component
  67. Returns
  68. -------
  69. dict
  70. '''
  71. return self._job_client.get_summary(self._jobid, self._component.name, self._role, self._party_id)