from CNNetwork.CNNetwork import run_cnn_network, evaluate_cnn_network from Data.Preprocess import get_data from keras.models import load_model if __name__ == '__main__': flag = input('[-]请输入操作( ① 加载模型 ②训练模型 ):') model_name = input('[-]请输入模型名称:') train, target, test, test_target, normal_test, normal_test_target, abnormal_test, abnormal_test_target = get_data() if int(flag) == 2: epochs = int(input('[-]请输入训练次数:')) cnn = run_cnn_network(train, target, epochs) cnn.save(model_name + '.h5') else: cnn = load_model(model_name + '.h5') loss, accuracy = evaluate_cnn_network(cnn, test, test_target, 'KDDTest') normal_loss, normal_accuracy = evaluate_cnn_network(cnn, normal_test, normal_test_target, 'KDDTest_normal') abnormal_loss, abnormal_accuracy = evaluate_cnn_network(cnn, abnormal_test, abnormal_test_target, 'KDDTest_abnormal') print('[-]( 测试集: KDDTest ) 测试损失为:' + str(loss) + ' 准确度为:' + str(accuracy)) print('[-]( 测试集: KDDTest_normal ) 测试损失为:' + str(normal_loss) + ' 准确度为:' + str(normal_accuracy)) print('[-]( 测试集: KDDTest_abnormal ) 测试损失为:' + str(abnormal_loss) + ' 准确度为:' + str(abnormal_accuracy))