|
|
|
|
@ -402,10 +402,8 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
|
|
|
|
|
|
|
|
|
|
shared.reload_hypernetworks()
|
|
|
|
|
|
|
|
|
|
return fn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
|
|
|
|
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
|
|
|
|
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
|
|
|
|
from modules import images
|
|
|
|
|
|
|
|
|
|
@ -448,6 +446,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|
|
|
|
return hypernetwork, filename
|
|
|
|
|
|
|
|
|
|
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
|
|
|
|
|
|
|
|
|
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
|
|
|
|
|
if clip_grad:
|
|
|
|
|
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
|
|
|
|
|
|
|
|
|
# dataset loading may take a while, so input validations and early returns should be done before this
|
|
|
|
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
|
|
|
|
@ -466,7 +468,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|
|
|
|
shared.parallel_processing_allowed = False
|
|
|
|
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
|
|
|
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
weights = hypernetwork.weights()
|
|
|
|
|
hypernetwork.train_mode()
|
|
|
|
|
|
|
|
|
|
@ -525,6 +527,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|
|
|
|
if shared.state.interrupted:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
if clip_grad:
|
|
|
|
|
clip_grad_sched.step(hypernetwork.step)
|
|
|
|
|
|
|
|
|
|
with devices.autocast():
|
|
|
|
|
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
|
|
|
|
if tag_drop_out != 0 or shuffle_tags:
|
|
|
|
|
@ -539,14 +544,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|
|
|
|
|
|
|
|
|
_loss_step += loss.item()
|
|
|
|
|
scaler.scale(loss).backward()
|
|
|
|
|
|
|
|
|
|
# go back until we reach gradient accumulation steps
|
|
|
|
|
if (j + 1) % gradient_step != 0:
|
|
|
|
|
continue
|
|
|
|
|
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}")
|
|
|
|
|
# scaler.unscale_(optimizer)
|
|
|
|
|
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
|
|
|
|
|
# torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0)
|
|
|
|
|
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
|
|
|
|
|
|
|
|
|
|
if clip_grad:
|
|
|
|
|
clip_grad(weights, clip_grad_sched.learn_rate)
|
|
|
|
|
|
|
|
|
|
scaler.step(optimizer)
|
|
|
|
|
scaler.update()
|
|
|
|
|
hypernetwork.step += 1
|
|
|
|
|
|