local-homo_nn.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import argparse
  17. import pathlib
  18. import numpy as np
  19. import torch as t
  20. from torch.utils.data import DataLoader, TensorDataset
  21. import pandas
  22. from pipeline.utils.tools import JobConfig
  23. from federatedml.nn.backend.utils.common import global_seed
  24. dataset = {
  25. "vehicle": {
  26. "guest": "examples/data/vehicle_scale_homo_guest.csv",
  27. "host": "examples/data/vehicle_scale_homo_host.csv",
  28. },
  29. "breast": {
  30. "guest": "examples/data/breast_homo_guest.csv",
  31. "host": "examples/data/breast_homo_host.csv",
  32. },
  33. }
  34. def fit(epoch, model, optimizer, loss, batch_size, dataset):
  35. print(
  36. 'model is {}, loss is {}, optimizer is {}'.format(
  37. model,
  38. loss,
  39. optimizer))
  40. dl = DataLoader(dataset, batch_size=batch_size)
  41. for i in range(epoch):
  42. epoch_loss = 0
  43. for feat, label in dl:
  44. optimizer.zero_grad()
  45. pred = model(feat)
  46. l = loss(pred, label)
  47. epoch_loss += l.detach().numpy()
  48. l.backward()
  49. optimizer.step()
  50. print('epoch is {}, epoch loss is {}'.format(i, epoch_loss))
  51. def compute_acc(pred, label, is_multy):
  52. if is_multy:
  53. pred = pred.argmax(axis=1)
  54. else:
  55. pred = (pred > 0.5) + 0
  56. return float((pred == label).sum() / len(label))
  57. def main(config="../../config.yaml", param="param_conf.yaml"):
  58. if isinstance(param, str):
  59. param = JobConfig.load_from_file(param)
  60. if isinstance(config, str):
  61. config = JobConfig.load_from_file(config)
  62. data_base_dir = config["data_base_dir"]
  63. else:
  64. data_base_dir = config.data_base_dir
  65. epoch = param["epoch"]
  66. lr = param["lr"]
  67. batch_size = param.get("batch_size", -1)
  68. is_multy = param["is_multy"]
  69. data = dataset[param.get("dataset", "vehicle")]
  70. global_seed(123)
  71. if is_multy:
  72. loss = t.nn.CrossEntropyLoss()
  73. else:
  74. loss = t.nn.BCELoss()
  75. data_path = pathlib.Path(data_base_dir)
  76. data_with_label = pandas.concat(
  77. [
  78. pandas.read_csv(data_path.joinpath(data["guest"]), index_col=0),
  79. pandas.read_csv(data_path.joinpath(data["host"]), index_col=0),
  80. ]
  81. ).values
  82. data = t.Tensor(data_with_label[:, 1:])
  83. labels = t.Tensor(data_with_label[:, 0])
  84. if is_multy:
  85. labels = labels.type(t.int64)
  86. else:
  87. labels = labels.reshape((-1, 1))
  88. ds = TensorDataset(data, labels)
  89. input_shape = data.shape[1]
  90. output_shape = 4 if is_multy else 1
  91. out_act = t.nn.Softmax(dim=1) if is_multy else t.nn.Sigmoid()
  92. model = t.nn.Sequential(
  93. t.nn.Linear(input_shape, 16),
  94. t.nn.ReLU(),
  95. t.nn.Linear(16, output_shape),
  96. out_act
  97. )
  98. if batch_size < 0:
  99. batch_size = len(data_with_label)
  100. optimizer = t.optim.Adam(model.parameters(), lr=lr)
  101. fit(epoch, model, optimizer, loss, batch_size, ds)
  102. pred_rs = model(data)
  103. acc = compute_acc(pred_rs, labels, is_multy)
  104. metric_summary = {"accuracy": acc}
  105. print(metric_summary)
  106. data_summary = {}
  107. return data_summary, metric_summary