main.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import numpy as np
  2. from tensorflow.keras.models import load_model
  3. from Data.Preprocess import get_data
  4. from GenerateAbnormalData import generate_abnormal_data
  5. from IDS.CNN.CNNetwork import run_cnn_network, evaluate_cnn_network
  6. from Network.BiGAN.RunBiGanNetwork import run_bigan
  7. from Network.GAN.RunNetwork import run_network
  8. def get_ids_train_data():
  9. data = get_data('train')
  10. _train_data = data['all'][0]
  11. _train_label = data['all'][1]
  12. labels = ['label-dos', 'label-normal', 'label-probe', 'label-r2l', 'label-u2r']
  13. for label in ['normal', 'u2r', 'r2l', 'probe', 'dos']:
  14. print('[-]' + label + '数据量为:' + str(data[label][0].shape[0]))
  15. while True:
  16. model_name = input('[-]请输入生成器模型名称(输入go开始训练IDS):')
  17. if model_name == 'go':
  18. break
  19. label_name = input('[-]请输入生成器数据类型(dos,normal,probe,r2l,u2r):')
  20. try:
  21. generator = load_model('Models/Gan/' + model_name + '.h5')
  22. except Exception as e:
  23. print('[-]没有此模型,名字输错啦')
  24. num = int(input('[-]请输入生成数据量:'))
  25. fake_abnormal_data_x = generate_abnormal_data(generator, num)
  26. _train_data = np.append(_train_data, fake_abnormal_data_x, axis=0)
  27. fake_abnormal_data_y = np.zeros((fake_abnormal_data_x.shape[0], 5))
  28. for i in range(5):
  29. if labels[i].endswith(label_name):
  30. fake_abnormal_data_y[::, i] = 1
  31. _train_label = np.append(_train_label, fake_abnormal_data_y, axis=0)
  32. return _train_data, _train_label
  33. def train_gan():
  34. type_of_gan = input('[-]请输入神经网络模型类型( ① GAN ② BIGAN ):')
  35. model_name = input('[-]请输入模型名称:')
  36. data_class = input('[-]请输入需要生成的数据:')
  37. data = get_data('train')[data_class]
  38. epochs = int(input('[-]请输入训练次数:'))
  39. if int(type_of_gan) == 1:
  40. generator = run_network(data[0], data[1], epochs)
  41. generator.save('Models/Gan/' + 'gan_' + model_name + '_' + data_class + '.h5')
  42. elif int(type_of_gan) == 2:
  43. generator = run_bigan(data[0], data[1], epochs)
  44. generator.save('Models/Gan/' + 'bigan_' + model_name + '_' + data_class + '.h5')
  45. def train_ids_without_gan():
  46. model_name = input('[-]请输入模型名称:')
  47. data = get_data('train')['all']
  48. epochs = int(input('[-]请输入训练次数:'))
  49. cnn = run_cnn_network(data[0], data[1], epochs)
  50. cnn.save('Models/IDS/' + model_name + '.h5')
  51. def train_ids_with_gan():
  52. train_data, train_label = get_ids_train_data()
  53. model_name = input('[-]请输入IDS模型名称:')
  54. epochs = int(input('[-]请输入训练次数:'))
  55. cnn = run_cnn_network(train_data, train_label, epochs)
  56. cnn.save('Models/IDS/' + model_name + '.h5')
  57. def load_and_test_ids():
  58. model_name = input('[-]请输入IDS模型名称:')
  59. cnn = load_model('Models/IDS/' + model_name + '.h5')
  60. data = get_data('test')
  61. losses = []
  62. accuracies = []
  63. flags = []
  64. while True:
  65. opt_flag = input('[-]请输入需要测试的样本集(all,normal,u2r,r2l,probe):')
  66. flags.append(opt_flag)
  67. try:
  68. loss, accuracy = evaluate_cnn_network(cnn, data[opt_flag][0], data[opt_flag][1], 'KDD' + opt_flag)
  69. losses.append(loss)
  70. accuracies.append(accuracy)
  71. for i in range(len(losses)):
  72. print('[-]( 测试集: KDD-' + flags[i] + ' ) 测试损失为:' + str(losses[i]) + ' 准确度为:' + str(accuracies[i]))
  73. except KeyError as e:
  74. break
  75. def main():
  76. flag = input('[-]请输入操作( ① GAN模型 ② IDS模型 ):')
  77. if int(flag) == 1:
  78. train_gan()
  79. elif int(flag) == 2:
  80. flag_for_ids = input('[-] ① 单独训练IDS模型 ② 使用GAN生成数据并训练IDS模型 ③ 加载并测试IDS模型 :')
  81. if int(flag_for_ids) == 1:
  82. train_ids_without_gan()
  83. elif int(flag_for_ids) == 2:
  84. train_ids_with_gan()
  85. elif int(flag_for_ids) == 3:
  86. load_and_test_ids()
  87. if __name__ == '__main__':
  88. main()