@ -22,6 +22,8 @@ from collections import defaultdict, deque
from statistics import stdev , mean
optimizer_dict = { optim_name : cls_obj for optim_name , cls_obj in inspect . getmembers ( torch . optim , inspect . isclass ) if optim_name != " Optimizer " }
class HypernetworkModule ( torch . nn . Module ) :
multiplier = 1.0
activation_dict = {
@ -142,6 +144,8 @@ class Hypernetwork:
self . use_dropout = use_dropout
self . activate_output = activate_output
self . last_layer_dropout = kwargs [ ' last_layer_dropout ' ] if ' last_layer_dropout ' in kwargs else True
self . optimizer_name = None
self . optimizer_state_dict = None
for size in enable_sizes or [ ] :
self . layers [ size ] = (
@ -163,6 +167,7 @@ class Hypernetwork:
def save ( self , filename ) :
state_dict = { }
optimizer_saved_dict = { }
for k , v in self . layers . items ( ) :
state_dict [ k ] = ( v [ 0 ] . state_dict ( ) , v [ 1 ] . state_dict ( ) )
@ -178,8 +183,15 @@ class Hypernetwork:
state_dict [ ' sd_checkpoint_name ' ] = self . sd_checkpoint_name
state_dict [ ' activate_output ' ] = self . activate_output
state_dict [ ' last_layer_dropout ' ] = self . last_layer_dropout
if self . optimizer_name is not None :
optimizer_saved_dict [ ' optimizer_name ' ] = self . optimizer_name
torch . save ( state_dict , filename )
if shared . opts . save_optimizer_state and self . optimizer_state_dict :
optimizer_saved_dict [ ' hash ' ] = sd_models . model_hash ( filename )
optimizer_saved_dict [ ' optimizer_state_dict ' ] = self . optimizer_state_dict
torch . save ( optimizer_saved_dict , filename + ' .optim ' )
def load ( self , filename ) :
self . filename = filename
@ -202,6 +214,18 @@ class Hypernetwork:
print ( f " Activate last layer is set to { self . activate_output } " )
self . last_layer_dropout = state_dict . get ( ' last_layer_dropout ' , False )
optimizer_saved_dict = torch . load ( self . filename + ' .optim ' , map_location = ' cpu ' ) if os . path . exists ( self . filename + ' .optim ' ) else { }
self . optimizer_name = optimizer_saved_dict . get ( ' optimizer_name ' , ' AdamW ' )
print ( f " Optimizer name is { self . optimizer_name } " )
if sd_models . model_hash ( filename ) == optimizer_saved_dict . get ( ' hash ' , None ) :
self . optimizer_state_dict = optimizer_saved_dict . get ( ' optimizer_state_dict ' , None )
else :
self . optimizer_state_dict = None
if self . optimizer_state_dict :
print ( " Loaded existing optimizer from checkpoint " )
else :
print ( " No saved optimizer exists in checkpoint " )
for size , sd in state_dict . items ( ) :
if type ( size ) == int :
self . layers [ size ] = (
@ -219,11 +243,11 @@ class Hypernetwork:
def list_hypernetworks ( path ) :
res = { }
for filename in glob . iglob ( os . path . join ( path , ' **/*.pt ' ) , recursive = True ) :
for filename in sorted ( glob . iglob ( os . path . join ( path , ' **/*.pt ' ) , recursive = True ) ) :
name = os . path . splitext ( os . path . basename ( filename ) ) [ 0 ]
# Prevent a hypothetical "None.pt" from being listed.
if name != " None " :
res [ name ] = filename
res [ name + f " ( { sd_models . model_hash ( filename ) } ) " ] = filename
return res
@ -358,6 +382,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
shared . state . textinfo = " Initializing hypernetwork training... "
shared . state . job_count = steps
hypernetwork_name = hypernetwork_name . rsplit ( ' ( ' , 1 ) [ 0 ]
filename = os . path . join ( shared . cmd_opts . hypernetwork_dir , f ' { hypernetwork_name } .pt ' )
log_directory = os . path . join ( log_directory , datetime . datetime . now ( ) . strftime ( " % Y- % m- %d " ) , hypernetwork_name )
@ -404,8 +429,22 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
weights = hypernetwork . weights ( )
for weight in weights :
weight . requires_grad = True
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
optimizer = torch . optim . AdamW ( weights , lr = scheduler . learn_rate )
# Here we use optimizer from saved HN, or we can specify as UI option.
if hypernetwork . optimizer_name in optimizer_dict :
optimizer = optimizer_dict [ hypernetwork . optimizer_name ] ( params = weights , lr = scheduler . learn_rate )
optimizer_name = hypernetwork . optimizer_name
else :
print ( f " Optimizer type { hypernetwork . optimizer_name } is not defined! " )
optimizer = torch . optim . AdamW ( params = weights , lr = scheduler . learn_rate )
optimizer_name = ' AdamW '
if hypernetwork . optimizer_state_dict : # This line must be changed if Optimizer type can be different from saved optimizer.
try :
optimizer . load_state_dict ( hypernetwork . optimizer_state_dict )
except RuntimeError as e :
print ( " Cannot resume from saved optimizer! " )
print ( e )
steps_without_grad = 0
@ -467,7 +506,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
# Before saving, change name to match current checkpoint.
hypernetwork_name_every = f ' { hypernetwork_name } - { steps_done } '
last_saved_file = os . path . join ( hypernetwork_dir , f ' { hypernetwork_name_every } .pt ' )
hypernetwork . optimizer_name = optimizer_name
if shared . opts . save_optimizer_state :
hypernetwork . optimizer_state_dict = optimizer . state_dict ( )
save_hypernetwork ( hypernetwork , checkpoint , hypernetwork_name , last_saved_file )
hypernetwork . optimizer_state_dict = None # dereference it after saving, to save memory.
textual_inversion . write_loss ( log_directory , " hypernetwork_loss.csv " , hypernetwork . step , len ( ds ) , {
" loss " : f " { previous_mean_loss : .7f } " ,
@ -530,8 +573,12 @@ Last saved image: {html.escape(last_saved_image)}<br/>
report_statistics ( loss_dict )
filename = os . path . join ( shared . cmd_opts . hypernetwork_dir , f ' { hypernetwork_name } .pt ' )
hypernetwork . optimizer_name = optimizer_name
if shared . opts . save_optimizer_state :
hypernetwork . optimizer_state_dict = optimizer . state_dict ( )
save_hypernetwork ( hypernetwork , checkpoint , hypernetwork_name , filename )
del optimizer
hypernetwork . optimizer_state_dict = None # dereference it after saving, to save memory.
return hypernetwork , filename
def save_hypernetwork ( hypernetwork , checkpoint , hypernetwork_name , filename ) :