|
|
|
|
@ -6,6 +6,7 @@ import torch
|
|
|
|
|
import tqdm
|
|
|
|
|
import html
|
|
|
|
|
import datetime
|
|
|
|
|
import math
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from modules import shared, devices, sd_hijack, processing, sd_models
|
|
|
|
|
@ -156,7 +157,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
|
|
|
|
|
return fn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
|
|
|
|
|
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_size, steps, num_repeats, create_image_every, save_embedding_every, template_file):
|
|
|
|
|
assert embedding_name, 'embedding not selected'
|
|
|
|
|
|
|
|
|
|
shared.state.textinfo = "Initializing textual inversion training..."
|
|
|
|
|
@ -182,7 +183,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
|
|
|
|
|
|
|
|
|
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
|
|
|
|
with torch.autocast("cuda"):
|
|
|
|
|
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
|
|
|
|
|
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=training_size, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
|
|
|
|
|
|
|
|
|
|
hijack = sd_hijack.model_hijack
|
|
|
|
|
|
|
|
|
|
@ -200,6 +201,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
|
|
|
|
|
if ititial_step > steps:
|
|
|
|
|
return embedding, filename
|
|
|
|
|
|
|
|
|
|
tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)])
|
|
|
|
|
epoch_len = (tr_img_len * num_repeats) + tr_img_len
|
|
|
|
|
|
|
|
|
|
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
|
|
|
|
for i, (x, text) in pbar:
|
|
|
|
|
embedding.step = i + ititial_step
|
|
|
|
|
@ -223,7 +227,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
|
|
|
|
|
loss.backward()
|
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
|
|
pbar.set_description(f"loss: {losses.mean():.7f}")
|
|
|
|
|
epoch_num = math.floor(embedding.step / epoch_len)
|
|
|
|
|
epoch_step = embedding.step - (epoch_num * epoch_len)
|
|
|
|
|
|
|
|
|
|
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}")
|
|
|
|
|
|
|
|
|
|
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
|
|
|
|
|
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
|
|
|
|
|
@ -236,6 +243,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
|
|
|
|
|
sd_model=shared.sd_model,
|
|
|
|
|
prompt=text,
|
|
|
|
|
steps=20,
|
|
|
|
|
height=training_size,
|
|
|
|
|
width=training_size,
|
|
|
|
|
do_not_save_grid=True,
|
|
|
|
|
do_not_save_samples=True,
|
|
|
|
|
)
|
|
|
|
|
|