import torch import math class LoRAModule(torch.nn.Module): """ replaces forward method of the original Linear, instead of replacing the original Linear module. """ def __init__( self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=64, alpha=32, dropout=None, rank_dropout=None, module_dropout=None, ): """if alpha == 0 or None, alpha is rank (no scaling).""" super().__init__() self.lora_name = lora_name in_dim = org_module.in_features out_dim = org_module.out_features self.lora_dim = lora_dim self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) if type(alpha) == torch.Tensor: alpha = alpha.detach().float().numpy() # without casting, bf16 causes error alpha = self.lora_dim if alpha is None or alpha == 0 else alpha self.scale = alpha / self.lora_dim self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える # same as microsoft's torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) torch.nn.init.zeros_(self.lora_up.weight) self.multiplier = multiplier self.org_module = org_module # remove in applying self.dropout = dropout self.rank_dropout = rank_dropout self.module_dropout = module_dropout def apply_to(self): self.org_forward = self.org_module.forward self.org_module.forward = self.forward del self.org_module def forward(self, x): org_forwarded = self.org_forward(x) # module dropout if self.module_dropout is not None and self.training: if torch.rand(1) < self.module_dropout: return org_forwarded lx = self.lora_down(x) # normal dropout if self.dropout is not None and self.training: lx = torch.nn.functional.dropout(lx, p=self.dropout) # rank dropout if self.rank_dropout is not None and self.training: mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout if len(lx.size()) == 3: mask = mask.unsqueeze(1) # for Text Encoder elif len(lx.size()) == 4: mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d lx = lx * mask # scaling for rank dropout: treat as if the rank is changed scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability else: scale = self.scale lx = self.lora_up(lx) return org_forwarded + lx * self.multiplier * scale