subsample.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import numpy as np
  2. from fate_arch.session import computing_session
  3. # TODO
  4. def random_sampling():
  5. """
  6. Normal random row subsample
  7. """
  8. pass
  9. def goss_sampling(grad_and_hess, top_rate, other_rate):
  10. """
  11. sampling method introduced in LightGBM
  12. """
  13. sample_num = grad_and_hess.count()
  14. g_h_generator = grad_and_hess.collect()
  15. id_list, g_list, h_list = [], [], []
  16. for id_, g_h in g_h_generator:
  17. id_list.append(id_)
  18. g_list.append(g_h[0])
  19. h_list.append(g_h[1])
  20. id_type = type(id_list[0])
  21. id_list = np.array(id_list)
  22. g_arr = np.array(g_list).astype(np.float64)
  23. h_arr = np.array(h_list).astype(np.float64)
  24. g_sum_arr = np.abs(g_arr).sum(axis=1) # if it is multi-classification case, we need to sum g
  25. abs_g_list_arr = g_sum_arr
  26. sorted_idx = np.argsort(-abs_g_list_arr, kind='stable') # stable sample result
  27. a_part_num = int(sample_num * top_rate)
  28. b_part_num = int(sample_num * other_rate)
  29. if a_part_num == 0 or b_part_num == 0:
  30. raise ValueError('subsampled result is 0: top sample {}, other sample {}'.format(a_part_num, b_part_num))
  31. # index of a part
  32. a_sample_idx = sorted_idx[:a_part_num]
  33. # index of b part
  34. rest_sample_idx = sorted_idx[a_part_num:]
  35. b_sample_idx = np.random.choice(rest_sample_idx, size=b_part_num, replace=False)
  36. # small gradient sample weights
  37. amplify_weights = (1 - top_rate) / other_rate
  38. g_arr[b_sample_idx] *= amplify_weights
  39. h_arr[b_sample_idx] *= amplify_weights
  40. # get selected sample
  41. a_idx_set, b_idx_set = set(list(a_sample_idx)), set(list(b_sample_idx))
  42. idx_set = a_idx_set.union(b_idx_set)
  43. selected_idx = np.array(list(idx_set))
  44. selected_g, selected_h = g_arr[selected_idx], h_arr[selected_idx]
  45. selected_id = id_list[selected_idx]
  46. data = [(id_type(id_), (g, h)) for id_, g, h in zip(selected_id, selected_g, selected_h)]
  47. new_g_h_table = computing_session.parallelize(data, include_key=True, partition=grad_and_hess.partitions)
  48. return new_g_h_table