|
|
|
|
@ -335,6 +335,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|
|
|
|
size = len(ds.indexes)
|
|
|
|
|
loss_dict = defaultdict(lambda : deque(maxlen = 1024))
|
|
|
|
|
losses = torch.zeros((size,))
|
|
|
|
|
previous_mean_losses = [0]
|
|
|
|
|
previous_mean_loss = 0
|
|
|
|
|
print("Mean loss of {} elements".format(size))
|
|
|
|
|
|
|
|
|
|
@ -356,7 +357,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|
|
|
|
for i, entries in pbar:
|
|
|
|
|
hypernetwork.step = i + ititial_step
|
|
|
|
|
if len(loss_dict) > 0:
|
|
|
|
|
previous_mean_loss = sum(i[-1] for i in loss_dict.values()) / len(loss_dict)
|
|
|
|
|
previous_mean_losses = [i[-1] for i in loss_dict.values()]
|
|
|
|
|
previous_mean_loss = mean(previous_mean_losses)
|
|
|
|
|
|
|
|
|
|
scheduler.apply(optimizer, hypernetwork.step)
|
|
|
|
|
if scheduler.finished:
|
|
|
|
|
@ -391,7 +393,13 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|
|
|
|
|
|
|
|
|
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
|
|
|
|
|
raise RuntimeError("Loss diverged.")
|
|
|
|
|
pbar.set_description(f"dataset loss: {previous_mean_loss:.7f}")
|
|
|
|
|
|
|
|
|
|
if len(previous_mean_losses) > 1:
|
|
|
|
|
std = stdev(previous_mean_losses)
|
|
|
|
|
else:
|
|
|
|
|
std = 0
|
|
|
|
|
dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
|
|
|
|
|
pbar.set_description(dataset_loss_info)
|
|
|
|
|
|
|
|
|
|
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
|
|
|
|
|
# Before saving, change name to match current checkpoint.
|
|
|
|
|
|