|
|
|
|
@ -6,6 +6,7 @@ import torch
|
|
|
|
|
import tqdm
|
|
|
|
|
import html
|
|
|
|
|
import datetime
|
|
|
|
|
import csv
|
|
|
|
|
|
|
|
|
|
from PIL import Image, PngImagePlugin
|
|
|
|
|
|
|
|
|
|
@ -172,7 +173,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
|
|
|
|
|
return fn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
|
|
|
|
|
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, write_csv_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
|
|
|
|
|
assert embedding_name, 'embedding not selected'
|
|
|
|
|
|
|
|
|
|
shared.state.textinfo = "Initializing textual inversion training..."
|
|
|
|
|
@ -256,6 +257,20 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
|
|
|
|
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
|
|
|
|
|
embedding.save(last_saved_file)
|
|
|
|
|
|
|
|
|
|
if write_csv_every > 0 and log_directory is not None and embedding.step % write_csv_every == 0:
|
|
|
|
|
write_csv_header = False if os.path.exists(os.path.join(log_directory, "textual_inversion_loss.csv")) else True
|
|
|
|
|
|
|
|
|
|
with open(os.path.join(log_directory, "textual_inversion_loss.csv"), "a+") as fout:
|
|
|
|
|
|
|
|
|
|
csv_writer = csv.DictWriter(fout, fieldnames=["epoch", "epoch_step", "loss"])
|
|
|
|
|
|
|
|
|
|
if write_csv_header:
|
|
|
|
|
csv_writer.writeheader()
|
|
|
|
|
|
|
|
|
|
csv_writer.writerow({"epoch": epoch_num + 1,
|
|
|
|
|
"epoch_step": epoch_step - 1,
|
|
|
|
|
"loss": f"{losses.mean():.7f}"})
|
|
|
|
|
|
|
|
|
|
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
|
|
|
|
|
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
|
|
|
|
|
|
|
|
|
|
|