@ -22,16 +22,20 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
class HypernetworkModule ( torch . nn . Module ) :
multiplier = 1.0
def __init__ ( self , dim , state_dict = None , layer_structure = None , add_layer_norm = False ):
def __init__ ( self , dim , state_dict = None , layer_structure = None , add_layer_norm = False , activation_func = None ):
super ( ) . __init__ ( )
assert layer_structure is not None , " layer_structure mu t not be None"
assert layer_structure is not None , " layer_structure mu s t not be None"
assert layer_structure [ 0 ] == 1 , " Multiplier Sequence should start with size 1! "
assert layer_structure [ - 1 ] == 1 , " Multiplier Sequence should end with size 1! "
linears = [ ]
for i in range ( len ( layer_structure ) - 1 ) :
linears . append ( torch . nn . Linear ( int ( dim * layer_structure [ i ] ) , int ( dim * layer_structure [ i + 1 ] ) ) )
if activation_func == " relu " :
linears . append ( torch . nn . ReLU ( ) )
if activation_func == " leakyrelu " :
linears . append ( torch . nn . LeakyReLU ( ) )
if add_layer_norm :
linears . append ( torch . nn . LayerNorm ( int ( dim * layer_structure [ i + 1 ] ) ) )
@ -42,8 +46,9 @@ class HypernetworkModule(torch.nn.Module):
self . load_state_dict ( state_dict )
else :
for layer in self . linear :
layer . weight . data . normal_ ( mean = 0.0 , std = 0.01 )
layer . bias . data . zero_ ( )
if not " ReLU " in layer . __str__ ( ) :
layer . weight . data . normal_ ( mean = 0.0 , std = 0.01 )
layer . bias . data . zero_ ( )
self . to ( devices . device )
@ -69,7 +74,8 @@ class HypernetworkModule(torch.nn.Module):
def trainables ( self ) :
layer_structure = [ ]
for layer in self . linear :
layer_structure + = [ layer . weight , layer . bias ]
if not " ReLU " in layer . __str__ ( ) :
layer_structure + = [ layer . weight , layer . bias ]
return layer_structure
@ -81,7 +87,7 @@ class Hypernetwork:
filename = None
name = None
def __init__ ( self , name = None , enable_sizes = None , layer_structure = None , add_layer_norm = False ):
def __init__ ( self , name = None , enable_sizes = None , layer_structure = None , add_layer_norm = False , activation_func = None ):
self . filename = None
self . name = name
self . layers = { }
@ -90,11 +96,12 @@ class Hypernetwork:
self . sd_checkpoint_name = None
self . layer_structure = layer_structure
self . add_layer_norm = add_layer_norm
self . activation_func = activation_func
for size in enable_sizes or [ ] :
self . layers [ size ] = (
HypernetworkModule ( size , None , self . layer_structure , self . add_layer_norm ),
HypernetworkModule ( size , None , self . layer_structure , self . add_layer_norm ),
HypernetworkModule ( size , None , self . layer_structure , self . add_layer_norm , self . activation_func ),
HypernetworkModule ( size , None , self . layer_structure , self . add_layer_norm , self . activation_func ),
)
def weights ( self ) :
@ -117,6 +124,7 @@ class Hypernetwork:
state_dict [ ' name ' ] = self . name
state_dict [ ' layer_structure ' ] = self . layer_structure
state_dict [ ' is_layer_norm ' ] = self . add_layer_norm
state_dict [ ' activation_func ' ] = self . activation_func
state_dict [ ' sd_checkpoint ' ] = self . sd_checkpoint
state_dict [ ' sd_checkpoint_name ' ] = self . sd_checkpoint_name
@ -131,12 +139,13 @@ class Hypernetwork:
self . layer_structure = state_dict . get ( ' layer_structure ' , [ 1 , 2 , 1 ] )
self . add_layer_norm = state_dict . get ( ' is_layer_norm ' , False )
self . activation_func = state_dict . get ( ' activation_func ' , None )
for size , sd in state_dict . items ( ) :
if type ( size ) == int :
self . layers [ size ] = (
HypernetworkModule ( size , sd [ 0 ] , self . layer_structure , self . add_layer_norm ),
HypernetworkModule ( size , sd [ 1 ] , self . layer_structure , self . add_layer_norm ),
HypernetworkModule ( size , sd [ 0 ] , self . layer_structure , self . add_layer_norm , self . activation_func ),
HypernetworkModule ( size , sd [ 1 ] , self . layer_structure , self . add_layer_norm , self . activation_func ),
)
self . name = state_dict . get ( ' name ' , self . name )