services_test.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import os
  2. import time
  3. import unittest
  4. from unittest.mock import patch
  5. from kazoo.client import KazooClient
  6. from kazoo.exceptions import NodeExistsError, NoNodeError
  7. from fate_flow.db import db_services
  8. from fate_flow.errors.error_services import *
  9. from fate_flow.db.db_models import DB, MachineLearningModelInfo as MLModel
  10. from fate_flow import settings
  11. model_download_url = 'http://127.0.0.1:9380/v1/model/transfer/arbiter-10000_guest-9999_host-10000_model/202105060929263278441'
  12. escaped_model_download_url = '/FATE-SERVICES/flow/online/transfer/providers/http%3A%2F%2F127.0.0.1%3A9380%2Fv1%2Fmodel%2Ftransfer%2Farbiter-10000_guest-9999_host-10000_model%2F202105060929263278441'
  13. class TestZooKeeperDB(unittest.TestCase):
  14. def setUp(self):
  15. # required environment: ZOOKEEPER_HOSTS
  16. # optional environment: ZOOKEEPER_USERNAME, ZOOKEEPER_PASSWORD
  17. config = {
  18. 'hosts': os.environ['ZOOKEEPER_HOSTS'].split(','),
  19. 'use_acl': False,
  20. }
  21. username = os.environ.get('ZOOKEEPER_USERNAME')
  22. password = os.environ.get('ZOOKEEPER_PASSWORD')
  23. if username and password:
  24. config.update({
  25. 'use_acl': True,
  26. 'username': username,
  27. 'password': password,
  28. })
  29. with patch.object(db_services.ServiceRegistry, 'USE_REGISTRY', 'ZooKeeper'), \
  30. patch.object(db_services.ServiceRegistry, 'ZOOKEEPER', config):
  31. self.service_db = db_services.service_db()
  32. def test_services_db(self):
  33. self.assertEqual(type(self.service_db), db_services.ZooKeeperDB)
  34. self.assertNotEqual(type(self.service_db), db_services.FallbackDB)
  35. self.assertEqual(type(self.service_db.client), KazooClient)
  36. def test_zookeeper_not_configured(self):
  37. with patch.object(db_services.ServiceRegistry, 'USE_REGISTRY', True), \
  38. patch.object(db_services.ServiceRegistry, 'ZOOKEEPER', {'hosts': None}), \
  39. self.assertRaisesRegex(ZooKeeperNotConfigured, ZooKeeperNotConfigured.message):
  40. db_services.service_db()
  41. def test_missing_zookeeper_username_or_password(self):
  42. with patch.object(db_services.ServiceRegistry, 'USE_REGISTRY', True), \
  43. patch.object(db_services.ServiceRegistry, 'ZOOKEEPER', {
  44. 'hosts': ['127.0.0.1:2281'],
  45. 'use_acl': True,
  46. }), self.assertRaisesRegex(
  47. MissingZooKeeperUsernameOrPassword, MissingZooKeeperUsernameOrPassword.message):
  48. db_services.service_db()
  49. def test_get_znode_path(self):
  50. self.assertEqual(self.service_db._get_znode_path('fateflow', model_download_url), escaped_model_download_url)
  51. def test_crud(self):
  52. self.service_db._insert('fateflow', model_download_url)
  53. self.assertIn(model_download_url, self.service_db.get_urls('fateflow'))
  54. self.service_db._delete('fateflow', model_download_url)
  55. self.assertNotIn(model_download_url, self.service_db.get_urls('fateflow'))
  56. def test_insert_exists_node(self):
  57. self.service_db._delete('servings', 'http://foo/bar')
  58. self.service_db._insert('servings', 'http://foo/bar')
  59. with self.assertRaises(NodeExistsError):
  60. self.service_db.client.create(self.service_db._get_znode_path('servings', 'http://foo/bar'), makepath=True)
  61. self.service_db._insert('servings', 'http://foo/bar')
  62. self.service_db._delete('servings', 'http://foo/bar')
  63. def test_delete_not_exists_node(self):
  64. self.service_db._delete('servings', 'http://foo/bar')
  65. with self.assertRaises(NoNodeError):
  66. self.service_db.client.delete(self.service_db._get_znode_path('servings', 'http://foo/bar'))
  67. self.service_db._delete('servings', 'http://foo/bar')
  68. def test_connection_closed(self):
  69. self.service_db._insert('fateflow', model_download_url)
  70. self.assertIn(model_download_url, self.service_db.get_urls('fateflow'))
  71. self.service_db.client.stop()
  72. self.service_db.client.start()
  73. self.assertNotIn(model_download_url, self.service_db.get_urls('fateflow'))
  74. def test_register_models(self):
  75. try:
  76. os.remove(DB.database)
  77. except FileNotFoundError:
  78. pass
  79. MLModel.create_table()
  80. for x in range(1, 101):
  81. job_id = str(time.time())
  82. model = MLModel(
  83. f_role='host', f_party_id='100', f_job_id=job_id,
  84. f_model_id=f'foobar#{x}', f_model_version=job_id,
  85. f_initiator_role='host', f_work_mode=0
  86. )
  87. model.save(force_insert=True)
  88. self.assertEqual(db_services.models_group_by_party_model_id_and_model_version().count(), 100)
  89. with patch.object(self.service_db, '_insert') as insert:
  90. self.service_db.register_models()
  91. self.assertEqual(insert.call_count, 100)
  92. with patch.object(self.service_db, '_delete') as delete:
  93. self.service_db.unregister_models()
  94. self.assertEqual(delete.call_count, 100)
  95. os.remove(DB.database)
  96. class TestFallbackDB(unittest.TestCase):
  97. def setUp(self):
  98. with patch.object(db_services.ServiceRegistry, 'USE_REGISTRY', False):
  99. self.service_db = db_services.service_db()
  100. def test_get_urls(self):
  101. self.assertEqual(self.service_db._get_urls('fateflow'), ['http://127.0.0.1:9380/v1/model/transfer'])
  102. self.assertEqual(self.service_db._get_urls('servings'), ['http://127.0.0.1:8000'])
  103. def test_crud(self):
  104. self.service_db._insert('fateflow', model_download_url)
  105. self.assertNotIn(model_download_url, self.service_db.get_urls('fateflow'))
  106. self.service_db._delete('fateflow', model_download_url)
  107. self.assertNotIn(model_download_url, self.service_db.get_urls('fateflow'))
  108. def test_get_model_download_url(self):
  109. self.assertEqual(db_services.get_model_download_url('foo-111#bar-222', '20210616'),
  110. 'http://127.0.0.1:9380/v1/model/transfer/foo-111_bar-222/20210616')
  111. def test_not_supported_service(self):
  112. with self.assertRaisesRegex(ServiceNotSupported, 'The service foobar is not supported'):
  113. self.service_db.get_urls('foobar')
  114. if __name__ == '__main__':
  115. unittest.main()