|
|
|
|
@ -7,6 +7,9 @@ import tqdm
|
|
|
|
|
import html
|
|
|
|
|
import datetime
|
|
|
|
|
|
|
|
|
|
from PIL import Image, PngImagePlugin
|
|
|
|
|
import base64
|
|
|
|
|
from io import BytesIO
|
|
|
|
|
|
|
|
|
|
from modules import shared, devices, sd_hijack, processing, sd_models
|
|
|
|
|
import modules.textual_inversion.dataset
|
|
|
|
|
@ -80,7 +83,15 @@ class EmbeddingDatabase:
|
|
|
|
|
def process_file(path, filename):
|
|
|
|
|
name = os.path.splitext(filename)[0]
|
|
|
|
|
|
|
|
|
|
data = torch.load(path, map_location="cpu")
|
|
|
|
|
data = []
|
|
|
|
|
|
|
|
|
|
if filename.upper().endswith('.PNG'):
|
|
|
|
|
embed_image = Image.open(path)
|
|
|
|
|
if 'sd-embedding' in embed_image.text:
|
|
|
|
|
embeddingData = base64.b64decode(embed_image.text['sd-embedding'])
|
|
|
|
|
data = torch.load(BytesIO(embeddingData), map_location="cpu")
|
|
|
|
|
else:
|
|
|
|
|
data = torch.load(path, map_location="cpu")
|
|
|
|
|
|
|
|
|
|
# textual inversion embeddings
|
|
|
|
|
if 'string_to_param' in data:
|
|
|
|
|
@ -156,7 +167,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
|
|
|
|
|
return fn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
|
|
|
|
|
def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding):
|
|
|
|
|
assert embedding_name, 'embedding not selected'
|
|
|
|
|
|
|
|
|
|
shared.state.textinfo = "Initializing textual inversion training..."
|
|
|
|
|
@ -244,7 +255,15 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
|
|
|
|
|
image = processed.images[0]
|
|
|
|
|
|
|
|
|
|
shared.state.current_image = image
|
|
|
|
|
image.save(last_saved_image)
|
|
|
|
|
|
|
|
|
|
if save_image_with_stored_embedding:
|
|
|
|
|
info = PngImagePlugin.PngInfo()
|
|
|
|
|
info.add_text("sd-embedding", base64.b64encode(open(last_saved_file,'rb').read()))
|
|
|
|
|
image.save(last_saved_image, "PNG", pnginfo=info)
|
|
|
|
|
else:
|
|
|
|
|
image.save(last_saved_image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
last_saved_image += f", prompt: {text}"
|
|
|
|
|
|
|
|
|
|
|