@ -131,7 +131,7 @@ def load_lora(name, filename):
with torch . no_grad ( ) :
module . weight . copy_ ( weight )
module . to ( device = devices . device , dtype = devices . dtype )
module . to ( device = devices . cpu , dtype = devices . dtype )
if lora_key == " lora_up.weight " :
lora_module . up = module
@ -177,29 +177,69 @@ def load_loras(names, multipliers=None):
loaded_loras . append ( lora )
def lora_forward ( module , input , res ) :
input = devices . cond_cast_unet ( input )
if len ( loaded_loras ) == 0 :
return res
def lora_apply_weights ( self : torch . nn . Conv2d | torch . nn . Linear ) :
"""
Applies the currently selected set of Loras to the weight of torch layer self .
If weights already have this particular set of loras applied , does nothing .
If not , restores orginal weights from backup and alters weights according to loras .
"""
lora_layer_name = getattr ( module , ' lora_layer_name ' , None )
for lora in loaded_loras :
module = lora . modules . get ( lora_layer_name , None )
if module is not None :
if shared . opts . lora_apply_to_outputs and res . shape == input . shape :
res = res + module . up ( module . down ( res ) ) * lora . multiplier * ( module . alpha / module . up . weight . shape [ 1 ] if module . alpha else 1.0 )
else :
res = res + module . up ( module . down ( input ) ) * lora . multiplier * ( module . alpha / module . up . weight . shape [ 1 ] if module . alpha else 1.0 )
current_names = getattr ( self , " lora_current_names " , ( ) )
wanted_names = tuple ( ( x . name , x . multiplier ) for x in loaded_loras )
weights_backup = getattr ( self , " lora_weights_backup " , None )
if weights_backup is None :
weights_backup = self . weight . to ( devices . cpu , copy = True )
self . lora_weights_backup = weights_backup
if current_names != wanted_names :
if weights_backup is not None :
self . weight . copy_ ( weights_backup )
lora_layer_name = getattr ( self , ' lora_layer_name ' , None )
for lora in loaded_loras :
module = lora . modules . get ( lora_layer_name , None )
if module is None :
continue
return res
with torch . no_grad ( ) :
up = module . up . weight . to ( self . weight . device , dtype = self . weight . dtype )
down = module . down . weight . to ( self . weight . device , dtype = self . weight . dtype )
if up . shape [ 2 : ] == ( 1 , 1 ) and down . shape [ 2 : ] == ( 1 , 1 ) :
updown = ( up . squeeze ( 2 ) . squeeze ( 2 ) @ down . squeeze ( 2 ) . squeeze ( 2 ) ) . unsqueeze ( 2 ) . unsqueeze ( 3 )
else :
updown = up @ down
self . weight + = updown * lora . multiplier * ( module . alpha / module . up . weight . shape [ 1 ] if module . alpha else 1.0 )
setattr ( self , " lora_current_names " , wanted_names )
def lora_Linear_forward ( self , input ) :
return lora_forward ( self , input , torch . nn . Linear_forward_before_lora ( self , input ) )
lora_apply_weights ( self )
return torch . nn . Linear_forward_before_lora ( self , input )
def lora_Linear_load_state_dict ( self : torch . nn . Linear , * args , * * kwargs ) :
setattr ( self , " lora_current_names " , ( ) )
setattr ( self , " lora_weights_backup " , None )
return torch . nn . Linear_load_state_dict_before_lora ( self , * args , * * kwargs )
def lora_Conv2d_forward ( self , input ) :
return lora_forward ( self , input , torch . nn . Conv2d_forward_before_lora ( self , input ) )
lora_apply_weights ( self )
return torch . nn . Conv2d_forward_before_lora ( self , input )
def lora_Conv2d_load_state_dict ( self : torch . nn . Conv2d , * args , * * kwargs ) :
setattr ( self , " lora_current_names " , ( ) )
setattr ( self , " lora_weights_backup " , None )
return torch . nn . Conv2d_load_state_dict_before_lora ( self , * args , * * kwargs )
def list_available_loras ( ) :