main.py 1.4 KB

123456789101112131415161718192021
  1. from CNNetwork.CNNetwork import run_cnn_network, evaluate_cnn_network
  2. from Data.Preprocess import get_data
  3. from keras.models import load_model
  4. if __name__ == '__main__':
  5. flag = input('[-]请输入操作( ① 加载模型 ②训练模型 ):')
  6. model_name = input('[-]请输入模型名称:')
  7. train, target, test, test_target, normal_test, normal_test_target, abnormal_test, abnormal_test_target = get_data()
  8. if int(flag) == 2:
  9. epochs = int(input('[-]请输入训练次数:'))
  10. cnn = run_cnn_network(train, target, epochs)
  11. cnn.save(model_name + '.h5')
  12. else:
  13. cnn = load_model(model_name + '.h5')
  14. loss, accuracy = evaluate_cnn_network(cnn, test, test_target, 'KDDTest')
  15. normal_loss, normal_accuracy = evaluate_cnn_network(cnn, normal_test, normal_test_target, 'KDDTest_normal')
  16. abnormal_loss, abnormal_accuracy = evaluate_cnn_network(cnn, abnormal_test, abnormal_test_target,
  17. 'KDDTest_abnormal')
  18. print('[-]( 测试集: KDDTest ) 测试损失为:' + str(loss) + ' 准确度为:' + str(accuracy))
  19. print('[-]( 测试集: KDDTest_normal ) 测试损失为:' + str(normal_loss) + ' 准确度为:' + str(normal_accuracy))
  20. print('[-]( 测试集: KDDTest_abnormal ) 测试损失为:' + str(abnormal_loss) + ' 准确度为:' + str(abnormal_accuracy))