local-hetero_nn.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. import argparse
  2. import numpy as np
  3. import os
  4. import pandas
  5. from sklearn import metrics
  6. from pipeline.utils.tools import JobConfig
  7. import torch as t
  8. from torch import nn
  9. from pipeline import fate_torch_hook
  10. from torch.utils.data import DataLoader, TensorDataset
  11. from federatedml.nn.backend.utils.common import global_seed
  12. fate_torch_hook(t)
  13. class HeteroLocalModel(t.nn.Module):
  14. def __init__(self, guest_btn, host_btn, interactive, top):
  15. super().__init__()
  16. self.guest_btn = guest_btn
  17. self.host_btn = host_btn
  18. self.inter = interactive
  19. self.top = top
  20. def forward(self, x1, x2):
  21. return self.top(self.inter(self.guest_btn(x1), self.host_btn(x2)))
  22. def build(param, shape1, shape2, lr):
  23. global_seed(101)
  24. guest_bottom = t.nn.Sequential(
  25. nn.Linear(shape1, param["bottom_layer_units"]),
  26. nn.ReLU()
  27. )
  28. host_bottom = t.nn.Sequential(
  29. nn.Linear(shape2, param["bottom_layer_units"]),
  30. nn.ReLU()
  31. )
  32. interactive_layer = t.nn.InteractiveLayer(
  33. guest_dim=param["bottom_layer_units"],
  34. host_dim=param["bottom_layer_units"],
  35. host_num=1,
  36. out_dim=param["interactive_layer_units"])
  37. act = nn.Sigmoid() if param["top_layer_units"] == 1 else nn.Softmax(dim=1)
  38. top_layer = t.nn.Sequential(
  39. t.nn.Linear(
  40. param["interactive_layer_units"],
  41. param["top_layer_units"]),
  42. act)
  43. model = HeteroLocalModel(
  44. guest_bottom,
  45. host_bottom,
  46. interactive_layer,
  47. top_layer)
  48. opt = t.optim.Adam(model.parameters(), lr=lr)
  49. return model, opt
  50. def fit(epoch, model, optimizer, loss, batch_size, dataset):
  51. print(
  52. 'model is {}, loss is {}, optimizer is {}'.format(
  53. model,
  54. loss,
  55. optimizer))
  56. dl = DataLoader(dataset, batch_size=batch_size)
  57. for i in range(epoch):
  58. epoch_loss = 0
  59. for xa, xb, label in dl:
  60. optimizer.zero_grad()
  61. pred = model(xa, xb)
  62. l = loss(pred, label)
  63. epoch_loss += l.detach().numpy()
  64. l.backward()
  65. optimizer.step()
  66. print('epoch is {}, epoch loss is {}'.format(i, epoch_loss))
  67. def predict(model, Xa, Xb):
  68. pred_rs = model(Xb, Xa)
  69. return pred_rs.detach().numpy()
  70. def main(config="../../config.yaml", param="./hetero_nn_breast_config.yaml"):
  71. if isinstance(config, str):
  72. config = JobConfig.load_from_file(config)
  73. data_base_dir = config["data_base_dir"]
  74. else:
  75. data_base_dir = config.data_base_dir
  76. if isinstance(param, str):
  77. param = JobConfig.load_from_file(param)
  78. data_guest = param["data_guest"]
  79. data_host = param["data_host"]
  80. idx = param["idx"]
  81. label_name = param["label_name"]
  82. # prepare data
  83. Xb = pandas.read_csv(
  84. os.path.join(
  85. data_base_dir,
  86. data_guest),
  87. index_col=idx)
  88. Xa = pandas.read_csv(os.path.join(data_base_dir, data_host), index_col=idx)
  89. y = Xb[label_name]
  90. out = Xa.drop(Xb.index)
  91. Xa = Xa.drop(out.index)
  92. Xb = Xb.drop(label_name, axis=1)
  93. Xa = t.Tensor(Xa.values)
  94. Xb = t.Tensor(Xb.values)
  95. y = t.Tensor(y.values)
  96. if param["loss"] == "categorical_crossentropy":
  97. loss = t.nn.CrossEntropyLoss()
  98. y = y.type(t.int64).flatten()
  99. else:
  100. loss = t.nn.BCELoss()
  101. y = y.reshape((-1, 1))
  102. model, opt = build(
  103. param, Xb.shape[1], Xa.shape[1], lr=param['learning_rate'])
  104. dataset = TensorDataset(Xb, Xa, y)
  105. fit(epoch=param['epochs'], model=model, optimizer=opt,
  106. batch_size=param['batch_size'], dataset=dataset, loss=loss)
  107. eval_result = {}
  108. for metric in param["metrics"]:
  109. if metric.lower() == "auc":
  110. predict_y = predict(model, Xa, Xb)
  111. auc = metrics.roc_auc_score(y, predict_y)
  112. eval_result["auc"] = auc
  113. elif metric == "accuracy":
  114. predict_y = np.argmax(predict(model, Xa, Xb), axis=1)
  115. acc = metrics.accuracy_score(
  116. y_true=y.detach().numpy(), y_pred=predict_y)
  117. eval_result["accuracy"] = acc
  118. print(eval_result)
  119. data_summary = {}
  120. return data_summary, eval_result
  121. if __name__ == "__main__":
  122. parser = argparse.ArgumentParser("BENCHMARK-QUALITY SKLEARN JOB")
  123. parser.add_argument("-config", type=str,
  124. help="config file")
  125. parser.add_argument("-param", type=str,
  126. help="config file for params")
  127. args = parser.parse_args()
  128. if args.config is not None:
  129. main(args.config, args.param)
  130. main()