123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197 |
- import unittest
- from unittest.mock import patch
- import hashlib
- from pathlib import Path
- from datetime import datetime
- from collections import deque
- from tempfile import TemporaryDirectory
- from ruamel import yaml
- from fate_flow.model import checkpoint
- model_string = (Path(__file__).parent.parent / 'misc' / 'DataIOMeta.pb').read_bytes()
- sha1 = hashlib.sha1(model_string).hexdigest()
- buffer_name = 'DataIOMeta'
- model_buffers = {
- 'my_model': checkpoint.parse_proto_object(buffer_name, model_string),
- }
- data = yaml.dump({
- 'step_index': 123,
- 'step_name': 'foobar',
- 'create_time': '2021-07-08T07:51:01.963423',
- 'models': {
- 'my_model': {
- 'filename': 'my_model.pb',
- 'sha1': sha1,
- 'buffer_name': buffer_name,
- },
- },
- }, Dumper=yaml.RoundTripDumper)
- class TestCheckpoint(unittest.TestCase):
- def setUp(self):
- self.tmpdir = TemporaryDirectory()
- self.checkpoint = checkpoint.Checkpoint(Path(self.tmpdir.name), 123, 'foobar')
- self.filepath = self.checkpoint.directory / 'my_model.pb'
- def tearDown(self):
- self.tmpdir.cleanup()
- def test_path(self):
- directory = Path(self.tmpdir.name) / '123#foobar'
- self.assertEqual(self.checkpoint.directory, directory)
- self.assertEqual(self.checkpoint.database, directory / 'database.yaml')
- def test_save_checkpoint(self):
- self.assertTrue(self.checkpoint.directory.exists())
- self.assertFalse(self.checkpoint.available)
- self.assertFalse(self.filepath.exists())
- self.assertIsNone(self.checkpoint.create_time)
- self.checkpoint.save(model_buffers)
- self.assertTrue(self.checkpoint.available)
- self.assertTrue(self.filepath.exists())
- self.assertIsNotNone(self.checkpoint.create_time)
- self.assertEqual(self.checkpoint.database.read_text('utf8'),
- data.replace('2021-07-08T07:51:01.963423', self.checkpoint.create_time.isoformat()), 1)
- self.assertEqual(self.filepath.read_bytes(), model_string)
- def test_read_checkpoint(self):
- self.assertTrue(self.checkpoint.directory.exists())
- self.assertFalse(self.checkpoint.available)
- self.assertFalse(self.filepath.exists())
- self.filepath.write_bytes(model_string)
- self.assertFalse(self.checkpoint.available)
- self.checkpoint.database.write_text(data, 'utf8')
- self.assertTrue(self.checkpoint.available)
- self.assertIsNone(self.checkpoint.create_time)
- self.assertEqual(self.checkpoint.read(), model_buffers)
- self.assertEqual(self.checkpoint.step_index, 123)
- self.assertEqual(self.checkpoint.step_name, 'foobar')
- self.assertEqual(self.checkpoint.create_time, datetime.fromisoformat('2021-07-08T07:51:01.963423'))
- def test_remove_checkpoint(self):
- self.checkpoint.save(model_buffers)
- self.checkpoint.database.write_text(data, 'utf8')
- self.checkpoint.remove()
- self.assertTrue(self.checkpoint.directory.exists())
- self.assertFalse(self.filepath.exists())
- self.assertFalse(self.checkpoint.available)
- self.assertIsNone(self.checkpoint.create_time)
- def test_read_checkpoint_step_index_or_step_name_not_match(self):
- self.filepath.write_bytes(model_string)
- self.checkpoint.database.write_text(data.replace('123', '233', 1), 'utf8')
- with self.assertRaisesRegex(ValueError, 'Checkpoint may be incorrect: step_index or step_name dose not match.'):
- self.checkpoint.read()
- def test_read_checkpoint_no_pb_file(self):
- self.checkpoint.database.write_text(data, 'utf8')
- with self.assertRaisesRegex(FileNotFoundError, 'Checkpoint is incorrect: protobuf file not found.'):
- self.checkpoint.read()
- def test_read_checkpoint_hash_not_match(self):
- self.filepath.write_bytes(model_string)
- self.checkpoint.database.write_text(data.replace(sha1, 'abcdef', 1), 'utf8')
- with self.assertRaisesRegex(ValueError, 'Checkpoint may be incorrect: hash dose not match.'):
- self.checkpoint.read()
- class TestCheckpointManager(unittest.TestCase):
- def setUp(self):
- self.tmpdir = TemporaryDirectory()
- with patch('fate_flow.model.checkpoint.get_project_base_directory', return_value=self.tmpdir.name):
- self.checkpoint_manager = checkpoint.CheckpointManager('job_id', 'role', 1000, 'model_id', 'model_version')
- def tearDown(self):
- self.tmpdir.cleanup()
- def test_directory(self):
- self.assertEqual(self.checkpoint_manager.directory,
- Path(self.tmpdir.name) / 'model_local_cache' /
- 'role#1000#model_id' / 'model_version' / 'checkpoint' / 'pipeline')
- def test_load_checkpoints_from_disk(self):
- for x in range(1, 51):
- directory = self.checkpoint_manager.directory / f'{x}#foobar{x}'
- directory.mkdir(0o755)
- (directory / 'my_model.pb').write_bytes(model_string)
- (directory / 'database.yaml').write_text(
- data.replace('123', str(x), 1).replace('foobar', f'foobar{x}', 1), 'utf8')
- self.checkpoint_manager.load_checkpoints_from_disk()
- self.assertEqual(self.checkpoint_manager.checkpoints_number, 50)
- self.assertEqual(self.checkpoint_manager.latest_step_index, 50)
- self.assertEqual(self.checkpoint_manager.latest_step_name, 'foobar50')
- self.assertEqual(self.checkpoint_manager.latest_checkpoint.read(), model_buffers)
- def test_checkpoint_index(self):
- for x in range(1, 101, 2):
- directory = self.checkpoint_manager.directory / f'{x}#foobar{x}'
- directory.mkdir(0o755)
- (directory / 'my_model.pb').write_bytes(model_string)
- (directory / 'database.yaml').write_text(
- data.replace('123', str(x), 1).replace('foobar', f'foobar{x}', 1), 'utf8')
- self.checkpoint_manager.load_checkpoints_from_disk()
- self.assertEqual(list(self.checkpoint_manager.number_indexed_checkpoints.keys()),
- list(range(1, 101, 2)))
- self.assertEqual(list(self.checkpoint_manager.name_indexed_checkpoints.keys()),
- [f'foobar{x}' for x in range(1, 101, 2)])
- for x in range(1, 101, 2):
- _checkpoint = self.checkpoint_manager.get_checkpoint_by_index(x)
- self.assertIs(self.checkpoint_manager.get_checkpoint_by_name(f'foobar{x}'), _checkpoint)
- self.assertEqual(_checkpoint.step_index, x)
- self.assertEqual(_checkpoint.step_name, f'foobar{x}')
- self.assertIsNone(_checkpoint.create_time)
- _model_buffers = _checkpoint.read()
- self.assertEqual(_checkpoint.step_index, x)
- self.assertEqual(_checkpoint.step_name, f'foobar{x}')
- self.assertEqual(_checkpoint.create_time.isoformat(), '2021-07-08T07:51:01.963423')
- def test_new_checkpoint(self):
- self.checkpoint_manager.checkpoints = deque(maxlen=10)
- for x in range(1, 31):
- _checkpoint = self.checkpoint_manager.new_checkpoint(x, f'foobar{x}')
- _checkpoint.save(model_buffers)
- self.assertEqual(self.checkpoint_manager.latest_step_index, x)
- self.assertEqual(self.checkpoint_manager.latest_step_name, f'foobar{x}')
- self.assertEqual(self.checkpoint_manager.checkpoints_number, 10)
- self.assertEqual(len(list(self.checkpoint_manager.directory.rglob('my_model.pb'))), 10)
- self.assertEqual(len(list(self.checkpoint_manager.directory.rglob('database.yaml'))), 10)
- self.assertEqual(len(list(self.checkpoint_manager.directory.rglob('.lock'))), 10)
- self.assertEqual(len(list(self.checkpoint_manager.directory.glob('*'))), 30)
- def test_clean(self):
- for x in range(10):
- _checkpoint = self.checkpoint_manager.new_checkpoint(x, f'foobar{x}')
- _checkpoint.save(model_buffers)
- self.assertEqual(self.checkpoint_manager.checkpoints_number, 10)
- self.assertEqual(len(list(self.checkpoint_manager.directory.glob('*'))), 10)
- self.checkpoint_manager.clean()
- self.assertEqual(self.checkpoint_manager.checkpoints_number, 0)
- self.assertTrue(self.checkpoint_manager.directory.exists())
- self.assertEqual(len(list(self.checkpoint_manager.directory.glob('*'))), 0)
- if __name__ == '__main__':
- unittest.main()
|