@ -11,7 +11,7 @@ from PIL import Image, PngImagePlugin
from modules import shared , devices , sd_hijack , processing , sd_models
import modules . textual_inversion . dataset
from modules . textual_inversion . learn_schedule import Learn Schedule
from modules . textual_inversion . learn_schedule import Learn Rate Scheduler
from modules . textual_inversion . image_embedding import ( embedding_to_b64 , embedding_from_b64 ,
insert_image_data_embed , extract_image_data_embed ,
@ -172,8 +172,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
def train_embedding ( embedding_name , learn_rate , data_root , log_directory , training_width , training_height , steps , num_repeats , create_image_every , save_embedding_every , template_file , save_image_with_stored_embedding , preview_image_prompt ) :
def train_embedding ( embedding_name , learn_rate , data_root , log_directory , training_width , training_height , steps , create_image_every , save_embedding_every , template_file , save_image_with_stored_embedding , preview_image_prompt ) :
assert embedding_name , ' embedding not selected '
shared . state . textinfo = " Initializing textual inversion training... "
@ -205,7 +204,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared . state . textinfo = f " Preparing dataset from { html . escape ( data_root ) } ... "
with torch . autocast ( " cuda " ) :
ds = modules . textual_inversion . dataset . PersonalizedBase ( data_root = data_root , width = training_width , height = training_height , repeats = num_repeats , placeholder_token = embedding_name , model = shared . sd_model , device = devices . device , template_file = template_file )
ds = modules . textual_inversion . dataset . PersonalizedBase ( data_root = data_root , width = training_width , height = training_height , repeats = shared. opts . training_image_repeats_per_epoch , placeholder_token = embedding_name , model = shared . sd_model , device = devices . device , template_file = template_file )
hijack = sd_hijack . model_hijack
@ -221,32 +220,24 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if ititial_step > steps :
return embedding , filename
schedules = iter ( LearnSchedule ( learn_rate , steps , ititial_step ) )
( learn_rate , end_step ) = next ( schedules )
print ( f ' Training at rate of { learn_rate } until step { end_step } ' )
optimizer = torch . optim . AdamW ( [ embedding . vec ] , lr = learn_rate )
scheduler = LearnRateScheduler ( learn_rate , steps , ititial_step )
optimizer = torch . optim . AdamW ( [ embedding . vec ] , lr = scheduler . learn_rate )
pbar = tqdm . tqdm ( enumerate ( ds ) , total = steps - ititial_step )
for i , ( x , text , _ ) in pbar :
for i , entry in pbar :
embedding . step = i + ititial_step
if embedding . step > end_step :
try :
( learn_rate , end_step ) = next ( schedules )
except :
break
tqdm . tqdm . write ( f ' Training at rate of { learn_rate } until step { end_step } ' )
for pg in optimizer . param_groups :
pg [ ' lr ' ] = learn_rate
scheduler . apply ( optimizer , embedding . step )
if scheduler . finished :
break
if shared . state . interrupted :
break
with torch . autocast ( " cuda " ) :
c = cond_model ( [ text] )
c = cond_model ( [ entry. cond_ text] )
x = x . to ( devices . device )
x = entry. latent . to ( devices . device )
loss = shared . sd_model ( x . unsqueeze ( 0 ) , c ) [ 0 ]
del x
@ -268,7 +259,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if embedding . step > 0 and images_dir is not None and embedding . step % create_image_every == 0 :
last_saved_image = os . path . join ( images_dir , f ' { embedding_name } - { embedding . step } .png ' )
preview_text = text if preview_image_prompt == " " else preview_image_prompt
preview_text = entry. cond_ text if preview_image_prompt == " " else preview_image_prompt
p = processing . StableDiffusionProcessingTxt2Img (
sd_model = shared . sd_model ,
@ -314,7 +305,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
< p >
Loss : { losses . mean ( ) : .7 f } < br / >
Step : { embedding . step } < br / >
Last prompt : { html . escape ( text) } < br / >
Last prompt : { html . escape ( entry. cond_ text) } < br / >
Last saved embedding : { html . escape ( last_saved_file ) } < br / >
Last saved image : { html . escape ( last_saved_image ) } < br / >
< / p >