vision.py 1.0 KB

1234567891011121314151617181920212223242526
  1. import torch as t
  2. from torchvision.models import get_model
  3. class TorchVisionModels(t.nn.Module):
  4. """
  5. This Class provides ALL torchvision classification models,
  6. instantiate models and using pretrained weights by providing string model name and weight names
  7. Parameters
  8. ----------
  9. vision_model_name: str, name of models provided by torchvision.models, for all available vision model, see:
  10. https://pytorch.org/vision/stable/models.html#table-of-all-available-classification-weights
  11. pretrain_weights: str, name of pretrained weight, for available vision weights, see:
  12. https://pytorch.org/vision/stable/models.html#table-of-all-available-classification-weights
  13. """
  14. def __init__(self, vision_model_name: str, pretrain_weights: str = None):
  15. super(TorchVisionModels, self).__init__()
  16. self.model = get_model(vision_model_name, weights=pretrain_weights)
  17. def forward(self, x):
  18. return self.model(x)
  19. def __repr__(self):
  20. return self.model.__repr__()