@ -34,7 +34,8 @@ class HypernetworkModule(torch.nn.Module):
}
activation_dict . update ( { cls_name . lower ( ) : cls_obj for cls_name , cls_obj in inspect . getmembers ( torch . nn . modules . activation ) if inspect . isclass ( cls_obj ) and cls_obj . __module__ == ' torch.nn.modules.activation ' } )
def __init__ ( self , dim , state_dict = None , layer_structure = None , activation_func = None , weight_init = ' Normal ' , add_layer_norm = False , use_dropout = False , activate_output = False ) :
def __init__ ( self , dim , state_dict = None , layer_structure = None , activation_func = None , weight_init = ' Normal ' ,
add_layer_norm = False , use_dropout = False , activate_output = False , * * kwargs ) :
super ( ) . __init__ ( )
assert layer_structure is not None , " layer_structure must not be None "
@ -60,7 +61,7 @@ class HypernetworkModule(torch.nn.Module):
linears . append ( torch . nn . LayerNorm ( int ( dim * layer_structure [ i + 1 ] ) ) )
# Add dropout except last layer
if use_dropout and i < len ( layer_structure ) - 3 :
if ' last_layer_dropout ' in kwargs and kwargs [ ' last_layer_dropout ' ] and use_dropout and i < len ( layer_structure ) - 2 :
linears . append ( torch . nn . Dropout ( p = 0.3 ) )
self . linear = torch . nn . Sequential ( * linears )
@ -126,7 +127,7 @@ class Hypernetwork:
filename = None
name = None
def __init__ ( self , name = None , enable_sizes = None , layer_structure = None , activation_func = None , weight_init = None , add_layer_norm = False , use_dropout = False , activate_output = False ):
def __init__ ( self , name = None , enable_sizes = None , layer_structure = None , activation_func = None , weight_init = None , add_layer_norm = False , use_dropout = False , activate_output = False , * * kwargs ):
self . filename = None
self . name = name
self . layers = { }
@ -139,11 +140,14 @@ class Hypernetwork:
self . add_layer_norm = add_layer_norm
self . use_dropout = use_dropout
self . activate_output = activate_output
self . last_layer_dropout = kwargs [ ' last_layer_dropout ' ] if ' last_layer_dropout ' in kwargs else True
for size in enable_sizes or [ ] :
self . layers [ size ] = (
HypernetworkModule ( size , None , self . layer_structure , self . activation_func , self . weight_init , self . add_layer_norm , self . use_dropout , self . activate_output ) ,
HypernetworkModule ( size , None , self . layer_structure , self . activation_func , self . weight_init , self . add_layer_norm , self . use_dropout , self . activate_output ) ,
HypernetworkModule ( size , None , self . layer_structure , self . activation_func , self . weight_init ,
self . add_layer_norm , self . use_dropout , self . activate_output , last_layer_dropout = self . last_layer_dropout ) ,
HypernetworkModule ( size , None , self . layer_structure , self . activation_func , self . weight_init ,
self . add_layer_norm , self . use_dropout , self . activate_output , last_layer_dropout = self . last_layer_dropout ) ,
)
def weights ( self ) :
@ -172,7 +176,8 @@ class Hypernetwork:
state_dict [ ' sd_checkpoint ' ] = self . sd_checkpoint
state_dict [ ' sd_checkpoint_name ' ] = self . sd_checkpoint_name
state_dict [ ' activate_output ' ] = self . activate_output
state_dict [ ' last_layer_dropout ' ] = self . last_layer_dropout
torch . save ( state_dict , filename )
def load ( self , filename ) :
@ -193,12 +198,16 @@ class Hypernetwork:
self . use_dropout = state_dict . get ( ' use_dropout ' , False )
print ( f " Dropout usage is set to { self . use_dropout } " )
self . activate_output = state_dict . get ( ' activate_output ' , True )
print ( f " Activate last layer is set to { self . activate_output } " )
self . last_layer_dropout = state_dict . get ( ' last_layer_dropout ' , False )
for size , sd in state_dict . items ( ) :
if type ( size ) == int :
self . layers [ size ] = (
HypernetworkModule ( size , sd [ 0 ] , self . layer_structure , self . activation_func , self . weight_init , self . add_layer_norm , self . use_dropout , self . activate_output ) ,
HypernetworkModule ( size , sd [ 1 ] , self . layer_structure , self . activation_func , self . weight_init , self . add_layer_norm , self . use_dropout , self . activate_output ) ,
HypernetworkModule ( size , sd [ 0 ] , self . layer_structure , self . activation_func , self . weight_init ,
self . add_layer_norm , self . use_dropout , self . activate_output , last_layer_dropout = self . last_layer_dropout ) ,
HypernetworkModule ( size , sd [ 1 ] , self . layer_structure , self . activation_func , self . weight_init ,
self . add_layer_norm , self . use_dropout , self . activate_output , last_layer_dropout = self . last_layer_dropout ) ,
)
self . name = state_dict . get ( ' name ' , self . name )