__init__.py 594 B

123456789101112131415161718192021
  1. from __future__ import absolute_import
  2. import torch
  3. def to_numpy(tensor):
  4. if torch.is_tensor(tensor):
  5. return tensor.cpu().numpy()
  6. elif type(tensor).__module__ != 'numpy':
  7. raise ValueError("Cannot convert {} to numpy array"
  8. .format(type(tensor)))
  9. return tensor
  10. def to_torch(ndarray):
  11. if type(ndarray).__module__ == 'numpy':
  12. return torch.from_numpy(ndarray)
  13. elif not torch.is_tensor(ndarray):
  14. raise ValueError("Cannot convert {} to torch tensor"
  15. .format(type(ndarray)))
  16. return ndarray