scatter.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. class Scatter(object):
  17. def __init__(self, host_variable, guest_variable):
  18. """
  19. scatter values from guest and hosts
  20. Args:
  21. host_variable: a variable represents `Host -> Arbiter`
  22. guest_variable: a variable represent `Guest -> Arbiter`
  23. Examples:
  24. >>> from federatedml.framework.homo.util import scatter
  25. >>> s = scatter.Scatter(host_variable, guest_variable)
  26. >>> for v in s.get():
  27. print(v)
  28. """
  29. self._host_variable = host_variable
  30. self._guest_variable = guest_variable
  31. def get(self, suffix=tuple(), host_ids=None):
  32. """
  33. create a generator of values from guest and hosts.
  34. Args:
  35. suffix: tag suffix
  36. host_ids: ids of hosts to get value from.
  37. If None provided, get values from all hosts.
  38. If a list of int provided, get values from all hosts listed.
  39. Returns:
  40. a generator of scatted values
  41. Raises:
  42. if host_ids is neither None nor a list of int, ValueError raised
  43. """
  44. yield self._guest_variable.get(idx=0, suffix=suffix)
  45. if host_ids is None:
  46. host_ids = -1
  47. for ret in self._host_variable.get(idx=host_ids, suffix=suffix):
  48. yield ret