@ -1,6 +1,7 @@
import os
import sys
import traceback
import inspect
import torch
import tqdm
@ -229,6 +230,28 @@ def write_loss(log_directory, filename, step, epoch_len, values):
* * values ,
} )
def save_settings_to_file ( initial_step , num_of_dataset_images , embedding_name , vectors_per_token , learn_rate , batch_size , data_root , log_directory , training_width , training_height , steps , create_image_every , save_embedding_every , template_file , save_image_with_stored_embedding , preview_from_txt2img , preview_prompt , preview_negative_prompt , preview_steps , preview_sampler_index , preview_cfg_scale , preview_seed , preview_width , preview_height ) :
checkpoint = sd_models . select_checkpoint ( )
model_name = checkpoint . model_name
model_hash = ' [ {} ] ' . format ( checkpoint . hash )
# Get a list of the argument names.
arg_names = inspect . getfullargspec ( save_settings_to_file ) . args
# Create a list of the argument names to include in the settings string.
names = arg_names [ : 16 ] # Include all arguments up until the preview-related ones.
if preview_from_txt2img :
names . extend ( arg_names [ 16 : ] ) # Include all remaining arguments if `preview_from_txt2img` is True.
# Build the settings string.
settings_str = " datetime : " + datetime . datetime . now ( ) . strftime ( " % Y- % m- %d % H: % M: % S " ) + " \n "
for name in names :
value = locals ( ) [ name ]
settings_str + = f " { name } : { value } \n "
with open ( os . path . join ( log_directory , ' settings.txt ' ) , " a+ " ) as fout :
fout . write ( settings_str + " \n \n " )
def validate_train_inputs ( model_name , learn_rate , batch_size , gradient_step , data_root , template_file , steps , save_model_every , create_image_every , log_directory , name = " embedding " ) :
assert model_name , f " { name } not selected "
assert learn_rate , " Learning rate is empty or 0 "
@ -292,13 +315,13 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
if initial_step > = steps :
shared . state . textinfo = " Model has already been trained beyond specified max steps "
return embedding , filename
scheduler = LearnRateScheduler ( learn_rate , steps , initial_step )
clip_grad = torch . nn . utils . clip_grad_value_ if clip_grad_mode == " value " else \
torch . nn . utils . clip_grad_norm_ if clip_grad_mode == " norm " else \
None
if clip_grad :
clip_grad_sched = LearnRateScheduler ( clip_grad_value , steps , i t itial_step, verbose = False )
clip_grad_sched = LearnRateScheduler ( clip_grad_value , steps , i n itial_step, verbose = False )
# dataset loading may take a while, so input validations and early returns should be done before this
shared . state . textinfo = f " Preparing dataset from { html . escape ( data_root ) } ... "
old_parallel_processing_allowed = shared . parallel_processing_allowed
@ -306,7 +329,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
pin_memory = shared . opts . pin_memory
ds = modules . textual_inversion . dataset . PersonalizedBase ( data_root = data_root , width = training_width , height = training_height , repeats = shared . opts . training_image_repeats_per_epoch , placeholder_token = embedding_name , model = shared . sd_model , cond_model = shared . sd_model . cond_stage_model , device = devices . device , template_file = template_file , batch_size = batch_size , gradient_step = gradient_step , shuffle_tags = shuffle_tags , tag_drop_out = tag_drop_out , latent_sampling_method = latent_sampling_method )
if shared . opts . save_train_settings_to_txt :
save_settings_to_file ( initial_step , len ( ds ) , embedding_name , len ( embedding . vec ) , learn_rate , batch_size , data_root , log_directory , training_width , training_height , steps , create_image_every , save_embedding_every , template_file , save_image_with_stored_embedding , preview_from_txt2img , preview_prompt , preview_negative_prompt , preview_steps , preview_sampler_index , preview_cfg_scale , preview_seed , preview_width , preview_height )
latent_sampling_method = ds . latent_sampling_method
dl = modules . textual_inversion . dataset . PersonalizedDataLoader ( ds , latent_sampling_method = latent_sampling_method , batch_size = ds . batch_size , pin_memory = pin_memory )