|
|
|
|
@ -383,11 +383,15 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|
|
|
|
ititial_step = hypernetwork.step or 0
|
|
|
|
|
if ititial_step > steps:
|
|
|
|
|
return hypernetwork, filename
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clip_grad_mode_value = clip_grad_mode == "value"
|
|
|
|
|
clip_grad_mode_norm = clip_grad_mode == "norm"
|
|
|
|
|
clip_grad_enabled = clip_grad_mode_value or clip_grad_mode_norm
|
|
|
|
|
if clip_grad_enabled:
|
|
|
|
|
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
|
|
|
|
|
|
|
|
|
|
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
|
|
|
|
|
|
|
|
|
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
|
|
|
|
|
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
|
|
|
|
|
|
|
|
|
|
@ -407,6 +411,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|
|
|
|
if shared.state.interrupted:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
if clip_grad_enabled:
|
|
|
|
|
clip_grad_sched.step(hypernetwork.step)
|
|
|
|
|
|
|
|
|
|
with torch.autocast("cuda"):
|
|
|
|
|
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
|
|
|
|
|
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
|
|
|
|
|
@ -430,9 +437,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|
|
|
|
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
|
|
|
|
|
|
|
|
|
|
if clip_grad_mode_value:
|
|
|
|
|
torch.nn.utils.clip_grad_value_(weights, clip_value=clip_grad_value)
|
|
|
|
|
torch.nn.utils.clip_grad_value_(weights, clip_value=clip_grad_sched.learn_rate)
|
|
|
|
|
elif clip_grad_mode_norm:
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(weights, max_norm=clip_grad_value)
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(weights, max_norm=clip_grad_sched.learn_rate)
|
|
|
|
|
|
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
|
|
|