123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 |
- import torch
- import math
- class LoRAModule(torch.nn.Module):
- """
- replaces forward method of the original Linear, instead of replacing the original Linear module.
- """
- def __init__(
- self,
- lora_name,
- org_module: torch.nn.Module,
- multiplier=1.0,
- lora_dim=64,
- alpha=32,
- dropout=None,
- rank_dropout=None,
- module_dropout=None,
- ):
- """if alpha == 0 or None, alpha is rank (no scaling)."""
- super().__init__()
- self.lora_name = lora_name
- in_dim = org_module.in_features
- out_dim = org_module.out_features
- self.lora_dim = lora_dim
- self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
- self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
- if type(alpha) == torch.Tensor:
- alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
- alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
- self.scale = alpha / self.lora_dim
- self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
- # same as microsoft's
- torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
- torch.nn.init.zeros_(self.lora_up.weight)
- self.multiplier = multiplier
- self.org_module = org_module # remove in applying
- self.dropout = dropout
- self.rank_dropout = rank_dropout
- self.module_dropout = module_dropout
- def apply_to(self):
- self.org_forward = self.org_module.forward
- self.org_module.forward = self.forward
- del self.org_module
- def forward(self, x):
- org_forwarded = self.org_forward(x)
- # module dropout
- if self.module_dropout is not None and self.training:
- if torch.rand(1) < self.module_dropout:
- return org_forwarded
- lx = self.lora_down(x)
- # normal dropout
- if self.dropout is not None and self.training:
- lx = torch.nn.functional.dropout(lx, p=self.dropout)
- # rank dropout
- if self.rank_dropout is not None and self.training:
- mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
- if len(lx.size()) == 3:
- mask = mask.unsqueeze(1) # for Text Encoder
- elif len(lx.size()) == 4:
- mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
- lx = lx * mask
- # scaling for rank dropout: treat as if the rank is changed
- scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
- else:
- scale = self.scale
- lx = self.lora_up(lx)
- return org_forwarded + lx * self.multiplier * scale
-
|