sampler_test.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import unittest
  17. import numpy as np
  18. from fate_arch.session import computing_session as session
  19. from federatedml.feature.instance import Instance
  20. from federatedml.feature.sampler import RandomSampler
  21. from federatedml.feature.sampler import StratifiedSampler
  22. from federatedml.util import consts
  23. class TestRandomSampler(unittest.TestCase):
  24. def setUp(self):
  25. session.init("test_random_sampler")
  26. self.data = [(i * 10 + 5, i * i) for i in range(100)]
  27. self.table = session.parallelize(self.data, include_key=True, partition=16)
  28. self.data_to_trans = [(i * 10 + 5, i * i * i) for i in range(100)]
  29. self.table_trans = session.parallelize(self.data_to_trans, include_key=True, partition=16)
  30. def test_downsample(self):
  31. sampler = RandomSampler(fraction=0.3, method="downsample")
  32. tracker = TrackerMock()
  33. sampler.set_tracker(tracker)
  34. sample_data, sample_ids = sampler.sample(self.table)
  35. self.assertTrue(sample_data.count() > 25 and sample_data.count() < 35)
  36. self.assertTrue(len(set(sample_ids)) == len(sample_ids))
  37. new_data = list(sample_data.collect())
  38. data_dict = dict(self.data)
  39. for id, value in new_data:
  40. self.assertTrue(id in data_dict)
  41. self.assertTrue(np.abs(value - data_dict.get(id)) < consts.FLOAT_ZERO)
  42. trans_sampler = RandomSampler(method="downsample")
  43. trans_sampler.set_tracker(tracker)
  44. trans_sample_data = trans_sampler.sample(self.table_trans, sample_ids)
  45. trans_data = list(trans_sample_data.collect())
  46. trans_sample_ids = [id for (id, value) in trans_data]
  47. data_to_trans_dict = dict(self.data_to_trans)
  48. sample_id_mapping = dict(zip(sample_ids, range(len(sample_ids))))
  49. self.assertTrue(len(trans_data) == len(sample_ids))
  50. self.assertTrue(set(trans_sample_ids) == set(sample_ids))
  51. for id, value in trans_data:
  52. self.assertTrue(id in sample_id_mapping)
  53. self.assertTrue(np.abs(value - data_to_trans_dict.get(id)) < consts.FLOAT_ZERO)
  54. def test_upsample(self):
  55. sampler = RandomSampler(fraction=3, method="upsample")
  56. tracker = TrackerMock()
  57. sampler.set_tracker(tracker)
  58. sample_data, sample_ids = sampler.sample(self.table)
  59. self.assertTrue(sample_data.count() > 250 and sample_data.count() < 350)
  60. data_dict = dict(self.data)
  61. new_data = list(sample_data.collect())
  62. for id, value in new_data:
  63. self.assertTrue(np.abs(value - data_dict[sample_ids[id]]) < consts.FLOAT_ZERO)
  64. trans_sampler = RandomSampler(method="upsample")
  65. trans_sampler.set_tracker(tracker)
  66. trans_sample_data = trans_sampler.sample(self.table_trans, sample_ids)
  67. trans_data = list(trans_sample_data.collect())
  68. data_to_trans_dict = dict(self.data_to_trans)
  69. self.assertTrue(len(trans_data) == len(sample_ids))
  70. for id, value in trans_data:
  71. self.assertTrue(np.abs(value - data_to_trans_dict[sample_ids[id]]) < consts.FLOAT_ZERO)
  72. def tearDown(self):
  73. session.stop()
  74. class TestStratifiedSampler(unittest.TestCase):
  75. def setUp(self):
  76. session.init("test_stratified_sampler")
  77. self.data = []
  78. self.data_to_trans = []
  79. for i in range(1000):
  80. self.data.append((i, Instance(label=i % 4, features=i * i)))
  81. self.data_to_trans.append((i, Instance(features=i ** 3)))
  82. self.table = session.parallelize(self.data, include_key=True, partition=16)
  83. self.table_trans = session.parallelize(self.data_to_trans, include_key=True, partition=16)
  84. def test_downsample(self):
  85. fractions = [(0, 0.3), (1, 0.4), (2, 0.5), (3, 0.8)]
  86. sampler = StratifiedSampler(fractions=fractions, method="downsample")
  87. tracker = TrackerMock()
  88. sampler.set_tracker(tracker)
  89. sample_data, sample_ids = sampler.sample(self.table)
  90. count_label = [0 for i in range(4)]
  91. new_data = list(sample_data.collect())
  92. data_dict = dict(self.data)
  93. self.assertTrue(set(sample_ids) & set(data_dict.keys()) == set(sample_ids))
  94. for id, inst in new_data:
  95. count_label[inst.label] += 1
  96. self.assertTrue(type(id).__name__ == 'int' and id >= 0 and id < 1000)
  97. self.assertTrue(inst.label == self.data[id][1].label and inst.features == self.data[id][1].features)
  98. for i in range(4):
  99. self.assertTrue(np.abs(count_label[i] - 250 * fractions[i][1]) < 10)
  100. trans_sampler = StratifiedSampler(method="downsample")
  101. trans_sampler.set_tracker(tracker)
  102. trans_sample_data = trans_sampler.sample(self.table_trans, sample_ids)
  103. trans_data = list(trans_sample_data.collect())
  104. trans_sample_ids = [id for (id, value) in trans_data]
  105. data_to_trans_dict = dict(self.data_to_trans)
  106. self.assertTrue(set(trans_sample_ids) == set(sample_ids))
  107. for id, inst in trans_data:
  108. self.assertTrue(inst.features == data_to_trans_dict.get(id).features)
  109. def test_upsample(self):
  110. fractions = [(0, 1.3), (1, 0.5), (2, 0.8), (3, 9)]
  111. sampler = StratifiedSampler(fractions=fractions, method="upsample")
  112. tracker = TrackerMock()
  113. sampler.set_tracker(tracker)
  114. sample_data, sample_ids = sampler.sample(self.table)
  115. new_data = list(sample_data.collect())
  116. count_label = [0 for i in range(4)]
  117. data_dict = dict(self.data)
  118. for id, inst in new_data:
  119. count_label[inst.label] += 1
  120. self.assertTrue(type(id).__name__ == 'int' and id >= 0 and id < len(sample_ids))
  121. real_id = sample_ids[id]
  122. self.assertTrue(inst.label == self.data[real_id][1].label and
  123. inst.features == self.data[real_id][1].features)
  124. for i in range(4):
  125. self.assertTrue(np.abs(count_label[i] - 250 * fractions[i][1]) < 10)
  126. trans_sampler = StratifiedSampler(method="upsample")
  127. trans_sampler.set_tracker(tracker)
  128. trans_sample_data = trans_sampler.sample(self.table_trans, sample_ids)
  129. trans_data = (trans_sample_data.collect())
  130. trans_sample_ids = [id for (id, value) in trans_data]
  131. data_to_trans_dict = dict(self.data_to_trans)
  132. self.assertTrue(sorted(trans_sample_ids) == list(range(len(sample_ids))))
  133. for id, inst in trans_data:
  134. real_id = sample_ids[id]
  135. self.assertTrue(inst.features == data_to_trans_dict[real_id][1].features)
  136. def tearDown(self):
  137. session.stop()
  138. class TrackerMock(object):
  139. def log_component_summary(self, *args, **kwargs):
  140. pass
  141. def log_metric_data(self, *args, **kwargs):
  142. pass
  143. def set_metric_meta(self, *args, **kwargs):
  144. pass
  145. if __name__ == '__main__':
  146. unittest.main()