@ -235,6 +235,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
filename = os . path . join ( shared . cmd_opts . embeddings_dir , f ' { embedding_name } .pt ' )
log_directory = os . path . join ( log_directory , datetime . datetime . now ( ) . strftime ( " % Y- % m- %d " ) , embedding_name )
unload = shared . opts . unload_models_when_training
if save_embedding_every > 0 :
embedding_dir = os . path . join ( log_directory , " embeddings " )
@ -272,6 +273,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
shared . state . textinfo = f " Preparing dataset from { html . escape ( data_root ) } ... "
with torch . autocast ( " cuda " ) :
ds = modules . textual_inversion . dataset . PersonalizedBase ( data_root = data_root , width = training_width , height = training_height , repeats = shared . opts . training_image_repeats_per_epoch , placeholder_token = embedding_name , model = shared . sd_model , device = devices . device , template_file = template_file , batch_size = batch_size )
if unload :
shared . sd_model . first_stage_model . to ( devices . cpu )
embedding . vec . requires_grad = True
optimizer = torch . optim . AdamW ( [ embedding . vec ] , lr = scheduler . learn_rate )
@ -328,6 +331,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
if images_dir is not None and steps_done % create_image_every == 0 :
forced_filename = f ' { embedding_name } - { steps_done } '
last_saved_image = os . path . join ( images_dir , forced_filename )
shared . sd_model . first_stage_model . to ( devices . device )
p = processing . StableDiffusionProcessingTxt2Img (
sd_model = shared . sd_model ,
do_not_save_grid = True ,
@ -355,6 +361,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
processed = processing . process_images ( p )
image = processed . images [ 0 ]
if unload :
shared . sd_model . first_stage_model . to ( devices . cpu )
shared . state . current_image = image
if save_image_with_stored_embedding and os . path . exists ( last_saved_file ) and embedding_yet_to_be_embedded :
@ -400,6 +409,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
filename = os . path . join ( shared . cmd_opts . embeddings_dir , f ' { embedding_name } .pt ' )
save_embedding ( embedding , checkpoint , embedding_name , filename , remove_cached_checksum = True )
shared . sd_model . first_stage_model . to ( devices . device )
return embedding , filename