tracking_test.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import json
  2. import os
  3. import time
  4. import unittest
  5. import requests
  6. from fate_flow.entity.run_status import EndStatus, JobStatus
  7. from fate_arch.common.file_utils import load_json_conf
  8. from fate_flow.utils.base_utils import get_fate_flow_python_directory
  9. from fate_flow.settings import API_VERSION, HOST, HTTP_PORT,IS_STANDALONE
  10. WORK_MODE = 1 if not IS_STANDALONE else 0
  11. class TestTracking(unittest.TestCase):
  12. def setUp(self):
  13. self.sleep_time = 10
  14. self.success_job_dir = './jobs/'
  15. self.dsl_path = 'fate_flow/examples/test_hetero_lr_job_dsl.json'
  16. self.config_path = 'fate_flow/examples/test_hetero_lr_job_conf.json'
  17. self.test_component_name = 'hetero_feature_selection_0'
  18. self.server_url = "http://{}:{}/{}".format(HOST, HTTP_PORT, API_VERSION)
  19. self.party_info = load_json_conf(os.path.abspath(os.path.join('./jobs', 'party_info.json'))) if WORK_MODE else None
  20. self.guest_party_id = self.party_info['guest'] if WORK_MODE else 9999
  21. self.host_party_id = self.party_info['host'] if WORK_MODE else 10000
  22. def test_tracking(self):
  23. with open(os.path.join(get_fate_flow_python_directory(), self.dsl_path), 'r') as f:
  24. dsl_data = json.load(f)
  25. with open(os.path.join(get_fate_flow_python_directory(), self.config_path), 'r') as f:
  26. config_data = json.load(f)
  27. config_data[ "initiator"]["party_id"] = self.guest_party_id
  28. config_data["role"] = {
  29. "guest": [self.guest_party_id],
  30. "host": [self.host_party_id],
  31. "arbiter": [self.host_party_id]
  32. }
  33. response = requests.post("/".join([self.server_url, 'job', 'submit']),
  34. json={'job_dsl': dsl_data, 'job_runtime_conf': config_data})
  35. self.assertTrue(response.status_code in [200, 201])
  36. self.assertTrue(int(response.json()['retcode']) == 0)
  37. job_id = response.json()['jobId']
  38. job_info = {'f_status': 'running'}
  39. for i in range(60):
  40. response = requests.post("/".join([self.server_url, 'job', 'query']), json={'job_id': job_id, 'role': 'guest'})
  41. self.assertTrue(response.status_code in [200, 201])
  42. job_info = response.json()['data'][0]
  43. if EndStatus.contains(job_info['f_status']):
  44. break
  45. time.sleep(self.sleep_time)
  46. print('waiting job run success, the job has been running for {}s'.format((i+1)*self.sleep_time))
  47. self.assertTrue(job_info['f_status'] == JobStatus.SUCCESS)
  48. os.makedirs(self.success_job_dir, exist_ok=True)
  49. with open(os.path.join(self.success_job_dir, job_id), 'w') as fw:
  50. json.dump(job_info, fw)
  51. self.assertTrue(os.path.exists(os.path.join(self.success_job_dir, job_id)))
  52. # test_component_parameters
  53. test_component(self, 'component/parameters')
  54. # test_component_metric_all
  55. test_component(self, 'component/metric/all')
  56. # test_component_metric
  57. test_component(self, 'component/metrics')
  58. # test_component_output_model
  59. test_component(self, 'component/output/model')
  60. # test_component_output_data_download
  61. test_component(self, 'component/output/data')
  62. # test_component_output_data_download
  63. test_component(self, 'component/output/data/download')
  64. # test_job_data_view
  65. test_component(self, 'job/data_view')
  66. def test_component(self, fun):
  67. job_id = os.listdir(os.path.abspath(os.path.join(self.success_job_dir)))[-1]
  68. job_info = load_json_conf(os.path.abspath(os.path.join(self.success_job_dir, job_id)))
  69. data = {'job_id': job_id, 'role': job_info['f_role'], 'party_id': job_info['f_party_id'], 'component_name': self.test_component_name}
  70. if 'download' in fun:
  71. response = requests.get("/".join([self.server_url, "tracking", fun]), json=data, stream=True)
  72. self.assertTrue(response.status_code in [200, 201])
  73. else:
  74. response = requests.post("/".join([self.server_url, 'tracking', fun]), json=data)
  75. self.assertTrue(response.status_code in [200, 201])
  76. self.assertTrue(int(response.json()['retcode']) == 0)
  77. if __name__ == '__main__':
  78. unittest.main()