operation.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. import torch as t
  2. import copy
  3. from torch.nn import Module
  4. class OpBase(object):
  5. def __init__(self):
  6. self.param_dict = {}
  7. def to_dict(self):
  8. ret = copy.deepcopy(self.param_dict)
  9. ret['op'] = type(self).__name__
  10. return ret
  11. class Astype(Module, OpBase):
  12. def __init__(self, cast_type: str):
  13. OpBase.__init__(self)
  14. Module.__init__(self)
  15. assert cast_type in [
  16. 'float',
  17. 'int',
  18. 'bool',
  19. 'float32',
  20. 'float64',
  21. 'int8',
  22. 'int16',
  23. 'int32',
  24. 'int64',
  25. 'float16']
  26. self.param_dict['cast_type'] = cast_type
  27. self.cast_type = cast_type
  28. self.cast_type_map = {
  29. 'float': t.float,
  30. 'int': t.int,
  31. 'bool': t.bool,
  32. 'float32': t.float32,
  33. 'float64': t.float64,
  34. 'float16': t.float16,
  35. 'int8': t.int8,
  36. 'int16': t.int16,
  37. 'int32': t.int32,
  38. 'int64': t.int64,
  39. }
  40. def forward(self, tensor: t.Tensor, **kwargs):
  41. return tensor.type(self.cast_type_map[self.cast_type])
  42. class Flatten(Module, OpBase):
  43. def __init__(self, start_dim=0, end_dim=-1):
  44. OpBase.__init__(self)
  45. Module.__init__(self)
  46. self.param_dict['start_dim'] = start_dim
  47. self.param_dict['end_dim'] = end_dim
  48. def forward(self, tensor):
  49. return tensor.flatten(**self.param_dict)
  50. class Reshape(Module, OpBase):
  51. def __init__(self, shape):
  52. OpBase.__init__(self)
  53. Module.__init__(self)
  54. assert isinstance(shape, tuple) or isinstance(shape, list)
  55. self.shape = shape
  56. self.param_dict['shape'] = list(shape)
  57. def forward(self, tensor: t.Tensor):
  58. return tensor.reshape(shape=self.shape)
  59. class Index(Module, OpBase):
  60. def __init__(self, index):
  61. OpBase.__init__(self)
  62. Module.__init__(self)
  63. assert isinstance(index, int)
  64. self.param_dict['index'] = index
  65. def forward(self, content):
  66. return content[self.param_dict['index']]
  67. class Select(Module, OpBase):
  68. def __init__(self, dim, idx):
  69. OpBase.__init__(self)
  70. Module.__init__(self)
  71. self.param_dict = {'dim': dim, 'index': idx}
  72. def forward(self, tensor):
  73. return tensor.select(self.param_dict['dim'], self.param_dict['index'])
  74. class SelectRange(Module, OpBase):
  75. def __init__(self, dim, start, end):
  76. OpBase.__init__(self)
  77. Module.__init__(self)
  78. self.param_dict = {'dim': dim, 'start': start, 'end': end}
  79. def forward(self, tensor):
  80. return tensor.select(
  81. self.param_dict['dim'], -1)[self.param_dict['start']: self.param_dict['end']]
  82. class Sum(Module, OpBase):
  83. def __init__(self, dim):
  84. OpBase.__init__(self)
  85. Module.__init__(self)
  86. assert isinstance(dim, int)
  87. self.param_dict['dim'] = dim
  88. def forward(self, tensor):
  89. return tensor.sum(dim=self.param_dict['dim'])
  90. class Squeeze(Module, OpBase):
  91. def __init__(self, **kwargs):
  92. OpBase.__init__(self)
  93. Module.__init__(self)
  94. def forward(self, tensor: t.Tensor):
  95. return tensor.squeeze()
  96. class Unsqueeze(Sum, OpBase):
  97. def __init__(self, dim):
  98. super(Unsqueeze, self).__init__(dim)
  99. def forward(self, tensor: t.Tensor):
  100. return tensor.unsqueeze(self.param_dict['dim'])