@ -224,7 +224,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, data_root, templat
if save_model_every or create_image_every :
assert log_directory , " Log directory is empty "
def train_embedding ( embedding_name , learn_rate , batch_size , data_root , log_directory , training_width , training_height , steps , create_image_every , save_embedding_every , template_file , save_image_with_stored_embedding , preview_from_txt2img , shuffle_tags, preview_prompt, preview_negative_prompt , preview_steps , preview_sampler_index , preview_cfg_scale , preview_seed , preview_width , preview_height ) :
def train_embedding ( embedding_name , learn_rate , batch_size , data_root , log_directory , training_width , training_height , steps , create_image_every , save_embedding_every , template_file , save_image_with_stored_embedding , preview_from_txt2img , preview_prompt, preview_negative_prompt , preview_steps , preview_sampler_index , preview_cfg_scale , preview_seed , preview_width , preview_height ) :
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
validate_train_inputs ( embedding_name , learn_rate , batch_size , data_root , template_file , steps , save_embedding_every , create_image_every , log_directory , name = " embedding " )
@ -272,7 +272,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
# dataset loading may take a while, so input validations and early returns should be done before this
shared . state . textinfo = f " Preparing dataset from { html . escape ( data_root ) } ... "
with torch . autocast ( " cuda " ) :
ds = modules . textual_inversion . dataset . PersonalizedBase ( data_root = data_root , width = training_width , height = training_height , repeats = shared . opts . training_image_repeats_per_epoch , placeholder_token = embedding_name , shuffle_tags= shuffle_tags , model= shared . sd_model , device = devices . device , template_file = template_file , batch_size = batch_size )
ds = modules . textual_inversion . dataset . PersonalizedBase ( data_root = data_root , width = training_width , height = training_height , repeats = shared . opts . training_image_repeats_per_epoch , placeholder_token = embedding_name , model= shared . sd_model , device = devices . device , template_file = template_file , batch_size = batch_size )
if unload :
shared . sd_model . first_stage_model . to ( devices . cpu )