test.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import time
  17. import click
  18. from flow_client.flow_cli.utils import cli_args
  19. from flow_client.flow_cli.utils.cli_utils import prettify
  20. from flow_sdk.client import FlowClient
  21. from pipeline.backend.pipeline import PipeLine
  22. from pipeline.component import (
  23. DataTransform, Evaluation, HeteroLR,
  24. HeteroSecureBoost, Intersection, Reader,
  25. )
  26. from pipeline.interface import Data
  27. @click.group(short_help="FATE Flow Test Operations")
  28. @click.pass_context
  29. def test(ctx):
  30. """
  31. \b
  32. Provides numbers of component operational commands, including metrics, parameters and etc.
  33. For more details, please check out the help text.
  34. """
  35. pass
  36. @test.command("toy", short_help="Toy Test Command")
  37. @cli_args.GUEST_PARTYID_REQUIRED
  38. @cli_args.HOST_PARTYID_REQUIRED
  39. @cli_args.TIMEOUT
  40. @cli_args.TASK_CORES
  41. @click.pass_context
  42. def toy(ctx, **kwargs):
  43. flow_sdk = FlowClient(ip=ctx.obj["ip"], port=ctx.obj["http_port"], version=ctx.obj["api_version"],
  44. app_key=ctx.obj.get("app_key"), secret_key=ctx.obj.get("secret_key"))
  45. submit_result = flow_sdk.test.toy(**kwargs)
  46. if submit_result["retcode"] == 0:
  47. for t in range(kwargs["timeout"]):
  48. job_id = submit_result["jobId"]
  49. r = flow_sdk.job.query(job_id=job_id, role="guest", party_id=kwargs["guest_party_id"])
  50. if r["retcode"] == 0 and len(r["data"]):
  51. job_status = r["data"][0]["f_status"]
  52. print(f"toy test job {job_id} is {job_status}")
  53. if job_status in {"success", "failed", "canceled"}:
  54. check_log(flow_sdk, kwargs["guest_party_id"], job_id, job_status)
  55. break
  56. time.sleep(1)
  57. else:
  58. print(f"check job status timeout")
  59. check_log(flow_sdk, kwargs["guest_party_id"], job_id, job_status)
  60. else:
  61. prettify(submit_result)
  62. def check_log(flow_sdk, party_id, job_id, job_status):
  63. r = flow_sdk.job.log(job_id=job_id, output_path="./logs/toy")
  64. if r["retcode"] == 0:
  65. log_msg = flow_sdk.test.check_toy(party_id, job_status, r["directory"])
  66. try:
  67. for msg in log_msg:
  68. print(msg)
  69. except BaseException:
  70. print(f"auto check log failed, please check {r['directory']}")
  71. else:
  72. print(f"get log failed, please check PROJECT_BASE/logs/{job_id} on the fateflow server machine")
  73. @test.command("min", short_help="Min Test Command")
  74. @click.option("-t", "--data-type", type=click.Choice(["fast", "normal"]), default="fast", show_default=True,
  75. help="fast for breast data, normal for default credit data")
  76. @click.option("--sbt/--no-sbt", is_flag=True, default=True, show_default=True, help="run sbt test or not")
  77. @cli_args.GUEST_PARTYID_REQUIRED
  78. @cli_args.HOST_PARTYID_REQUIRED
  79. @cli_args.ARBITER_PARTYID_REQUIRED
  80. @click.pass_context
  81. def run_min_test(ctx, data_type, sbt, guest_party_id, host_party_id, arbiter_party_id, **kwargs):
  82. guest_party_id = int(guest_party_id)
  83. host_party_id = int(host_party_id)
  84. arbiter_party_id = int(arbiter_party_id)
  85. if data_type == "fast":
  86. guest_train_data = {"name": "breast_hetero_guest", "namespace": "experiment"}
  87. host_train_data = {"name": "breast_hetero_host", "namespace": "experiment"}
  88. auc_base = 0.98
  89. elif data_type == "normal":
  90. guest_train_data = {"name": "default_credit_hetero_guest", "namespace": "experiment"}
  91. host_train_data = {"name": "default_credit_hetero_host", "namespace": "experiment"}
  92. auc_base = 0.69
  93. else:
  94. click.echo(f"data type {data_type} not supported", err=True)
  95. raise click.Abort()
  96. lr_pipeline = lr_train_pipeline(guest_party_id, host_party_id, arbiter_party_id, guest_train_data, host_train_data)
  97. lr_auc = get_auc(lr_pipeline, "hetero_lr_0")
  98. if lr_auc < auc_base:
  99. click.echo(f"Warning: The LR auc {lr_auc} is lower than expect value {auc_base}")
  100. predict_pipeline(lr_pipeline, guest_party_id, host_party_id, guest_train_data, host_train_data)
  101. if sbt:
  102. sbt_pipeline = sbt_train_pipeline(guest_party_id, host_party_id, guest_train_data, host_train_data)
  103. sbt_auc = get_auc(sbt_pipeline, "hetero_secureboost_0")
  104. if sbt_auc < auc_base:
  105. click.echo(f"Warning: The SBT auc {sbt_auc} is lower than expect value {auc_base}")
  106. predict_pipeline(sbt_pipeline, guest_party_id, host_party_id, guest_train_data, host_train_data)
  107. def lr_train_pipeline(guest, host, arbiter, guest_train_data, host_train_data):
  108. pipeline = PipeLine().set_initiator(role="guest", party_id=guest).set_roles(guest=guest, host=host, arbiter=arbiter)
  109. reader_0 = Reader(name="reader_0")
  110. reader_0.get_party_instance(role="guest", party_id=guest).component_param(table=guest_train_data)
  111. reader_0.get_party_instance(role="host", party_id=host).component_param(table=host_train_data)
  112. data_transform_0 = DataTransform(name="data_transform_0")
  113. data_transform_0.get_party_instance(role="guest", party_id=guest).component_param(
  114. with_label=True, output_format="dense")
  115. data_transform_0.get_party_instance(role="host", party_id=host).component_param(with_label=False)
  116. intersection_0 = Intersection(name="intersection_0")
  117. lr_param = {
  118. "penalty": "L2",
  119. "tol": 0.0001,
  120. "alpha": 0.01,
  121. "optimizer": "rmsprop",
  122. "batch_size": -1,
  123. "learning_rate": 0.15,
  124. "init_param": {
  125. "init_method": "zeros",
  126. "fit_intercept": True,
  127. },
  128. "max_iter": 30,
  129. "early_stop": "diff",
  130. "encrypt_param": {
  131. "key_length": 1024,
  132. },
  133. "cv_param": {
  134. "n_splits": 5,
  135. "shuffle": False,
  136. "random_seed": 103,
  137. "need_cv": False,
  138. },
  139. "validation_freqs": 3,
  140. }
  141. hetero_lr_0 = HeteroLR(name="hetero_lr_0", **lr_param)
  142. evaluation_0 = Evaluation(name="evaluation_0", eval_type="binary")
  143. pipeline.add_component(reader_0)
  144. pipeline.add_component(data_transform_0, data=Data(data=reader_0.output.data))
  145. pipeline.add_component(intersection_0, data=Data(data=data_transform_0.output.data))
  146. pipeline.add_component(hetero_lr_0, data=Data(train_data=intersection_0.output.data))
  147. pipeline.add_component(evaluation_0, data=Data(data=hetero_lr_0.output.data))
  148. pipeline.compile()
  149. pipeline.fit()
  150. return pipeline
  151. def sbt_train_pipeline(guest, host, guest_train_data, host_train_data):
  152. pipeline = PipeLine().set_initiator(role="guest", party_id=guest).set_roles(guest=guest, host=host)
  153. reader_0 = Reader(name="reader_0")
  154. reader_0.get_party_instance(role="guest", party_id=guest).component_param(table=guest_train_data)
  155. reader_0.get_party_instance(role="host", party_id=host).component_param(table=host_train_data)
  156. data_transform_0 = DataTransform(name="data_transform_0")
  157. data_transform_0.get_party_instance(role="guest", party_id=guest).component_param(
  158. with_label=True, output_format="dense")
  159. data_transform_0.get_party_instance(role="host", party_id=host).component_param(with_label=False)
  160. intersection_0 = Intersection(name="intersection_0")
  161. sbt_param = {
  162. "task_type": "classification",
  163. "objective_param": {
  164. "objective": "cross_entropy",
  165. },
  166. "num_trees": 3,
  167. "validation_freqs": 1,
  168. "encrypt_param": {
  169. "method": "paillier",
  170. },
  171. "tree_param": {
  172. "max_depth": 3,
  173. }
  174. }
  175. hetero_secure_boost_0 = HeteroSecureBoost(name="hetero_secureboost_0", **sbt_param)
  176. evaluation_0 = Evaluation(name="evaluation_0", eval_type="binary")
  177. pipeline.add_component(reader_0)
  178. pipeline.add_component(data_transform_0, data=Data(data=reader_0.output.data))
  179. pipeline.add_component(intersection_0, data=Data(data=data_transform_0.output.data))
  180. pipeline.add_component(hetero_secure_boost_0, data=Data(train_data=intersection_0.output.data))
  181. pipeline.add_component(evaluation_0, data=Data(data=hetero_secure_boost_0.output.data))
  182. pipeline.compile()
  183. pipeline.fit()
  184. return pipeline
  185. def get_auc(pipeline, component_name):
  186. cpn_summary = pipeline.get_component(component_name).get_summary()
  187. auc = cpn_summary.get("validation_metrics").get("train").get("auc")[-1]
  188. return auc
  189. def predict_pipeline(train_pipeline, guest, host, guest_train_data, host_train_data):
  190. cpn_list = train_pipeline.get_component_list()[1:]
  191. train_pipeline.deploy_component(cpn_list)
  192. pipeline = PipeLine()
  193. reader_0 = Reader(name="reader_0")
  194. reader_0.get_party_instance(role="guest", party_id=guest).component_param(table=guest_train_data)
  195. reader_0.get_party_instance(role="host", party_id=host).component_param(table=host_train_data)
  196. pipeline.add_component(reader_0)
  197. pipeline.add_component(train_pipeline, data=Data(predict_input={
  198. train_pipeline.data_transform_0.input.data: reader_0.output.data}))
  199. pipeline.predict()