nlp_tokenizer.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. from federatedml.nn.dataset.base import Dataset
  2. import pandas as pd
  3. import torch as t
  4. from transformers import BertTokenizerFast
  5. import os
  6. import numpy as np
  7. # avoid tokenizer parallelism
  8. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  9. class TokenizerDataset(Dataset):
  10. """
  11. A Dataset for some basic NLP Tasks, this dataset will automatically transform raw text into word indices
  12. using BertTokenizer from transformers library,
  13. see https://huggingface.co/docs/transformers/model_doc/bert?highlight=berttokenizer for details of BertTokenizer
  14. Parameters
  15. ----------
  16. truncation bool, truncate word sequence to 'text_max_length'
  17. text_max_length int, max length of word sequences
  18. tokenizer_name_or_path str, name of bert tokenizer(see transformers official for details) or path to local
  19. transformer tokenizer folder
  20. return_label bool, return label or not, this option is for host dataset, when running hetero-NN
  21. """
  22. def __init__(self, truncation=True, text_max_length=128,
  23. tokenizer_name_or_path="bert-base-uncased",
  24. return_label=True):
  25. super(TokenizerDataset, self).__init__()
  26. self.text = None
  27. self.word_idx = None
  28. self.label = None
  29. self.tokenizer = None
  30. self.sample_ids = None
  31. self.truncation = truncation
  32. self.max_length = text_max_length
  33. self.with_label = return_label
  34. self.tokenizer_name_or_path = tokenizer_name_or_path
  35. def load(self, file_path):
  36. tokenizer = BertTokenizerFast.from_pretrained(
  37. self.tokenizer_name_or_path)
  38. self.tokenizer = tokenizer
  39. self.text = pd.read_csv(file_path)
  40. text_list = list(self.text.text)
  41. self.word_idx = tokenizer(
  42. text_list,
  43. padding=True,
  44. return_tensors='pt',
  45. truncation=self.truncation,
  46. max_length=self.max_length)['input_ids']
  47. if self.with_label:
  48. self.label = t.Tensor(self.text.label).detach().numpy()
  49. self.label = self.label.reshape((len(self.word_idx), -1))
  50. del tokenizer # avoid tokenizer parallelism
  51. if 'id' in self.text:
  52. self.sample_ids = self.text['id'].values.tolist()
  53. def get_classes(self):
  54. return np.unique(self.label).tolist()
  55. def get_vocab_size(self):
  56. return self.tokenizer.vocab_size
  57. def get_sample_ids(self):
  58. return self.sample_ids
  59. def __getitem__(self, item):
  60. if self.with_label:
  61. return self.word_idx[item], self.label[item]
  62. else:
  63. return self.word_idx[item]
  64. def __len__(self):
  65. return len(self.word_idx)
  66. def __repr__(self):
  67. return self.tokenizer.__repr__()