@ -28,6 +28,7 @@ class Embedding:
self . cached_checksum = None
self . sd_checkpoint = None
self . sd_checkpoint_name = None
self . optimizer_state_dict = None
def save ( self , filename ) :
embedding_data = {
@ -41,6 +42,13 @@ class Embedding:
torch . save ( embedding_data , filename )
if shared . opts . save_optimizer_state and self . optimizer_state_dict is not None :
optimizer_saved_dict = {
' hash ' : self . checksum ( ) ,
' optimizer_state_dict ' : self . optimizer_state_dict ,
}
torch . save ( optimizer_saved_dict , filename + ' .optim ' )
def checksum ( self ) :
if self . cached_checksum is not None :
return self . cached_checksum
@ -95,9 +103,10 @@ class EmbeddingDatabase:
self . expected_shape = self . get_expected_shape ( )
def process_file ( path , filename ) :
name = os . path . splitext ( filename ) [ 0 ]
name , ext = os . path . splitext ( filename )
ext = ext . upper ( )
if os. path . split ext( filename . upper ( ) ) [ - 1 ] in [ ' .PNG ' , ' .WEBP ' , ' .JXL ' , ' .AVIF ' ] :
if ext in [ ' .PNG ' , ' .WEBP ' , ' .JXL ' , ' .AVIF ' ] :
embed_image = Image . open ( path )
if hasattr ( embed_image , ' text ' ) and ' sd-ti-embedding ' in embed_image . text :
data = embedding_from_b64 ( embed_image . text [ ' sd-ti-embedding ' ] )
@ -105,8 +114,10 @@ class EmbeddingDatabase:
else :
data = extract_image_data_embed ( embed_image )
name = data . get ( ' name ' , name )
el se :
el if ext in [ ' .BIN ' , ' .PT ' ] :
data = torch . load ( path , map_location = " cpu " )
else :
return
# textual inversion embeddings
if ' string_to_param ' in data :
@ -300,6 +311,20 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
embedding . vec . requires_grad = True
optimizer = torch . optim . AdamW ( [ embedding . vec ] , lr = scheduler . learn_rate , weight_decay = 0.0 )
if shared . opts . save_optimizer_state :
optimizer_state_dict = None
if os . path . exists ( filename + ' .optim ' ) :
optimizer_saved_dict = torch . load ( filename + ' .optim ' , map_location = ' cpu ' )
if embedding . checksum ( ) == optimizer_saved_dict . get ( ' hash ' , None ) :
optimizer_state_dict = optimizer_saved_dict . get ( ' optimizer_state_dict ' , None )
if optimizer_state_dict is not None :
optimizer . load_state_dict ( optimizer_state_dict )
print ( " Loaded existing optimizer from checkpoint " )
else :
print ( " No saved optimizer exists in checkpoint " )
scaler = torch . cuda . amp . GradScaler ( )
batch_size = ds . batch_size
@ -366,9 +391,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
# Before saving, change name to match current checkpoint.
embedding_name_every = f ' { embedding_name } - { steps_done } '
last_saved_file = os . path . join ( embedding_dir , f ' { embedding_name_every } .pt ' )
#if shared.opts.save_optimizer_state:
#embedding.optimizer_state_dict = optimizer.state_dict()
save_embedding ( embedding , checkpoint , embedding_name_every , last_saved_file , remove_cached_checksum = True )
save_embedding ( embedding , optimizer , checkpoint , embedding_name_every , last_saved_file , remove_cached_checksum = True )
embedding_yet_to_be_embedded = True
write_loss ( log_directory , " textual_inversion_loss.csv " , embedding . step , steps_per_epoch , {
@ -458,7 +481,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
< / p >
"""
filename = os . path . join ( shared . cmd_opts . embeddings_dir , f ' { embedding_name } .pt ' )
save_embedding ( embedding , checkpoint, embedding_name , filename , remove_cached_checksum = True )
save_embedding ( embedding , optimizer, checkpoint, embedding_name , filename , remove_cached_checksum = True )
except Exception :
print ( traceback . format_exc ( ) , file = sys . stderr )
pass
@ -470,7 +493,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
return embedding , filename
def save_embedding ( embedding , checkpoint, embedding_name , filename , remove_cached_checksum = True ) :
def save_embedding ( embedding , optimizer, checkpoint, embedding_name , filename , remove_cached_checksum = True ) :
old_embedding_name = embedding . name
old_sd_checkpoint = embedding . sd_checkpoint if hasattr ( embedding , " sd_checkpoint " ) else None
old_sd_checkpoint_name = embedding . sd_checkpoint_name if hasattr ( embedding , " sd_checkpoint_name " ) else None
@ -481,6 +504,7 @@ def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cache
if remove_cached_checksum :
embedding . cached_checksum = None
embedding . name = embedding_name
embedding . optimizer_state_dict = optimizer . state_dict ( )
embedding . save ( filename )
except :
embedding . sd_checkpoint = old_sd_checkpoint