component.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import os
  17. from contextlib import closing
  18. from typing import List
  19. from flow_sdk.client.api.base import BaseFlowAPI
  20. from flow_sdk.utils import preprocess, check_config, download_from_request
  21. class Component(BaseFlowAPI):
  22. def list(self, job_id):
  23. kwargs = locals()
  24. config_data, dsl_data = preprocess(**kwargs)
  25. return self._post(url='tracking/component/list', json=config_data)
  26. def metrics(self, job_id, role, party_id, component_name):
  27. kwargs = locals()
  28. config_data, dsl_data = preprocess(**kwargs)
  29. check_config(config=config_data,
  30. required_arguments=['job_id', 'component_name', 'role', 'party_id'])
  31. return self._post(url='tracking/component/metrics', json=config_data)
  32. def metric_all(self, job_id, role, party_id, component_name):
  33. kwargs = locals()
  34. config_data, dsl_data = preprocess(**kwargs)
  35. check_config(config=config_data,
  36. required_arguments=['job_id', 'component_name', 'role', 'party_id'])
  37. return self._post(url='tracking/component/metric/all', json=config_data)
  38. def metric_delete(self, date=None, job_id=None):
  39. kwargs = locals()
  40. config_data, dsl_data = preprocess(**kwargs)
  41. if config_data.get('date'):
  42. config_data['model'] = config_data.pop('date')
  43. return self._post(url='tracking/component/metric/delete', json=config_data)
  44. def parameters(self, job_id, role, party_id, component_name):
  45. kwargs = locals()
  46. config_data, dsl_data = preprocess(**kwargs)
  47. check_config(config=config_data,
  48. required_arguments=['job_id', 'component_name', 'role', 'party_id'])
  49. return self._post(url='tracking/component/parameters', json=config_data)
  50. def output_data(self, job_id, role, party_id, component_name, output_path, limit=-1):
  51. kwargs = locals()
  52. config_data, dsl_data = preprocess(**kwargs)
  53. check_config(config=config_data,
  54. required_arguments=['job_id', 'component_name', 'role', 'party_id', 'output_path'])
  55. tar_file_name = 'job_{}_{}_{}_{}_output_data.tar.gz'.format(config_data['job_id'],
  56. config_data['component_name'],
  57. config_data['role'],
  58. config_data['party_id'])
  59. extract_dir = os.path.join(config_data['output_path'], tar_file_name.replace('.tar.gz', ''))
  60. with closing(self._get(url='tracking/component/output/data/download',
  61. handle_result=False, json=config_data, stream=True)) as response:
  62. if response.status_code == 200:
  63. try:
  64. download_from_request(http_response=response, tar_file_name=tar_file_name, extract_dir=extract_dir)
  65. response = {'retcode': 0,
  66. 'directory': extract_dir,
  67. 'retmsg': 'download successfully, please check {} directory'.format(extract_dir)}
  68. except BaseException:
  69. response = {'retcode': 100,
  70. 'retmsg': 'download failed, please check if the parameters are correct'}
  71. else:
  72. response = response.json()
  73. return response
  74. def output_model(self, job_id, role, party_id, component_name):
  75. kwargs = locals()
  76. config_data, dsl_data = preprocess(**kwargs)
  77. check_config(config=config_data,
  78. required_arguments=['job_id', 'component_name', 'role', 'party_id'])
  79. return self._post(url='tracking/component/output/model', json=config_data)
  80. def output_data_table(self, job_id, role, party_id, component_name):
  81. kwargs = locals()
  82. config_data, dsl_data = preprocess(**kwargs)
  83. check_config(config=config_data,
  84. required_arguments=['job_id', 'component_name', 'role', 'party_id'])
  85. return self._post(url='tracking/component/output/data/table', json=config_data)
  86. def get_summary(self, job_id, role, party_id, component_name):
  87. kwargs = locals()
  88. config_data, dsl_data = preprocess(**kwargs)
  89. check_config(config=config_data,
  90. required_arguments=['job_id', 'component_name', 'role', 'party_id'])
  91. res = self._post(url='tracking/component/summary/download', json=config_data)
  92. if not res.get('data'):
  93. res['data'] = {}
  94. return res
  95. def hetero_model_merge(
  96. self,
  97. model_id: str, model_version: str, guest_party_id: str, host_party_ids: List[str],
  98. component_name: str, model_type: str, output_format: str, target_name: str = None,
  99. host_rename: bool = None, include_guest_coef: bool = None,
  100. ):
  101. kwargs = locals()
  102. config_data, dsl_data = preprocess(**kwargs)
  103. check_config(config=config_data, required_arguments=(
  104. 'model_id', 'model_version', 'guest_party_id', 'host_party_ids',
  105. 'component_name', 'model_type', 'output_format',
  106. ))
  107. res = self._post(url='component/hetero/merge', json=config_data)
  108. return res
  109. def woe_array_extract(
  110. self,
  111. model_id: str, model_version: str, party_id: str, role: str, component_name: str,
  112. ):
  113. kwargs = locals()
  114. config_data, dsl_data = preprocess(**kwargs)
  115. check_config(config=config_data, required_arguments=(
  116. 'model_id', 'model_version', 'party_id', 'role', 'component_name',
  117. ))
  118. res = self._post(url='component/woe_array/extract', json=config_data)
  119. return res
  120. def woe_array_merge(
  121. self,
  122. model_id: str, model_version: str, party_id: str, role: str, component_name: str,
  123. woe_array: dict,
  124. ):
  125. kwargs = locals()
  126. config_data, dsl_data = preprocess(**kwargs)
  127. check_config(config=config_data, required_arguments=(
  128. 'model_id', 'model_version', 'party_id', 'role', 'component_name',
  129. 'woe_array',
  130. ))
  131. res = self._post(url='component/woe_array/merge', json=config_data)
  132. return res