|
|
|
|
@ -251,8 +251,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
|
|
|
|
|
if save_model_every or create_image_every:
|
|
|
|
|
assert log_directory, "Log directory is empty"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
|
|
|
|
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
|
|
|
|
save_embedding_every = save_embedding_every or 0
|
|
|
|
|
create_image_every = create_image_every or 0
|
|
|
|
|
validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
|
|
|
|
|
@ -295,6 +294,11 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
|
|
|
|
return embedding, filename
|
|
|
|
|
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
|
|
|
|
|
|
|
|
|
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
|
|
|
|
|
torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
|
|
|
|
|
None
|
|
|
|
|
if clip_grad:
|
|
|
|
|
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
|
|
|
|
|
# dataset loading may take a while, so input validations and early returns should be done before this
|
|
|
|
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
|
|
|
|
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
|
|
|
|
@ -361,6 +365,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
|
|
|
|
if shared.state.interrupted:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
if clip_grad:
|
|
|
|
|
clip_grad_sched.step(embedding.step)
|
|
|
|
|
|
|
|
|
|
with devices.autocast():
|
|
|
|
|
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
|
|
|
|
c = shared.sd_model.cond_stage_model(batch.cond_text)
|
|
|
|
|
@ -382,6 +389,10 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
|
|
|
|
# go back until we reach gradient accumulation steps
|
|
|
|
|
if (j + 1) % gradient_step != 0:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if clip_grad:
|
|
|
|
|
clip_grad(embedding.vec, clip_grad_sched.learn_rate)
|
|
|
|
|
|
|
|
|
|
scaler.step(optimizer)
|
|
|
|
|
scaler.update()
|
|
|
|
|
embedding.step += 1
|
|
|
|
|
|