123456789101112131415161718192021 |
- from CNNetwork.CNNetwork import run_cnn_network, evaluate_cnn_network
- from Data.Preprocess import get_data
- from keras.models import load_model
- if __name__ == '__main__':
- flag = input('[-]请输入操作( ① 加载模型 ②训练模型 ):')
- model_name = input('[-]请输入模型名称:')
- train, target, test, test_target, normal_test, normal_test_target, abnormal_test, abnormal_test_target = get_data()
- if int(flag) == 2:
- epochs = int(input('[-]请输入训练次数:'))
- cnn = run_cnn_network(train, target, epochs)
- cnn.save(model_name + '.h5')
- else:
- cnn = load_model(model_name + '.h5')
- loss, accuracy = evaluate_cnn_network(cnn, test, test_target, 'KDDTest')
- normal_loss, normal_accuracy = evaluate_cnn_network(cnn, normal_test, normal_test_target, 'KDDTest_normal')
- abnormal_loss, abnormal_accuracy = evaluate_cnn_network(cnn, abnormal_test, abnormal_test_target,
- 'KDDTest_abnormal')
- print('[-]( 测试集: KDDTest ) 测试损失为:' + str(loss) + ' 准确度为:' + str(accuracy))
- print('[-]( 测试集: KDDTest_normal ) 测试损失为:' + str(normal_loss) + ' 准确度为:' + str(normal_accuracy))
- print('[-]( 测试集: KDDTest_abnormal ) 测试损失为:' + str(abnormal_loss) + ' 准确度为:' + str(abnormal_accuracy))
|