rnn.py 1.1 KB

123456789101112131415161718192021222324252627282930313233343536
  1. import torch
  2. import torch.nn as nn
  3. from easyfl.models.model import BaseModel
  4. def repackage_hidden(h):
  5. """Wraps hidden states in new Tensors, to detach them from their history."""
  6. if isinstance(h, torch.Tensor):
  7. return h.detach()
  8. else:
  9. return tuple(repackage_hidden(v) for v in h)
  10. class Model(BaseModel):
  11. def __init__(self, embedding_dim=8, voc_size=80, lstm_unit=256, batch_first=True, n_layers=2):
  12. super(Model, self).__init__()
  13. self.encoder = nn.Embedding(voc_size, embedding_dim)
  14. self.lstm = nn.LSTM(embedding_dim, lstm_unit, n_layers, batch_first=batch_first)
  15. self.decoder = nn.Linear(lstm_unit, voc_size)
  16. self.init_weights()
  17. def forward(self, inp):
  18. inp = self.encoder(inp)
  19. inp, _ = self.lstm(inp)
  20. # extract the last state of output for prediction
  21. hidden = inp[:, -1]
  22. output = self.decoder(hidden)
  23. return output
  24. def init_weights(self):
  25. init_range = 0.1
  26. self.encoder.weight.data.uniform_(-init_range, init_range)
  27. self.decoder.bias.data.zero_()
  28. self.decoder.weight.data.uniform_(-init_range, init_range)