checkpoint_test.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. import unittest
  2. from unittest.mock import patch
  3. import hashlib
  4. from pathlib import Path
  5. from datetime import datetime
  6. from collections import deque
  7. from tempfile import TemporaryDirectory
  8. from ruamel import yaml
  9. from fate_flow.model import checkpoint
  10. model_string = (Path(__file__).parent.parent / 'misc' / 'DataIOMeta.pb').read_bytes()
  11. sha1 = hashlib.sha1(model_string).hexdigest()
  12. buffer_name = 'DataIOMeta'
  13. model_buffers = {
  14. 'my_model': checkpoint.parse_proto_object(buffer_name, model_string),
  15. }
  16. data = yaml.dump({
  17. 'step_index': 123,
  18. 'step_name': 'foobar',
  19. 'create_time': '2021-07-08T07:51:01.963423',
  20. 'models': {
  21. 'my_model': {
  22. 'filename': 'my_model.pb',
  23. 'sha1': sha1,
  24. 'buffer_name': buffer_name,
  25. },
  26. },
  27. }, Dumper=yaml.RoundTripDumper)
  28. class TestCheckpoint(unittest.TestCase):
  29. def setUp(self):
  30. self.tmpdir = TemporaryDirectory()
  31. self.checkpoint = checkpoint.Checkpoint(Path(self.tmpdir.name), 123, 'foobar')
  32. self.filepath = self.checkpoint.directory / 'my_model.pb'
  33. def tearDown(self):
  34. self.tmpdir.cleanup()
  35. def test_path(self):
  36. directory = Path(self.tmpdir.name) / '123#foobar'
  37. self.assertEqual(self.checkpoint.directory, directory)
  38. self.assertEqual(self.checkpoint.database, directory / 'database.yaml')
  39. def test_save_checkpoint(self):
  40. self.assertTrue(self.checkpoint.directory.exists())
  41. self.assertFalse(self.checkpoint.available)
  42. self.assertFalse(self.filepath.exists())
  43. self.assertIsNone(self.checkpoint.create_time)
  44. self.checkpoint.save(model_buffers)
  45. self.assertTrue(self.checkpoint.available)
  46. self.assertTrue(self.filepath.exists())
  47. self.assertIsNotNone(self.checkpoint.create_time)
  48. self.assertEqual(self.checkpoint.database.read_text('utf8'),
  49. data.replace('2021-07-08T07:51:01.963423', self.checkpoint.create_time.isoformat()), 1)
  50. self.assertEqual(self.filepath.read_bytes(), model_string)
  51. def test_read_checkpoint(self):
  52. self.assertTrue(self.checkpoint.directory.exists())
  53. self.assertFalse(self.checkpoint.available)
  54. self.assertFalse(self.filepath.exists())
  55. self.filepath.write_bytes(model_string)
  56. self.assertFalse(self.checkpoint.available)
  57. self.checkpoint.database.write_text(data, 'utf8')
  58. self.assertTrue(self.checkpoint.available)
  59. self.assertIsNone(self.checkpoint.create_time)
  60. self.assertEqual(self.checkpoint.read(), model_buffers)
  61. self.assertEqual(self.checkpoint.step_index, 123)
  62. self.assertEqual(self.checkpoint.step_name, 'foobar')
  63. self.assertEqual(self.checkpoint.create_time, datetime.fromisoformat('2021-07-08T07:51:01.963423'))
  64. def test_remove_checkpoint(self):
  65. self.checkpoint.save(model_buffers)
  66. self.checkpoint.database.write_text(data, 'utf8')
  67. self.checkpoint.remove()
  68. self.assertTrue(self.checkpoint.directory.exists())
  69. self.assertFalse(self.filepath.exists())
  70. self.assertFalse(self.checkpoint.available)
  71. self.assertIsNone(self.checkpoint.create_time)
  72. def test_read_checkpoint_step_index_or_step_name_not_match(self):
  73. self.filepath.write_bytes(model_string)
  74. self.checkpoint.database.write_text(data.replace('123', '233', 1), 'utf8')
  75. with self.assertRaisesRegex(ValueError, 'Checkpoint may be incorrect: step_index or step_name dose not match.'):
  76. self.checkpoint.read()
  77. def test_read_checkpoint_no_pb_file(self):
  78. self.checkpoint.database.write_text(data, 'utf8')
  79. with self.assertRaisesRegex(FileNotFoundError, 'Checkpoint is incorrect: protobuf file not found.'):
  80. self.checkpoint.read()
  81. def test_read_checkpoint_hash_not_match(self):
  82. self.filepath.write_bytes(model_string)
  83. self.checkpoint.database.write_text(data.replace(sha1, 'abcdef', 1), 'utf8')
  84. with self.assertRaisesRegex(ValueError, 'Checkpoint may be incorrect: hash dose not match.'):
  85. self.checkpoint.read()
  86. class TestCheckpointManager(unittest.TestCase):
  87. def setUp(self):
  88. self.tmpdir = TemporaryDirectory()
  89. with patch('fate_flow.model.checkpoint.get_project_base_directory', return_value=self.tmpdir.name):
  90. self.checkpoint_manager = checkpoint.CheckpointManager('job_id', 'role', 1000, 'model_id', 'model_version')
  91. def tearDown(self):
  92. self.tmpdir.cleanup()
  93. def test_directory(self):
  94. self.assertEqual(self.checkpoint_manager.directory,
  95. Path(self.tmpdir.name) / 'model_local_cache' /
  96. 'role#1000#model_id' / 'model_version' / 'checkpoint' / 'pipeline')
  97. def test_load_checkpoints_from_disk(self):
  98. for x in range(1, 51):
  99. directory = self.checkpoint_manager.directory / f'{x}#foobar{x}'
  100. directory.mkdir(0o755)
  101. (directory / 'my_model.pb').write_bytes(model_string)
  102. (directory / 'database.yaml').write_text(
  103. data.replace('123', str(x), 1).replace('foobar', f'foobar{x}', 1), 'utf8')
  104. self.checkpoint_manager.load_checkpoints_from_disk()
  105. self.assertEqual(self.checkpoint_manager.checkpoints_number, 50)
  106. self.assertEqual(self.checkpoint_manager.latest_step_index, 50)
  107. self.assertEqual(self.checkpoint_manager.latest_step_name, 'foobar50')
  108. self.assertEqual(self.checkpoint_manager.latest_checkpoint.read(), model_buffers)
  109. def test_checkpoint_index(self):
  110. for x in range(1, 101, 2):
  111. directory = self.checkpoint_manager.directory / f'{x}#foobar{x}'
  112. directory.mkdir(0o755)
  113. (directory / 'my_model.pb').write_bytes(model_string)
  114. (directory / 'database.yaml').write_text(
  115. data.replace('123', str(x), 1).replace('foobar', f'foobar{x}', 1), 'utf8')
  116. self.checkpoint_manager.load_checkpoints_from_disk()
  117. self.assertEqual(list(self.checkpoint_manager.number_indexed_checkpoints.keys()),
  118. list(range(1, 101, 2)))
  119. self.assertEqual(list(self.checkpoint_manager.name_indexed_checkpoints.keys()),
  120. [f'foobar{x}' for x in range(1, 101, 2)])
  121. for x in range(1, 101, 2):
  122. _checkpoint = self.checkpoint_manager.get_checkpoint_by_index(x)
  123. self.assertIs(self.checkpoint_manager.get_checkpoint_by_name(f'foobar{x}'), _checkpoint)
  124. self.assertEqual(_checkpoint.step_index, x)
  125. self.assertEqual(_checkpoint.step_name, f'foobar{x}')
  126. self.assertIsNone(_checkpoint.create_time)
  127. _model_buffers = _checkpoint.read()
  128. self.assertEqual(_checkpoint.step_index, x)
  129. self.assertEqual(_checkpoint.step_name, f'foobar{x}')
  130. self.assertEqual(_checkpoint.create_time.isoformat(), '2021-07-08T07:51:01.963423')
  131. def test_new_checkpoint(self):
  132. self.checkpoint_manager.checkpoints = deque(maxlen=10)
  133. for x in range(1, 31):
  134. _checkpoint = self.checkpoint_manager.new_checkpoint(x, f'foobar{x}')
  135. _checkpoint.save(model_buffers)
  136. self.assertEqual(self.checkpoint_manager.latest_step_index, x)
  137. self.assertEqual(self.checkpoint_manager.latest_step_name, f'foobar{x}')
  138. self.assertEqual(self.checkpoint_manager.checkpoints_number, 10)
  139. self.assertEqual(len(list(self.checkpoint_manager.directory.rglob('my_model.pb'))), 10)
  140. self.assertEqual(len(list(self.checkpoint_manager.directory.rglob('database.yaml'))), 10)
  141. self.assertEqual(len(list(self.checkpoint_manager.directory.rglob('.lock'))), 10)
  142. self.assertEqual(len(list(self.checkpoint_manager.directory.glob('*'))), 30)
  143. def test_clean(self):
  144. for x in range(10):
  145. _checkpoint = self.checkpoint_manager.new_checkpoint(x, f'foobar{x}')
  146. _checkpoint.save(model_buffers)
  147. self.assertEqual(self.checkpoint_manager.checkpoints_number, 10)
  148. self.assertEqual(len(list(self.checkpoint_manager.directory.glob('*'))), 10)
  149. self.checkpoint_manager.clean()
  150. self.assertEqual(self.checkpoint_manager.checkpoints_number, 0)
  151. self.assertTrue(self.checkpoint_manager.directory.exists())
  152. self.assertEqual(len(list(self.checkpoint_manager.directory.glob('*'))), 0)
  153. if __name__ == '__main__':
  154. unittest.main()