|
|
|
|
@ -433,15 +433,15 @@ if os.path.exists(cmd_opts.gfpgan_dir):
|
|
|
|
|
print(traceback.format_exc(), file=sys.stderr)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TextInversionEmbeddings:
|
|
|
|
|
class StableDiffuionModelHijack:
|
|
|
|
|
ids_lookup = {}
|
|
|
|
|
word_embeddings = {}
|
|
|
|
|
word_embeddings_checksums = {}
|
|
|
|
|
fixes = []
|
|
|
|
|
fixes = None
|
|
|
|
|
used_custom_terms = []
|
|
|
|
|
dir_mtime = None
|
|
|
|
|
|
|
|
|
|
def load(self, dir, model):
|
|
|
|
|
def load_textual_inversion_embeddings(self, dir, model):
|
|
|
|
|
mt = os.path.getmtime(dir)
|
|
|
|
|
if self.dir_mtime is not None and mt <= self.dir_mtime:
|
|
|
|
|
return
|
|
|
|
|
@ -469,6 +469,7 @@ class TextInversionEmbeddings:
|
|
|
|
|
self.word_embeddings_checksums[name] = f'{const_hash(emb)&0xffff:04x}'
|
|
|
|
|
|
|
|
|
|
ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
|
|
|
|
|
|
|
|
|
|
first_id = ids[0]
|
|
|
|
|
if first_id not in self.ids_lookup:
|
|
|
|
|
self.ids_lookup[first_id] = []
|
|
|
|
|
@ -497,6 +498,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|
|
|
|
self.embeddings = embeddings
|
|
|
|
|
self.tokenizer = wrapped.tokenizer
|
|
|
|
|
self.max_length = wrapped.max_length
|
|
|
|
|
self.token_mults = {}
|
|
|
|
|
|
|
|
|
|
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
|
|
|
|
|
for text, ident in tokens_with_parens:
|
|
|
|
|
mult = 1.0
|
|
|
|
|
for c in text:
|
|
|
|
|
if c == '[':
|
|
|
|
|
mult /= 1.1
|
|
|
|
|
if c == ']':
|
|
|
|
|
mult *= 1.1
|
|
|
|
|
if c == '(':
|
|
|
|
|
mult *= 1.1
|
|
|
|
|
if c == ')':
|
|
|
|
|
mult /= 1.1
|
|
|
|
|
|
|
|
|
|
if mult != 1.0:
|
|
|
|
|
self.token_mults[ident] = mult
|
|
|
|
|
|
|
|
|
|
def forward(self, text):
|
|
|
|
|
self.embeddings.fixes = []
|
|
|
|
|
@ -508,14 +526,17 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
cache = {}
|
|
|
|
|
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
|
|
|
|
|
batch_multipliers = []
|
|
|
|
|
for tokens in batch_tokens:
|
|
|
|
|
tuple_tokens = tuple(tokens)
|
|
|
|
|
|
|
|
|
|
if tuple_tokens in cache:
|
|
|
|
|
remade_tokens, fixes = cache[tuple_tokens]
|
|
|
|
|
remade_tokens, fixes, multipliers = cache[tuple_tokens]
|
|
|
|
|
else:
|
|
|
|
|
fixes = []
|
|
|
|
|
remade_tokens = []
|
|
|
|
|
multipliers = []
|
|
|
|
|
mult = 1.0
|
|
|
|
|
|
|
|
|
|
i = 0
|
|
|
|
|
while i < len(tokens):
|
|
|
|
|
@ -523,14 +544,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
possible_matches = self.embeddings.ids_lookup.get(token, None)
|
|
|
|
|
|
|
|
|
|
if possible_matches is None:
|
|
|
|
|
mult_change = self.token_mults.get(token)
|
|
|
|
|
if mult_change is not None:
|
|
|
|
|
mult *= mult_change
|
|
|
|
|
elif possible_matches is None:
|
|
|
|
|
remade_tokens.append(token)
|
|
|
|
|
multipliers.append(mult)
|
|
|
|
|
else:
|
|
|
|
|
found = False
|
|
|
|
|
for ids, word in possible_matches:
|
|
|
|
|
if tokens[i:i+len(ids)] == ids:
|
|
|
|
|
fixes.append((len(remade_tokens), word))
|
|
|
|
|
remade_tokens.append(777)
|
|
|
|
|
multipliers.append(mult)
|
|
|
|
|
i += len(ids) - 1
|
|
|
|
|
found = True
|
|
|
|
|
self.embeddings.used_custom_terms.append((word, self.embeddings.word_embeddings_checksums[word]))
|
|
|
|
|
@ -538,19 +564,32 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
if not found:
|
|
|
|
|
remade_tokens.append(token)
|
|
|
|
|
multipliers.append(mult)
|
|
|
|
|
|
|
|
|
|
i += 1
|
|
|
|
|
|
|
|
|
|
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
|
|
|
|
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
|
|
|
|
|
cache[tuple_tokens] = (remade_tokens, fixes)
|
|
|
|
|
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
|
|
|
|
|
|
|
|
|
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
|
|
|
|
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
|
|
|
|
|
|
|
|
|
remade_batch_tokens.append(remade_tokens)
|
|
|
|
|
self.embeddings.fixes.append(fixes)
|
|
|
|
|
batch_multipliers.append(multipliers)
|
|
|
|
|
|
|
|
|
|
tokens = torch.asarray(remade_batch_tokens).to(self.wrapped.device)
|
|
|
|
|
outputs = self.wrapped.transformer(input_ids=tokens)
|
|
|
|
|
z = outputs.last_hidden_state
|
|
|
|
|
|
|
|
|
|
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
|
|
|
|
batch_multipliers = torch.asarray(np.array(batch_multipliers)).to(device)
|
|
|
|
|
original_mean = z.mean()
|
|
|
|
|
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
|
|
|
|
new_mean = z.mean()
|
|
|
|
|
z *= original_mean / new_mean
|
|
|
|
|
|
|
|
|
|
return z
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -562,22 +601,17 @@ class EmbeddingsWithFixes(nn.Module):
|
|
|
|
|
|
|
|
|
|
def forward(self, input_ids):
|
|
|
|
|
batch_fixes = self.embeddings.fixes
|
|
|
|
|
self.embeddings.fixes = []
|
|
|
|
|
self.embeddings.fixes = None
|
|
|
|
|
|
|
|
|
|
inputs_embeds = self.wrapped(input_ids)
|
|
|
|
|
|
|
|
|
|
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
|
|
|
|
for offset, word in fixes:
|
|
|
|
|
tensor[offset] = self.embeddings.word_embeddings[word]
|
|
|
|
|
|
|
|
|
|
return inputs_embeds
|
|
|
|
|
if batch_fixes is not None:
|
|
|
|
|
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
|
|
|
|
for offset, word in fixes:
|
|
|
|
|
tensor[offset] = self.embeddings.word_embeddings[word]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_learned_conditioning_with_embeddings(model, prompts):
|
|
|
|
|
if os.path.exists(cmd_opts.embeddings_dir):
|
|
|
|
|
text_inversion_embeddings.load(cmd_opts.embeddings_dir, model)
|
|
|
|
|
|
|
|
|
|
return model.get_learned_conditioning(prompts)
|
|
|
|
|
return inputs_embeds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False, extra_generation_params=None):
|
|
|
|
|
@ -648,7 +682,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index,
|
|
|
|
|
return f"{prompt}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments])
|
|
|
|
|
|
|
|
|
|
if os.path.exists(cmd_opts.embeddings_dir):
|
|
|
|
|
text_inversion_embeddings.load(cmd_opts.embeddings_dir, model)
|
|
|
|
|
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, model)
|
|
|
|
|
|
|
|
|
|
output_images = []
|
|
|
|
|
with torch.no_grad(), autocast("cuda"), model.ema_scope():
|
|
|
|
|
@ -661,8 +695,8 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index,
|
|
|
|
|
uc = model.get_learned_conditioning(len(prompts) * [""])
|
|
|
|
|
c = model.get_learned_conditioning(prompts)
|
|
|
|
|
|
|
|
|
|
if len(text_inversion_embeddings.used_custom_terms) > 0:
|
|
|
|
|
comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in text_inversion_embeddings.used_custom_terms]))
|
|
|
|
|
if len(model_hijack.used_custom_terms) > 0:
|
|
|
|
|
comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in model_hijack.used_custom_terms]))
|
|
|
|
|
|
|
|
|
|
# we manually generate all input noises because each one should have a specific seed
|
|
|
|
|
x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds)
|
|
|
|
|
@ -1060,10 +1094,9 @@ model = load_model_from_config(config, cmd_opts.ckpt)
|
|
|
|
|
|
|
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
|
|
|
model = (model if cmd_opts.no_half else model.half()).to(device)
|
|
|
|
|
text_inversion_embeddings = TextInversionEmbeddings()
|
|
|
|
|
|
|
|
|
|
if os.path.exists(cmd_opts.embeddings_dir):
|
|
|
|
|
text_inversion_embeddings.hijack(model)
|
|
|
|
|
model_hijack = StableDiffuionModelHijack()
|
|
|
|
|
model_hijack.hijack(model)
|
|
|
|
|
|
|
|
|
|
demo = gr.TabbedInterface(
|
|
|
|
|
interface_list=[x[0] for x in interfaces],
|
|
|
|
|
|