123456789101112131415161718192021222324252627282930313233343536373839 |
- from transformers.models.bert import BertModel
- from torch.nn import Module
- from federatedml.util import LOGGER
- class PretrainedBert(Module):
- def __init__(self, pretrained_model_name_or_path: str = 'bert-base-uncased', freeze_weight=False):
- """
- A pretrained Bert Model based on transformers
- Parameters
- ----------
- pretrained_model_name_or_path: string, specify the version of bert pretrained model,
- for all available bert model, see:
- https://huggingface.co/bert-base-uncased?text=The+goal+of+life+is+%5BMASK%5D.#model-variations
- or it can be a path to downloaded bert model
- freeze_weight: bool, freeze weight or not when training. if True, bert model will not be added to parameters,
- and skip grad calculation
- """
- super(PretrainedBert, self).__init__()
- self.pretrained_model_str = pretrained_model_name_or_path
- self.freeze_weight = freeze_weight
- LOGGER.info(
- 'if you are using non-local models, it will download the pretrained model and will take'
- 'some time')
- self.model = BertModel.from_pretrained(
- pretrained_model_name_or_path=self.pretrained_model_str)
- if self.freeze_weight:
- self.model.requires_grad_(False)
- def forward(self, x):
- return self.model(x)
- def parameters(self, recurse: bool = True):
- if self.freeze_weight:
- return (),
- else:
- return self.model.parameters(recurse=recurse)
|