pretrained_bert.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. from transformers.models.bert import BertModel
  2. from torch.nn import Module
  3. from federatedml.util import LOGGER
  4. class PretrainedBert(Module):
  5. def __init__(self, pretrained_model_name_or_path: str = 'bert-base-uncased', freeze_weight=False):
  6. """
  7. A pretrained Bert Model based on transformers
  8. Parameters
  9. ----------
  10. pretrained_model_name_or_path: string, specify the version of bert pretrained model,
  11. for all available bert model, see:
  12. https://huggingface.co/bert-base-uncased?text=The+goal+of+life+is+%5BMASK%5D.#model-variations
  13. or it can be a path to downloaded bert model
  14. freeze_weight: bool, freeze weight or not when training. if True, bert model will not be added to parameters,
  15. and skip grad calculation
  16. """
  17. super(PretrainedBert, self).__init__()
  18. self.pretrained_model_str = pretrained_model_name_or_path
  19. self.freeze_weight = freeze_weight
  20. LOGGER.info(
  21. 'if you are using non-local models, it will download the pretrained model and will take'
  22. 'some time')
  23. self.model = BertModel.from_pretrained(
  24. pretrained_model_name_or_path=self.pretrained_model_str)
  25. if self.freeze_weight:
  26. self.model.requires_grad_(False)
  27. def forward(self, x):
  28. return self.model(x)
  29. def parameters(self, recurse: bool = True):
  30. if self.freeze_weight:
  31. return (),
  32. else:
  33. return self.model.parameters(recurse=recurse)