serialization.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. from __future__ import print_function, absolute_import
  2. import json
  3. import os.path as osp
  4. import shutil
  5. import torch
  6. from torch.nn import Parameter
  7. from .osutils import mkdir_if_missing
  8. def read_json(fpath):
  9. with open(fpath, 'r') as f:
  10. obj = json.load(f)
  11. return obj
  12. def write_json(obj, fpath):
  13. mkdir_if_missing(osp.dirname(fpath))
  14. with open(fpath, 'w') as f:
  15. json.dump(obj, f, indent=4, separators=(',', ': '))
  16. def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'):
  17. mkdir_if_missing(osp.dirname(fpath))
  18. torch.save(state, fpath)
  19. if is_best:
  20. shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar'))
  21. def load_checkpoint(fpath):
  22. if osp.isfile(fpath):
  23. checkpoint = torch.load(fpath)
  24. print("=> Loaded checkpoint '{}'".format(fpath))
  25. return checkpoint
  26. else:
  27. raise ValueError("=> No checkpoint found at '{}'".format(fpath))
  28. def copy_state_dict(state_dict, model, strip=None):
  29. tgt_state = model.state_dict()
  30. copied_names = set()
  31. for name, param in state_dict.items():
  32. if strip is not None and name.startswith(strip):
  33. name = name[len(strip):]
  34. if name not in tgt_state:
  35. continue
  36. if isinstance(param, Parameter):
  37. param = param.data
  38. if param.size() != tgt_state[name].size():
  39. print('mismatch:', name, param.size(), tgt_state[name].size())
  40. continue
  41. tgt_state[name].copy_(param)
  42. copied_names.add(name)
  43. missing = set(tgt_state.keys()) - copied_names
  44. if len(missing) > 0:
  45. print("missing keys in state_dict:", missing)
  46. return model