fate-homo_nn.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import argparse
  17. import torch as t
  18. from pipeline import fate_torch_hook
  19. from pipeline.backend.pipeline import PipeLine
  20. from pipeline.component import DataTransform, HomoNN, Evaluation
  21. from pipeline.component import Reader
  22. from pipeline.interface import Data, Model
  23. from pipeline.utils.tools import load_job_config, JobConfig
  24. from federatedml.evaluation.metrics import classification_metric
  25. from fate_test.utils import extract_data, parse_summary_result
  26. from pipeline.component.nn import TrainerParam, DatasetParam
  27. fate_torch_hook(t)
  28. class dataset(object):
  29. breast = {
  30. "guest": {"name": "breast_homo_guest", "namespace": "experiment"},
  31. "host": [
  32. {"name": "breast_homo_host", "namespace": "experiment"},
  33. {"name": "breast_homo_host", "namespace": "experiment"}
  34. ]
  35. }
  36. vehicle = {
  37. "guest": {"name": "vehicle_scale_homo_guest", "namespace": "experiment"},
  38. "host": [
  39. {"name": "vehicle_scale_homo_host", "namespace": "experiment"},
  40. {"name": "vehicle_scale_homo_host", "namespace": "experiment"}
  41. ]
  42. }
  43. def main(config="../../config.yaml", param="param_conf.yaml", namespace=""):
  44. num_host = 1
  45. if isinstance(config, str):
  46. config = load_job_config(config)
  47. if isinstance(param, str):
  48. param = JobConfig.load_from_file(param)
  49. epoch = param["epoch"]
  50. lr = param["lr"]
  51. batch_size = param.get("batch_size", -1)
  52. is_multy = param["is_multy"]
  53. data = getattr(dataset, param.get("dataset", "vehicle"))
  54. if is_multy:
  55. loss = t.nn.CrossEntropyLoss()
  56. else:
  57. loss = t.nn.BCELoss()
  58. input_shape = 18 if is_multy else 30
  59. output_shape = 4 if is_multy else 1
  60. out_act = t.nn.Softmax(dim=1) if is_multy else t.nn.Sigmoid()
  61. model = t.nn.Sequential(
  62. t.nn.Linear(input_shape, 16),
  63. t.nn.ReLU(),
  64. t.nn.Linear(16, output_shape),
  65. out_act
  66. )
  67. optimizer = t.optim.Adam(model.parameters(), lr=lr)
  68. guest_train_data = data["guest"]
  69. host_train_data = data["host"][:num_host]
  70. for d in [guest_train_data, *host_train_data]:
  71. d["namespace"] = f"{d['namespace']}{namespace}"
  72. hosts = config.parties.host[:num_host]
  73. pipeline = PipeLine() .set_initiator(
  74. role='guest',
  75. party_id=config.parties.guest[0]) .set_roles(
  76. guest=config.parties.guest[0],
  77. host=hosts,
  78. arbiter=config.parties.arbiter)
  79. reader_0 = Reader(name="reader_0")
  80. reader_0.get_party_instance(
  81. role='guest',
  82. party_id=config.parties.guest[0]).component_param(
  83. table=guest_train_data)
  84. for i in range(num_host):
  85. reader_0.get_party_instance(role='host', party_id=hosts[i]) \
  86. .component_param(table=host_train_data[i])
  87. data_transform_0 = DataTransform(name="data_transform_0", with_label=True)
  88. data_transform_0.get_party_instance(
  89. role='guest', party_id=config.parties.guest[0]) .component_param(
  90. with_label=True, output_format="dense")
  91. data_transform_0.get_party_instance(
  92. role='host',
  93. party_id=hosts).component_param(
  94. with_label=True)
  95. if is_multy:
  96. ds_param = DatasetParam(
  97. dataset_name='table',
  98. flatten_label=True,
  99. label_dtype='long')
  100. else:
  101. ds_param = DatasetParam(dataset_name='table')
  102. homo_nn_0 = HomoNN(
  103. name="homo_nn_0",
  104. trainer=TrainerParam(
  105. trainer_name='fedavg_trainer',
  106. epochs=epoch,
  107. batch_size=batch_size,
  108. ),
  109. dataset=ds_param,
  110. torch_seed=100,
  111. optimizer=optimizer,
  112. loss=loss,
  113. model=model)
  114. homo_nn_1 = HomoNN(name="homo_nn_1")
  115. if is_multy:
  116. eval_type = "multi"
  117. else:
  118. eval_type = "binary"
  119. evaluation_0 = Evaluation(
  120. name='evaluation_0',
  121. eval_type=eval_type,
  122. metrics=[
  123. "accuracy",
  124. "precision",
  125. "recall"])
  126. pipeline.add_component(reader_0)
  127. pipeline.add_component(
  128. data_transform_0, data=Data(
  129. data=reader_0.output.data))
  130. pipeline.add_component(homo_nn_0, data=Data(
  131. train_data=data_transform_0.output.data))
  132. pipeline.add_component(
  133. homo_nn_1, data=Data(
  134. test_data=data_transform_0.output.data), model=Model(
  135. homo_nn_0.output.model))
  136. pipeline.add_component(evaluation_0, data=Data(data=homo_nn_0.output.data))
  137. pipeline.compile()
  138. pipeline.fit()
  139. metric_summary = parse_summary_result(
  140. pipeline.get_component("evaluation_0").get_summary())
  141. nn_0_data = pipeline.get_component("homo_nn_0").get_output_data()
  142. nn_1_data = pipeline.get_component("homo_nn_1").get_output_data()
  143. nn_0_score = extract_data(nn_0_data, "predict_result")
  144. nn_0_label = extract_data(nn_0_data, "label")
  145. nn_1_score = extract_data(nn_1_data, "predict_result")
  146. nn_1_label = extract_data(nn_1_data, "label")
  147. nn_0_score_label = extract_data(nn_0_data, "predict_result", keep_id=True)
  148. nn_1_score_label = extract_data(nn_1_data, "predict_result", keep_id=True)
  149. if eval_type == "binary":
  150. # metric_nn = {
  151. # "score_diversity_ratio": classification_metric.Distribution.compute(nn_0_score_label, nn_1_score_label),
  152. # "ks_2samp": classification_metric.KSTest.compute(nn_0_score, nn_1_score),
  153. # "mAP_D_value": classification_metric.AveragePrecisionScore().compute(nn_0_score, nn_1_score, nn_0_label,
  154. # nn_1_label)}
  155. # metric_summary["distribution_metrics"] = {"homo_nn": metric_nn}
  156. if metric_summary is None:
  157. metric_summary = {}
  158. metric_summary["accuracy"] = (
  159. nn_0_score == nn_0_label).sum() / len(nn_0_label)
  160. # elif eval_type == "multi":
  161. # metric_nn = {
  162. # "score_diversity_ratio": classification_metric.Distribution.compute(nn_0_score_label, nn_1_score_label)}
  163. # metric_summary["distribution_metrics"] = {"homo_nn": metric_nn}
  164. data_summary = dict(
  165. train={"guest": guest_train_data["name"], **{f"host_{i}": host_train_data[i]["name"] for i in range(num_host)}},
  166. test={"guest": guest_train_data["name"], **{f"host_{i}": host_train_data[i]["name"] for i in range(num_host)}}
  167. )
  168. return data_summary, metric_summary
  169. if __name__ == "__main__":
  170. parser = argparse.ArgumentParser("BENCHMARK-QUALITY PIPELINE JOB")
  171. parser.add_argument("-config", type=str,
  172. help="config file")
  173. parser.add_argument("-param", type=str,
  174. help="config file for params")
  175. args = parser.parse_args()
  176. if args.config is not None:
  177. main(args.config, args.param)
  178. else:
  179. main(args.param)