@ -35,7 +35,8 @@ class HypernetworkModule(torch.nn.Module):
}
activation_dict . update ( { cls_name . lower ( ) : cls_obj for cls_name , cls_obj in inspect . getmembers ( torch . nn . modules . activation ) if inspect . isclass ( cls_obj ) and cls_obj . __module__ == ' torch.nn.modules.activation ' } )
def __init__ ( self , dim , state_dict = None , layer_structure = None , activation_func = None , weight_init = ' Normal ' , add_layer_norm = False , use_dropout = False ) :
def __init__ ( self , dim , state_dict = None , layer_structure = None , activation_func = None , weight_init = ' Normal ' ,
add_layer_norm = False , use_dropout = False , activate_output = False , last_layer_dropout = True ) :
super ( ) . __init__ ( )
assert layer_structure is not None , " layer_structure must not be None "
@ -48,8 +49,8 @@ class HypernetworkModule(torch.nn.Module):
# Add a fully-connected layer
linears . append ( torch . nn . Linear ( int ( dim * layer_structure [ i ] ) , int ( dim * layer_structure [ i + 1 ] ) ) )
# Add an activation func
if activation_func == " linear " or activation_func is None :
# Add an activation func except last layer
if activation_func == " linear " or activation_func is None or ( i > = len ( layer_structure ) - 2 and not activate_output ) :
pass
elif activation_func in self . activation_dict :
linears . append ( self . activation_dict [ activation_func ] ( ) )
@ -60,8 +61,8 @@ class HypernetworkModule(torch.nn.Module):
if add_layer_norm :
linears . append ( torch . nn . LayerNorm ( int ( dim * layer_structure [ i + 1 ] ) ) )
# Add dropout ex pe ct last layer
if use_dropout and i < len ( layer_structure ) - 3 :
# Add dropout ex cep t last layer
if use_dropout and ( i < len ( layer_structure ) - 3 or last_layer_dropout and i < len ( layer_structure ) - 2 ) :
linears . append ( torch . nn . Dropout ( p = 0.3 ) )
self . linear = torch . nn . Sequential ( * linears )
@ -75,7 +76,7 @@ class HypernetworkModule(torch.nn.Module):
w , b = layer . weight . data , layer . bias . data
if weight_init == " Normal " or type ( layer ) == torch . nn . LayerNorm :
normal_ ( w , mean = 0.0 , std = 0.01 )
normal_ ( b , mean = 0.0 , std = 0.005 )
normal_ ( b , mean = 0.0 , std = 0 )
elif weight_init == ' XavierUniform ' :
xavier_uniform_ ( w )
zeros_ ( b )
@ -127,7 +128,7 @@ class Hypernetwork:
filename = None
name = None
def __init__ ( self , name = None , enable_sizes = None , layer_structure = None , activation_func = None , weight_init = None , add_layer_norm = False , use_dropout = False ):
def __init__ ( self , name = None , enable_sizes = None , layer_structure = None , activation_func = None , weight_init = None , add_layer_norm = False , use_dropout = False , activate_output = False , * * kwargs ):
self . filename = None
self . name = name
self . layers = { }
@ -139,11 +140,15 @@ class Hypernetwork:
self . weight_init = weight_init
self . add_layer_norm = add_layer_norm
self . use_dropout = use_dropout
self . activate_output = activate_output
self . last_layer_dropout = kwargs [ ' last_layer_dropout ' ] if ' last_layer_dropout ' in kwargs else True
for size in enable_sizes or [ ] :
self . layers [ size ] = (
HypernetworkModule ( size , None , self . layer_structure , self . activation_func , self . weight_init , self . add_layer_norm , self . use_dropout ) ,
HypernetworkModule ( size , None , self . layer_structure , self . activation_func , self . weight_init , self . add_layer_norm , self . use_dropout ) ,
HypernetworkModule ( size , None , self . layer_structure , self . activation_func , self . weight_init ,
self . add_layer_norm , self . use_dropout , self . activate_output , last_layer_dropout = self . last_layer_dropout ) ,
HypernetworkModule ( size , None , self . layer_structure , self . activation_func , self . weight_init ,
self . add_layer_norm , self . use_dropout , self . activate_output , last_layer_dropout = self . last_layer_dropout ) ,
)
def weights ( self ) :
@ -171,7 +176,9 @@ class Hypernetwork:
state_dict [ ' use_dropout ' ] = self . use_dropout
state_dict [ ' sd_checkpoint ' ] = self . sd_checkpoint
state_dict [ ' sd_checkpoint_name ' ] = self . sd_checkpoint_name
state_dict [ ' activate_output ' ] = self . activate_output
state_dict [ ' last_layer_dropout ' ] = self . last_layer_dropout
torch . save ( state_dict , filename )
def load ( self , filename ) :
@ -191,12 +198,17 @@ class Hypernetwork:
print ( f " Layer norm is set to { self . add_layer_norm } " )
self . use_dropout = state_dict . get ( ' use_dropout ' , False )
print ( f " Dropout usage is set to { self . use_dropout } " )
self . activate_output = state_dict . get ( ' activate_output ' , True )
print ( f " Activate last layer is set to { self . activate_output } " )
self . last_layer_dropout = state_dict . get ( ' last_layer_dropout ' , False )
for size , sd in state_dict . items ( ) :
if type ( size ) == int :
self . layers [ size ] = (
HypernetworkModule ( size , sd [ 0 ] , self . layer_structure , self . activation_func , self . weight_init , self . add_layer_norm , self . use_dropout ) ,
HypernetworkModule ( size , sd [ 1 ] , self . layer_structure , self . activation_func , self . weight_init , self . add_layer_norm , self . use_dropout ) ,
HypernetworkModule ( size , sd [ 0 ] , self . layer_structure , self . activation_func , self . weight_init ,
self . add_layer_norm , self . use_dropout , self . activate_output , last_layer_dropout = self . last_layer_dropout ) ,
HypernetworkModule ( size , sd [ 1 ] , self . layer_structure , self . activation_func , self . weight_init ,
self . add_layer_norm , self . use_dropout , self . activate_output , last_layer_dropout = self . last_layer_dropout ) ,
)
self . name = state_dict . get ( ' name ' , self . name )