selector.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. import numpy as np
  19. from federatedml.nn.hetero.strategy.comparision import Comparision
  20. class RelativeSelector(object):
  21. def __init__(self, max_size=None, beta=1, random_state=None, min_prob=0):
  22. self._comparision = Comparision(size=max_size)
  23. self._beta = beta
  24. self._min_prob = min_prob
  25. np.random.seed(random_state)
  26. def select_batch_sample(self, samples):
  27. select_ret = [False] * len(samples)
  28. for sample in samples:
  29. self._comparision.add(sample)
  30. for idx, sample in enumerate(samples):
  31. select_ret[idx] = max(
  32. self._min_prob, np.power(
  33. np.random.uniform(
  34. 0, 1), self._beta)) <= self._comparision.get_rate(sample)
  35. return select_ret
  36. class SelectorFactory(object):
  37. @staticmethod
  38. def get_selector(
  39. method,
  40. selective_size,
  41. beta=1,
  42. random_rate=None,
  43. min_prob=0):
  44. if not method:
  45. return None
  46. elif method == "relative":
  47. return RelativeSelector(
  48. selective_size,
  49. beta,
  50. random_state=random_rate,
  51. min_prob=min_prob)
  52. else:
  53. raise ValueError("Back Propagation Selector {} not supported yet")