shakespeare.py 413 B

123456789101112131415
  1. import numpy as np
  2. import torch
  3. from easyfl.datasets.data_process.language_utils import word_to_indices, letter_to_vec
  4. def process_x(raw_x_batch):
  5. x_batch = [word_to_indices(word) for word in raw_x_batch]
  6. x_batch = np.array(x_batch)
  7. return torch.LongTensor(x_batch)
  8. def process_y(raw_y_batch):
  9. y_batch = [np.argmax(letter_to_vec(c)) for c in raw_y_batch]
  10. return torch.LongTensor(y_batch)