fate-hetero_nn.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. import argparse
  2. from collections import OrderedDict
  3. from pipeline.backend.pipeline import PipeLine
  4. from pipeline.component import DataTransform
  5. from pipeline.component import HeteroNN
  6. from pipeline.component import Intersection
  7. from pipeline.component import Reader
  8. from pipeline.component import Evaluation
  9. from pipeline.interface import Data
  10. from pipeline.utils.tools import load_job_config, JobConfig
  11. from pipeline.interface import Model
  12. from federatedml.evaluation.metrics import classification_metric
  13. from fate_test.utils import extract_data, parse_summary_result
  14. from pipeline import fate_torch_hook
  15. import torch as t
  16. from torch import nn
  17. from torch.nn import init
  18. from torch import optim
  19. from pipeline import fate_torch as ft
  20. fate_torch_hook(t)
  21. def main(config="./config.yaml", param="./hetero_nn_breast_config.yaml", namespace=""):
  22. # obtain config
  23. if isinstance(config, str):
  24. config = load_job_config(config)
  25. if isinstance(param, str):
  26. param = JobConfig.load_from_file(param)
  27. parties = config.parties
  28. guest = parties.guest[0]
  29. host = parties.host[0]
  30. guest_train_data = {"name": param["guest_table_name"], "namespace": f"experiment{namespace}"}
  31. host_train_data = {"name": param["host_table_name"], "namespace": f"experiment{namespace}"}
  32. pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host)
  33. reader_0 = Reader(name="reader_0")
  34. reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=guest_train_data)
  35. reader_0.get_party_instance(role='host', party_id=host).component_param(table=host_train_data)
  36. data_transform_0 = DataTransform(name="data_transform_0")
  37. data_transform_0.get_party_instance(role='guest', party_id=guest).component_param(with_label=True)
  38. data_transform_0.get_party_instance(role='host', party_id=host).component_param(with_label=False)
  39. intersection_0 = Intersection(name="intersection_0")
  40. guest_input_shape = param['guest_input_shape']
  41. host_input_shape = param['host_input_shape']
  42. # define model structures
  43. bottom_model_guest = t.nn.Sequential(
  44. nn.Linear(guest_input_shape, 10, True),
  45. nn.ReLU(),
  46. nn.Linear(10, 8, True),
  47. nn.ReLU()
  48. )
  49. bottom_model_host = t.nn.Sequential(
  50. nn.Linear(host_input_shape, 10, True),
  51. nn.ReLU(),
  52. nn.Linear(10, 8, True),
  53. nn.ReLU()
  54. )
  55. interactive_layer = t.nn.Linear(8, 4, True)
  56. top_model_guest = t.nn.Sequential(
  57. nn.Linear(4, 1, True),
  58. nn.Sigmoid()
  59. )
  60. loss_fn = nn.BCELoss()
  61. opt: ft.optim.Adam = optim.Adam(lr=param['learning_rate'])
  62. hetero_nn_0 = HeteroNN(name="hetero_nn_0", epochs=param["epochs"],
  63. interactive_layer_lr=param["learning_rate"], batch_size=param["batch_size"],
  64. early_stop="diff")
  65. guest_nn_0 = hetero_nn_0.get_party_instance(role='guest', party_id=guest)
  66. guest_nn_0.add_bottom_model(bottom_model_guest)
  67. guest_nn_0.add_top_model(top_model_guest)
  68. guest_nn_0.set_interactve_layer(interactive_layer)
  69. host_nn_0 = hetero_nn_0.get_party_instance(role='host', party_id=host)
  70. host_nn_0.add_bottom_model(bottom_model_host)
  71. # do remember to compile
  72. hetero_nn_0.compile(opt, loss=loss_fn)
  73. hetero_nn_1 = HeteroNN(name="hetero_nn_1")
  74. evaluation_0 = Evaluation(name="evaluation_0", eval_type=param['eval_type'])
  75. pipeline.add_component(reader_0)
  76. pipeline.add_component(data_transform_0, data=Data(data=reader_0.output.data))
  77. pipeline.add_component(intersection_0, data=Data(data=data_transform_0.output.data))
  78. pipeline.add_component(hetero_nn_0, data=Data(train_data=intersection_0.output.data))
  79. pipeline.add_component(hetero_nn_1, data=Data(test_data=intersection_0.output.data),
  80. model=Model(hetero_nn_0.output.model))
  81. pipeline.add_component(evaluation_0, data=Data(data=hetero_nn_0.output.data))
  82. pipeline.compile()
  83. pipeline.fit()
  84. nn_0_data = pipeline.get_component("hetero_nn_0").get_output_data()
  85. nn_1_data = pipeline.get_component("hetero_nn_1").get_output_data()
  86. nn_0_score = extract_data(nn_0_data, "predict_result")
  87. nn_0_label = extract_data(nn_0_data, "label")
  88. nn_1_score = extract_data(nn_1_data, "predict_result")
  89. nn_1_label = extract_data(nn_1_data, "label")
  90. nn_0_score_label = extract_data(nn_0_data, "predict_result", keep_id=True)
  91. nn_1_score_label = extract_data(nn_1_data, "predict_result", keep_id=True)
  92. metric_summary = parse_summary_result(pipeline.get_component("evaluation_0").get_summary())
  93. eval_type = param['eval_type']
  94. if eval_type == "binary":
  95. metric_nn = {
  96. "score_diversity_ratio": classification_metric.Distribution.compute(nn_0_score_label, nn_1_score_label),
  97. "ks_2samp": classification_metric.KSTest.compute(nn_0_score, nn_1_score),
  98. "mAP_D_value": classification_metric.AveragePrecisionScore().compute(nn_0_score, nn_1_score, nn_0_label,
  99. nn_1_label)}
  100. metric_summary["distribution_metrics"] = {"hetero_nn": metric_nn}
  101. elif eval_type == "multi":
  102. metric_nn = {
  103. "score_diversity_ratio": classification_metric.Distribution.compute(nn_0_score_label, nn_1_score_label)}
  104. metric_summary["distribution_metrics"] = {"hetero_nn": metric_nn}
  105. data_summary = {"train": {"guest": guest_train_data["name"], "host": host_train_data["name"]},
  106. "test": {"guest": guest_train_data["name"], "host": host_train_data["name"]}
  107. }
  108. return data_summary, metric_summary
  109. if __name__ == "__main__":
  110. parser = argparse.ArgumentParser("BENCHMARK-QUALITY PIPELINE JOB")
  111. parser.add_argument("-config", type=str,
  112. help="config file")
  113. parser.add_argument("-param", type=str,
  114. help="config file for params")
  115. args = parser.parse_args()
  116. if args.config is not None:
  117. main(args.config, args.param)
  118. else:
  119. main()