Preprocess.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import pandas as pd
  2. import numpy as np
  3. from sklearn.preprocessing import MinMaxScaler
  4. def get_cols():
  5. with open('Data/kddcup.names', 'r') as infile:
  6. kdd_names = infile.readlines()
  7. kdd_cols = [x.split(':')[0] for x in kdd_names[1:]]
  8. kdd_cols += ['class', 'difficulty']
  9. return kdd_cols
  10. def get_attack_map():
  11. attack_map = [x.strip().split(',') for x in open('Data/AttackTypes.csv', 'r')]
  12. attack_map = {k: v for (k, v) in attack_map}
  13. return attack_map
  14. def cat_encode(df, col):
  15. return pd.concat([df.drop(col, axis=1), pd.get_dummies(df[col].values)], axis=1)
  16. def log_trns(df, col):
  17. return df[col].apply(np.log1p)
  18. # 数据归一化
  19. def rescale_features(data):
  20. min_max_scaler = MinMaxScaler()
  21. data = min_max_scaler.fit_transform(data)
  22. return data
  23. def get_normal_and_abnormal_data(kdd_t):
  24. normal_kdd_t = kdd_t[kdd_t['class'].isin(['normal'])]
  25. abnormal_kdd_t = kdd_t[~kdd_t['class'].isin(['normal'])]
  26. test_target = kdd_t['class']
  27. test_target = pd.get_dummies(test_target)
  28. normal_kdd_t.pop('difficulty')
  29. abnormal_kdd_t.pop('difficulty')
  30. normal_test_target = normal_kdd_t.pop('class')
  31. abnormal_test_target = abnormal_kdd_t.pop('class')
  32. normal_test_target = pd.get_dummies(normal_test_target)
  33. abnormal_test_target = pd.get_dummies(abnormal_test_target)
  34. for col in test_target.columns:
  35. if col not in normal_test_target.columns:
  36. normal_test_target[col] = 0
  37. normal_test_target = normal_test_target[test_target.columns]
  38. for col in test_target.columns:
  39. if col not in abnormal_test_target.columns:
  40. abnormal_test_target[col] = 0
  41. abnormal_test_target = abnormal_test_target[test_target.columns]
  42. normal_test = normal_kdd_t.values
  43. normal_test_target = normal_test_target.values
  44. normal_test = rescale_features(normal_test)
  45. abnormal_test = abnormal_kdd_t.values
  46. abnormal_test_target = abnormal_test_target.values
  47. abnormal_test = rescale_features(abnormal_test)
  48. return normal_test, normal_test_target, abnormal_test, abnormal_test_target
  49. def get_data():
  50. kdd_cols = get_cols()
  51. kdd = pd.read_csv('Data/KDDTrain+.csv', names=kdd_cols)
  52. kdd_t = pd.read_csv('Data/KDDTest+.csv', names=kdd_cols)
  53. kdd_cols = [kdd.columns[0]] + sorted(list(set(kdd.protocol_type.values))) + sorted(
  54. list(set(kdd.service.values))) + sorted(list(set(kdd.flag.values))) + kdd.columns[4:].tolist()
  55. attack_map = get_attack_map()
  56. kdd['class'] = kdd['class'].replace(attack_map)
  57. kdd_t['class'] = kdd_t['class'].replace(attack_map)
  58. cat_lst = ['protocol_type', 'service', 'flag']
  59. for col in cat_lst:
  60. kdd = cat_encode(kdd, col)
  61. kdd_t = cat_encode(kdd_t, col)
  62. log_lst = ['duration', 'src_bytes', 'dst_bytes']
  63. for col in log_lst:
  64. kdd[col] = log_trns(kdd, col)
  65. kdd_t[col] = log_trns(kdd_t, col)
  66. kdd = kdd[kdd_cols]
  67. # 多类别转0、1编码时,kdd_t的类别比kdd少
  68. for col in kdd_cols:
  69. if col not in kdd_t.columns:
  70. kdd_t[col] = 0
  71. kdd_t = kdd_t[kdd_cols]
  72. normal_test, normal_test_target, abnormal_test, abnormal_test_target = get_normal_and_abnormal_data(kdd_t)
  73. kdd.pop('difficulty')
  74. target = kdd.pop('class')
  75. kdd_t.pop('difficulty')
  76. test_target = kdd_t.pop('class')
  77. target = pd.get_dummies(target)
  78. test_target = pd.get_dummies(test_target)
  79. # for idx, col in enumerate(list(test_target.columns)):
  80. # print(idx, col)
  81. train = kdd.values
  82. test = kdd_t.values
  83. target = target.values
  84. test_target = test_target.values
  85. train = rescale_features(train)
  86. test = rescale_features(test)
  87. return train, target, test, test_target, normal_test, normal_test_target, abnormal_test, abnormal_test_target