123456789101112131415161718192021 |
- from __future__ import absolute_import
- import torch
- def to_numpy(tensor):
- if torch.is_tensor(tensor):
- return tensor.cpu().numpy()
- elif type(tensor).__module__ != 'numpy':
- raise ValueError("Cannot convert {} to numpy array"
- .format(type(tensor)))
- return tensor
- def to_torch(ndarray):
- if type(ndarray).__module__ == 'numpy':
- return torch.from_numpy(ndarray)
- elif not torch.is_tensor(ndarray):
- raise ValueError("Cannot convert {} to torch tensor"
- .format(type(ndarray)))
- return ndarray
|