from __future__ import absolute_import import torch def to_numpy(tensor): if torch.is_tensor(tensor): return tensor.cpu().numpy() elif type(tensor).__module__ != 'numpy': raise ValueError("Cannot convert {} to numpy array" .format(type(tensor))) return tensor def to_torch(ndarray): if type(ndarray).__module__ == 'numpy': return torch.from_numpy(ndarray) elif not torch.is_tensor(ndarray): raise ValueError("Cannot convert {} to torch tensor" .format(type(ndarray))) return ndarray