lora-checkpoint.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import torch
  2. import math
  3. class LoRAModule(torch.nn.Module):
  4. """
  5. replaces forward method of the original Linear, instead of replacing the original Linear module.
  6. """
  7. def __init__(
  8. self,
  9. lora_name,
  10. org_module: torch.nn.Module,
  11. multiplier=1.0,
  12. lora_dim=64,
  13. alpha=32,
  14. dropout=None,
  15. rank_dropout=None,
  16. module_dropout=None,
  17. ):
  18. """if alpha == 0 or None, alpha is rank (no scaling)."""
  19. super().__init__()
  20. self.lora_name = lora_name
  21. in_dim = org_module.in_features
  22. out_dim = org_module.out_features
  23. self.lora_dim = lora_dim
  24. self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
  25. self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
  26. if type(alpha) == torch.Tensor:
  27. alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
  28. alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
  29. self.scale = alpha / self.lora_dim
  30. self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
  31. # same as microsoft's
  32. torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
  33. torch.nn.init.zeros_(self.lora_up.weight)
  34. self.multiplier = multiplier
  35. self.org_module = org_module # remove in applying
  36. self.dropout = dropout
  37. self.rank_dropout = rank_dropout
  38. self.module_dropout = module_dropout
  39. def apply_to(self):
  40. self.org_forward = self.org_module.forward
  41. self.org_module.forward = self.forward
  42. del self.org_module
  43. def forward(self, x):
  44. org_forwarded = self.org_forward(x)
  45. # module dropout
  46. if self.module_dropout is not None and self.training:
  47. if torch.rand(1) < self.module_dropout:
  48. return org_forwarded
  49. lx = self.lora_down(x)
  50. # normal dropout
  51. if self.dropout is not None and self.training:
  52. lx = torch.nn.functional.dropout(lx, p=self.dropout)
  53. # rank dropout
  54. if self.rank_dropout is not None and self.training:
  55. mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
  56. if len(lx.size()) == 3:
  57. mask = mask.unsqueeze(1) # for Text Encoder
  58. elif len(lx.size()) == 4:
  59. mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
  60. lx = lx * mask
  61. # scaling for rank dropout: treat as if the rank is changed
  62. scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
  63. else:
  64. scale = self.scale
  65. lx = self.lora_up(lx)
  66. return org_forwarded + lx * self.multiplier * scale