standard_scale_test.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. import copy
  2. import time
  3. import unittest
  4. import numpy as np
  5. from fate_arch.session import computing_session as session
  6. from sklearn.preprocessing import StandardScaler as SSL
  7. from federatedml.feature.feature_scale.standard_scale import StandardScale
  8. from federatedml.feature.instance import Instance
  9. from federatedml.param.scale_param import ScaleParam
  10. from federatedml.util.param_extract import ParamExtract
  11. class TestStandardScaler(unittest.TestCase):
  12. def setUp(self):
  13. self.test_data = [
  14. [0, 1.0, 10, 2, 3, 1],
  15. [1.0, 2, 9, 2, 4, 2],
  16. [0, 3.0, 8, 3, 3, 3],
  17. [1.0, 4, 7, 4, 4, 4],
  18. [1.0, 5, 6, 5, 5, 5],
  19. [1.0, 6, 5, 6, 6, -100],
  20. [0, 7.0, 4, 7, 7, 7],
  21. [0, 8, 3.0, 8, 6, 8],
  22. [0, 9, 2, 9.0, 9, 9],
  23. [0, 10, 1, 10.0, 10, 10]
  24. ]
  25. str_time = time.strftime("%Y%m%d%H%M%S", time.localtime())
  26. session.init(str_time)
  27. self.test_instance = []
  28. for td in self.test_data:
  29. self.test_instance.append(Instance(features=np.array(td)))
  30. self.table_instance = self.data_to_table(self.test_instance)
  31. self.table_instance.schema['header'] = ["fid" + str(i) for i in range(len(self.test_data[0]))]
  32. self.table_instance.schema['anonymous_header'] = [
  33. "guest_9999_x" + str(i) for i in range(len(self.test_data[0]))]
  34. def print_table(self, table):
  35. for v in (list(table.collect())):
  36. print(v[1].features)
  37. def data_to_table(self, data, partition=10):
  38. data_table = session.parallelize(data, include_key=False, partition=partition)
  39. return data_table
  40. def get_table_instance_feature(self, table_instance):
  41. res_list = []
  42. for k, v in list(table_instance.collect()):
  43. res_list.append(list(np.around(v.features, 4)))
  44. return res_list
  45. def get_scale_param(self):
  46. component_param = {
  47. "method": "standard_scale",
  48. "mode": "normal",
  49. "scale_col_indexes": [],
  50. "with_mean": True,
  51. "with_std": True,
  52. }
  53. scale_param = ScaleParam()
  54. param_extracter = ParamExtract()
  55. param_extracter.parse_param_from_config(scale_param, component_param)
  56. return scale_param
  57. # test with (with_mean=True, with_std=True):
  58. def test_fit1(self):
  59. scale_param = self.get_scale_param()
  60. standard_scaler = StandardScale(scale_param)
  61. fit_instance = standard_scaler.fit(self.table_instance)
  62. mean = standard_scaler.mean
  63. std = standard_scaler.std
  64. scaler = SSL()
  65. scaler.fit(self.test_data)
  66. self.assertListEqual(self.get_table_instance_feature(fit_instance),
  67. np.around(scaler.transform(self.test_data), 4).tolist())
  68. self.assertListEqual(list(np.around(mean, 4)), list(np.around(scaler.mean_, 4)))
  69. self.assertListEqual(list(np.around(std, 4)), list(np.around(scaler.scale_, 4)))
  70. # test with (with_mean=False, with_std=True):
  71. def test_fit2(self):
  72. scale_param = self.get_scale_param()
  73. scale_param.with_mean = False
  74. standard_scaler = StandardScale(scale_param)
  75. fit_instance = standard_scaler.fit(self.table_instance)
  76. mean = standard_scaler.mean
  77. std = standard_scaler.std
  78. scaler = SSL(with_mean=False)
  79. scaler.fit(self.test_data)
  80. self.assertListEqual(self.get_table_instance_feature(fit_instance),
  81. np.around(scaler.transform(self.test_data), 4).tolist())
  82. self.assertListEqual(list(np.around(mean, 4)), [0 for _ in mean])
  83. self.assertListEqual(list(np.around(std, 4)), list(np.around(scaler.scale_, 4)))
  84. # test with (with_mean=True, with_std=False):
  85. def test_fit3(self):
  86. scale_param = self.get_scale_param()
  87. scale_param.with_std = False
  88. standard_scaler = StandardScale(scale_param)
  89. fit_instance = standard_scaler.fit(self.table_instance)
  90. mean = standard_scaler.mean
  91. std = standard_scaler.std
  92. scaler = SSL(with_std=False)
  93. scaler.fit(self.test_data)
  94. self.assertListEqual(self.get_table_instance_feature(fit_instance),
  95. np.around(scaler.transform(self.test_data), 4).tolist())
  96. self.assertListEqual(list(np.around(mean, 4)), list(np.around(scaler.mean_, 4)))
  97. self.assertListEqual(list(np.around(std, 4)), [1 for _ in std])
  98. # test with (with_mean=False, with_std=False):
  99. def test_fit4(self):
  100. scale_param = self.get_scale_param()
  101. scale_param.with_std = False
  102. scale_param.with_mean = False
  103. standard_scaler = StandardScale(scale_param)
  104. fit_instance = standard_scaler.fit(self.table_instance)
  105. mean = standard_scaler.mean
  106. std = standard_scaler.std
  107. scaler = SSL(with_mean=False, with_std=False)
  108. scaler.fit(self.test_data)
  109. self.assertListEqual(self.get_table_instance_feature(fit_instance),
  110. np.around(scaler.transform(self.test_data), 4).tolist())
  111. self.assertEqual(mean, [0 for _ in range(len(self.test_data[0]))])
  112. self.assertEqual(std, [1 for _ in range(len(self.test_data[0]))])
  113. # test with (area="all", scale_column_idx=[], with_mean=True, with_std=True):
  114. def test_fit5(self):
  115. scale_param = self.get_scale_param()
  116. scale_param.scale_column_idx = []
  117. standard_scaler = StandardScale(scale_param)
  118. fit_instance = standard_scaler.fit(self.table_instance)
  119. mean = standard_scaler.mean
  120. std = standard_scaler.std
  121. scaler = SSL()
  122. scaler.fit(self.test_data)
  123. self.assertListEqual(self.get_table_instance_feature(fit_instance),
  124. np.around(scaler.transform(self.test_data), 4).tolist())
  125. self.assertListEqual(list(np.around(mean, 4)), list(np.around(scaler.mean_, 4)))
  126. self.assertListEqual(list(np.around(std, 4)), list(np.around(scaler.scale_, 4)))
  127. # test with (area="col", scale_column_idx=[], with_mean=True, with_std=True):
  128. def test_fit6(self):
  129. scale_param = self.get_scale_param()
  130. scale_param.scale_col_indexes = []
  131. standard_scaler = StandardScale(scale_param)
  132. fit_instance = standard_scaler.fit(self.table_instance)
  133. mean = standard_scaler.mean
  134. std = standard_scaler.std
  135. scaler = SSL()
  136. scaler.fit(self.test_data)
  137. self.assertListEqual(self.get_table_instance_feature(fit_instance),
  138. np.around(self.test_data, 4).tolist())
  139. self.assertListEqual(list(np.around(mean, 4)), list(np.around(scaler.mean_, 4)))
  140. self.assertListEqual(list(np.around(std, 4)), list(np.around(scaler.scale_, 4)))
  141. # test with (area="all", upper=2, lower=1, with_mean=False, with_std=False):
  142. def test_fit7(self):
  143. scale_param = self.get_scale_param()
  144. scale_param.scale_column_idx = []
  145. scale_param.feat_upper = 2
  146. scale_param.feat_lower = 1
  147. scale_param.with_mean = False
  148. scale_param.with_std = False
  149. standard_scaler = StandardScale(scale_param)
  150. fit_instance = standard_scaler.fit(self.table_instance)
  151. mean = standard_scaler.mean
  152. std = standard_scaler.std
  153. column_max_value = standard_scaler.column_max_value
  154. column_min_value = standard_scaler.column_min_value
  155. for i, line in enumerate(self.test_data):
  156. for j, value in enumerate(line):
  157. if value > 2:
  158. self.test_data[i][j] = 2
  159. elif value < 1:
  160. self.test_data[i][j] = 1
  161. self.assertListEqual(self.get_table_instance_feature(fit_instance),
  162. np.around(self.test_data, 4).tolist())
  163. self.assertEqual(mean, [0 for _ in range(len(self.test_data[0]))])
  164. self.assertEqual(std, [1 for _ in range(len(self.test_data[0]))])
  165. self.assertEqual(column_max_value, [1, 2, 2, 2, 2, 2])
  166. self.assertEqual(column_min_value, [1, 1, 1, 2, 2, 1])
  167. # test with (area="all", upper=[2,2,2,2,2,2], lower=[1,1,1,1,1,1], with_mean=False, with_std=False):
  168. def test_fit8(self):
  169. scale_param = self.get_scale_param()
  170. scale_param.scale_column_idx = []
  171. scale_param.feat_upper = [2, 2, 2, 2, 2, 2]
  172. scale_param.feat_lower = [1, 1, 1, 1, 1, 1]
  173. scale_param.with_mean = False
  174. scale_param.with_std = False
  175. standard_scaler = StandardScale(scale_param)
  176. fit_instance = standard_scaler.fit(self.table_instance)
  177. mean = standard_scaler.mean
  178. std = standard_scaler.std
  179. column_max_value = standard_scaler.column_max_value
  180. column_min_value = standard_scaler.column_min_value
  181. for i, line in enumerate(self.test_data):
  182. for j, value in enumerate(line):
  183. if value > 2:
  184. self.test_data[i][j] = 2
  185. elif value < 1:
  186. self.test_data[i][j] = 1
  187. self.assertListEqual(self.get_table_instance_feature(fit_instance),
  188. np.around(self.test_data, 4).tolist())
  189. self.assertEqual(mean, [0 for _ in range(len(self.test_data[0]))])
  190. self.assertEqual(std, [1 for _ in range(len(self.test_data[0]))])
  191. self.assertEqual(column_max_value, [1, 2, 2, 2, 2, 2])
  192. self.assertEqual(column_min_value, [1, 1, 1, 2, 2, 1])
  193. # test with (area="col", upper=[2,2,2,2,2,2], lower=[1,1,1,1,1,1],
  194. # scale_column_idx=[1,2,4], with_mean=True, with_std=True):
  195. def test_fit9(self):
  196. scale_column_idx = [1, 2, 4]
  197. scale_param = self.get_scale_param()
  198. scale_param.feat_upper = [2, 2, 2, 2, 2, 2]
  199. scale_param.feat_lower = [1, 1, 1, 1, 1, 1]
  200. scale_param.with_mean = True
  201. scale_param.with_std = True
  202. scale_param.scale_col_indexes = scale_column_idx
  203. standard_scaler = StandardScale(scale_param)
  204. fit_instance = standard_scaler.fit(self.table_instance)
  205. mean = standard_scaler.mean
  206. std = standard_scaler.std
  207. column_max_value = standard_scaler.column_max_value
  208. column_min_value = standard_scaler.column_min_value
  209. raw_data = copy.deepcopy(self.test_data)
  210. for i, line in enumerate(self.test_data):
  211. for j, value in enumerate(line):
  212. if j in scale_column_idx:
  213. if value > 2:
  214. self.test_data[i][j] = 2
  215. elif value < 1:
  216. self.test_data[i][j] = 1
  217. scaler = SSL(with_mean=True, with_std=True)
  218. scaler.fit(self.test_data)
  219. transform_data = np.around(scaler.transform(self.test_data), 4).tolist()
  220. for i, line in enumerate(transform_data):
  221. for j, cols in enumerate(line):
  222. if j not in scale_column_idx:
  223. transform_data[i][j] = raw_data[i][j]
  224. self.assertListEqual(self.get_table_instance_feature(fit_instance),
  225. transform_data)
  226. self.assertListEqual(list(np.around(mean, 6)), list(np.around(scaler.mean_, 6)))
  227. self.assertListEqual(list(np.around(std, 6)), list(np.around(scaler.scale_, 6)))
  228. self.assertEqual(column_max_value, [1, 2, 2, 10, 2, 10])
  229. self.assertEqual(column_min_value, [0, 1, 1, 2, 2, -100])
  230. raw_data_transform = standard_scaler.transform(self.table_instance)
  231. self.assertListEqual(self.get_table_instance_feature(fit_instance),
  232. self.get_table_instance_feature(raw_data_transform))
  233. # test with (mode="cap", area="col", upper=0.8, lower=0.2, scale_column_idx=[1,2,4], with_mean=True, with_std=True):
  234. def test_fit10(self):
  235. scale_column_idx = [1, 2, 4]
  236. scale_param = self.get_scale_param()
  237. scale_param.scale_col_indexes = []
  238. scale_param.feat_upper = 0.8
  239. scale_param.feat_lower = 0.2
  240. scale_param.with_mean = True
  241. scale_param.with_std = True
  242. scale_param.mode = "cap"
  243. scale_param.scale_col_indexes = scale_column_idx
  244. standard_scaler = StandardScale(scale_param)
  245. fit_instance = standard_scaler.fit(self.table_instance)
  246. mean = standard_scaler.mean
  247. std = standard_scaler.std
  248. column_max_value = standard_scaler.column_max_value
  249. column_min_value = standard_scaler.column_min_value
  250. gt_cap_lower_list = [0, 2, 2, 2, 3, 1]
  251. gt_cap_upper_list = [1, 8, 8, 8, 7, 8]
  252. raw_data = copy.deepcopy(self.test_data)
  253. for i, line in enumerate(self.test_data):
  254. for j, value in enumerate(line):
  255. if j in scale_column_idx:
  256. if value > gt_cap_upper_list[j]:
  257. self.test_data[i][j] = gt_cap_upper_list[j]
  258. elif value < gt_cap_lower_list[j]:
  259. self.test_data[i][j] = gt_cap_lower_list[j]
  260. scaler = SSL(with_mean=True, with_std=True)
  261. scaler.fit(self.test_data)
  262. transform_data = np.around(scaler.transform(self.test_data), 4).tolist()
  263. for i, line in enumerate(transform_data):
  264. for j, cols in enumerate(line):
  265. if j not in scale_column_idx:
  266. transform_data[i][j] = raw_data[i][j]
  267. self.assertListEqual(self.get_table_instance_feature(fit_instance),
  268. transform_data)
  269. self.assertEqual(column_max_value, gt_cap_upper_list)
  270. self.assertEqual(column_min_value, gt_cap_lower_list)
  271. self.assertListEqual(list(np.around(mean, 6)), list(np.around(scaler.mean_, 6)))
  272. self.assertListEqual(list(np.around(std, 6)), list(np.around(scaler.scale_, 6)))
  273. raw_data_transform = standard_scaler.transform(self.table_instance)
  274. self.assertListEqual(self.get_table_instance_feature(fit_instance),
  275. self.get_table_instance_feature(raw_data_transform))
  276. # test with (with_mean=True, with_std=True):
  277. def test_transform1(self):
  278. scale_param = self.get_scale_param()
  279. standard_scaler = StandardScale(scale_param)
  280. fit_instance = standard_scaler.fit(self.table_instance)
  281. transform_data = standard_scaler.transform(self.table_instance)
  282. self.assertListEqual(self.get_table_instance_feature(transform_data),
  283. self.get_table_instance_feature(fit_instance))
  284. # test with (with_mean=True, with_std=False):
  285. def test_transform2(self):
  286. scale_param = self.get_scale_param()
  287. scale_param.with_std = False
  288. standard_scaler = StandardScale(scale_param)
  289. fit_instance = standard_scaler.fit(self.table_instance)
  290. transform_data = standard_scaler.transform(self.table_instance)
  291. self.assertListEqual(self.get_table_instance_feature(transform_data),
  292. self.get_table_instance_feature(fit_instance))
  293. # test with (with_mean=False, with_std=True):
  294. def test_transform3(self):
  295. scale_param = self.get_scale_param()
  296. scale_param.with_mean = False
  297. standard_scaler = StandardScale(scale_param)
  298. fit_instance = standard_scaler.fit(self.table_instance)
  299. transform_data = standard_scaler.transform(self.table_instance)
  300. self.assertListEqual(self.get_table_instance_feature(transform_data),
  301. self.get_table_instance_feature(fit_instance))
  302. # test with (with_mean=False, with_std=False):
  303. def test_transform4(self):
  304. scale_param = self.get_scale_param()
  305. scale_param.with_mean = False
  306. scale_param.with_std = False
  307. standard_scaler = StandardScale(scale_param)
  308. fit_instance = standard_scaler.fit(self.table_instance)
  309. transform_data = standard_scaler.transform(self.table_instance)
  310. self.assertListEqual(self.get_table_instance_feature(transform_data),
  311. self.get_table_instance_feature(fit_instance))
  312. # test with (area='all', scale_column_idx=[], with_mean=False, with_std=False):
  313. def test_transform5(self):
  314. scale_param = self.get_scale_param()
  315. scale_param.with_mean = False
  316. scale_param.with_std = False
  317. scale_param.scale_column_idx = []
  318. standard_scaler = StandardScale(scale_param)
  319. fit_instance = standard_scaler.fit(self.table_instance)
  320. transform_data = standard_scaler.transform(self.table_instance)
  321. self.assertListEqual(self.get_table_instance_feature(transform_data),
  322. self.get_table_instance_feature(fit_instance))
  323. # test with (area='col', with_mean=[], with_std=False):
  324. def test_transform6(self):
  325. scale_param = self.get_scale_param()
  326. scale_param.with_mean = False
  327. scale_param.with_std = False
  328. scale_param.scale_column_idx = []
  329. standard_scaler = StandardScale(scale_param)
  330. fit_instance = standard_scaler.fit(self.table_instance)
  331. transform_data = standard_scaler.transform(self.table_instance)
  332. self.assertListEqual(self.get_table_instance_feature(transform_data),
  333. self.get_table_instance_feature(fit_instance))
  334. def test_cols_select_fit_and_transform(self):
  335. scale_param = self.get_scale_param()
  336. scale_param.scale_column_idx = [1, 2, 4]
  337. standard_scaler = StandardScale(scale_param)
  338. fit_data = standard_scaler.fit(self.table_instance)
  339. scale_column_idx = standard_scaler.scale_column_idx
  340. scaler = SSL(with_mean=True, with_std=True)
  341. scaler.fit(self.test_data)
  342. transform_data = np.around(scaler.transform(self.test_data), 4).tolist()
  343. for i, line in enumerate(transform_data):
  344. for j, cols in enumerate(line):
  345. if j not in scale_column_idx:
  346. transform_data[i][j] = self.test_data[i][j]
  347. self.assertListEqual(self.get_table_instance_feature(fit_data),
  348. transform_data)
  349. std_scale_transform_data = standard_scaler.transform(self.table_instance)
  350. self.assertListEqual(self.get_table_instance_feature(std_scale_transform_data),
  351. transform_data)
  352. def test_cols_select_fit_and_transform_repeat(self):
  353. scale_param = self.get_scale_param()
  354. scale_param.scale_column_idx = [1, 1, 2, 2, 4, 5, 5]
  355. standard_scaler = StandardScale(scale_param)
  356. fit_data = standard_scaler.fit(self.table_instance)
  357. scale_column_idx = standard_scaler.scale_column_idx
  358. scaler = SSL(with_mean=True, with_std=True)
  359. scaler.fit(self.test_data)
  360. transform_data = np.around(scaler.transform(self.test_data), 4).tolist()
  361. for i, line in enumerate(transform_data):
  362. for j, cols in enumerate(line):
  363. if j not in scale_column_idx:
  364. transform_data[i][j] = self.test_data[i][j]
  365. self.assertListEqual(self.get_table_instance_feature(fit_data),
  366. transform_data)
  367. std_scale_transform_data = standard_scaler.transform(self.table_instance)
  368. self.assertListEqual(self.get_table_instance_feature(std_scale_transform_data),
  369. transform_data)
  370. def tearDown(self):
  371. session.stop()
  372. if __name__ == "__main__":
  373. unittest.main()