femnist.py 206 B

12345678910
  1. import torch
  2. def process_x(raw_x_batch):
  3. raw_x_batch = torch.FloatTensor(raw_x_batch)
  4. return raw_x_batch.view(-1, 1, 28, 28)
  5. def process_y(raw_y_batch):
  6. return torch.LongTensor(raw_y_batch)