import torch def process_x(raw_x_batch): raw_x_batch = torch.FloatTensor(raw_x_batch) return raw_x_batch.view(-1, 1, 28, 28) def process_y(raw_y_batch): return torch.LongTensor(raw_y_batch)