test_utils.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import time
  17. import unittest
  18. import uuid
  19. from multiprocessing import Pool
  20. from fate_arch.computing import ComputingType
  21. from fate_arch.session import Session
  22. from federatedml.util import consts
  23. class TestBlocks(unittest.TestCase):
  24. def clean_tables(self):
  25. from fate_arch.session import computing_session as session
  26. session.init(job_id=self.job_id)
  27. try:
  28. session.cleanup("*", self.job_id, True)
  29. except EnvironmentError:
  30. pass
  31. try:
  32. session.cleanup("*", self.job_id, False)
  33. except EnvironmentError:
  34. pass
  35. def setUp(self) -> None:
  36. self.job_id = str(uuid.uuid1())
  37. def tearDown(self) -> None:
  38. self.clean_tables()
  39. @staticmethod
  40. def apply_func(func, job_id, role, num_hosts, ind, *args):
  41. partyid_map = dict(host=[9999 + i for i in range(num_hosts)], guest=[9999], arbiter=[9999])
  42. partyid = 9999
  43. if role == consts.HOST:
  44. partyid = 9999 + ind
  45. with Session() as session:
  46. session.init_computing(job_id, computing_type=ComputingType.STANDALONE)
  47. session.init_federation(federation_session_id=job_id,
  48. runtime_conf={"local": {"role": role, "party_id": partyid}, "role": partyid_map})
  49. return func(job_id, role, ind, *args)
  50. @staticmethod
  51. def run_test(func, job_id, num_hosts, *args):
  52. pool = Pool(num_hosts + 2)
  53. tasks = []
  54. for role, ind in [(consts.ARBITER, 0), (consts.GUEST, 0)] + [(consts.HOST, i) for i in range(num_hosts)]:
  55. tasks.append(
  56. pool.apply_async(func=TestBlocks.apply_func,
  57. args=(func, job_id, role, num_hosts, ind, *args))
  58. )
  59. pool.close()
  60. left = [i for i in range(len(tasks))]
  61. while left:
  62. time.sleep(0.01)
  63. tmp = []
  64. for i in left:
  65. if tasks[i].ready():
  66. tasks[i] = tasks[i].get()
  67. else:
  68. tmp.append(i)
  69. left = tmp
  70. return tasks[0], tasks[1], tasks[2:]