local-hetero_nn.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import argparse
  2. import numpy as np
  3. import os
  4. from tensorflow import keras
  5. import pandas
  6. import tensorflow as tf
  7. from tensorflow.keras.utils import to_categorical
  8. from tensorflow.keras import optimizers
  9. from sklearn import metrics
  10. from pipeline.utils.tools import JobConfig
  11. from sklearn.preprocessing import LabelEncoder
  12. import torch as t
  13. from torch import nn
  14. from torch.utils.data import Dataset, DataLoader
  15. import tqdm
  16. from pipeline import fate_torch_hook
  17. fate_torch_hook(t)
  18. class TestModel(t.nn.Module):
  19. def __init__(self, guest_input_shape, host_input_shape):
  20. super(TestModel, self).__init__()
  21. self.guest_bottom = t.nn.Sequential(
  22. nn.Linear(guest_input_shape, 10, True),
  23. nn.ReLU(),
  24. nn.Linear(10, 8, True),
  25. nn.ReLU()
  26. )
  27. self.host_bottom = t.nn.Sequential(
  28. nn.Linear(host_input_shape, 10, True),
  29. nn.ReLU(),
  30. nn.Linear(10, 8, True),
  31. nn.ReLU()
  32. )
  33. self.inter_a, self.inter_b = t.nn.Linear(8, 4, True), t.nn.Linear(8, 4, True)
  34. self.top_model_guest = t.nn.Sequential(
  35. nn.Linear(4, 1, True),
  36. nn.Sigmoid()
  37. )
  38. def forward(self, data):
  39. x_guest, x_host = data[0].type(t.float), data[1].type(t.float)
  40. guest_fw = self.inter_a(self.guest_bottom(x_guest))
  41. host_fw = self.inter_b(self.host_bottom(x_host))
  42. out = self.top_model_guest(guest_fw + host_fw)
  43. return out
  44. def predict(self, data):
  45. rs = self.forward(data)
  46. return rs.detach().numpy()
  47. class TestDataset(Dataset):
  48. def __init__(self, guest_data, host_data, label):
  49. super(TestDataset, self).__init__()
  50. self.g = guest_data
  51. self.h = host_data
  52. self.l = label
  53. def __getitem__(self, idx):
  54. return self.g[idx], self.h[idx], self.l[idx]
  55. def __len__(self):
  56. return len(self.l)
  57. def build(param, shape1, shape2):
  58. return TestModel(shape1, shape2)
  59. def main(config="./config.yaml", param="./hetero_nn_breast_config.yaml"):
  60. try:
  61. if isinstance(config, str):
  62. config = JobConfig.load_from_file(config)
  63. data_base_dir = config["data_base_dir"]
  64. else:
  65. data_base_dir = config.data_base_dir
  66. if isinstance(param, str):
  67. param = JobConfig.load_from_file(param)
  68. data_guest = param["data_guest"]
  69. data_host = param["data_host"]
  70. idx = param["idx"]
  71. label_name = param["label_name"]
  72. # prepare data
  73. Xb = pandas.read_csv(os.path.join(data_base_dir, data_guest), index_col=idx)
  74. Xa = pandas.read_csv(os.path.join(data_base_dir, data_host), index_col=idx)
  75. y = Xb[label_name]
  76. out = Xa.drop(Xb.index)
  77. Xa = Xa.drop(out.index)
  78. Xb = Xb.drop(label_name, axis=1)
  79. # torch model
  80. model = build(param, Xb.shape[1], Xa.shape[1])
  81. Xb = t.Tensor(Xb.values)
  82. Xa = t.Tensor(Xa.values)
  83. y = t.Tensor(y.values)
  84. dataset = TestDataset(Xb, Xa, y)
  85. batch_size = len(dataset) if param['batch_size'] == -1 else param['batch_size']
  86. dataloader = DataLoader(dataset, batch_size=batch_size)
  87. optimizer = t.optim.Adam(lr=param['learning_rate']).to_torch_instance(model.parameters())
  88. if param['eval_type'] == 'binary':
  89. loss_fn = t.nn.BCELoss()
  90. for i in tqdm.tqdm(range(param['epochs'])):
  91. for gd, hd, label in dataloader:
  92. optimizer.zero_grad()
  93. pred = model([gd, hd])
  94. loss = loss_fn(pred.flatten(), label.type(t.float32))
  95. loss.backward()
  96. optimizer.step()
  97. eval_result = {}
  98. for metric in param["metrics"]:
  99. if metric.lower() == "auc":
  100. predict_y = model.predict([Xb, Xa])
  101. auc = metrics.roc_auc_score(y, predict_y)
  102. eval_result["auc"] = auc
  103. elif metric == "accuracy":
  104. predict_y = np.argmax(model.predict([Xb, Xa]), axis=1)
  105. predict_y = label_encoder.inverse_transform(predict_y)
  106. acc = metrics.accuracy_score(y_true=labels, y_pred=predict_y)
  107. eval_result["accuracy"] = acc
  108. data_summary = {}
  109. except Exception as e:
  110. print(e)
  111. return data_summary, eval_result
  112. if __name__ == "__main__":
  113. parser = argparse.ArgumentParser("BENCHMARK-QUALITY SKLEARN JOB")
  114. parser.add_argument("-config", type=str,
  115. help="config file")
  116. parser.add_argument("-param", type=str,
  117. help="config file for params")
  118. args = parser.parse_args()
  119. if args.config is not None:
  120. main(args.config, args.param)
  121. else:
  122. main()