|
|
|
|
@ -180,6 +180,7 @@ class StableDiffusionModelHijack:
|
|
|
|
|
dir_mtime = None
|
|
|
|
|
layers = None
|
|
|
|
|
circular_enabled = False
|
|
|
|
|
clip = None
|
|
|
|
|
|
|
|
|
|
def load_textual_inversion_embeddings(self, dirname, model):
|
|
|
|
|
mt = os.path.getmtime(dirname)
|
|
|
|
|
@ -242,6 +243,7 @@ class StableDiffusionModelHijack:
|
|
|
|
|
|
|
|
|
|
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
|
|
|
|
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
|
|
|
|
self.clip = m.cond_stage_model
|
|
|
|
|
|
|
|
|
|
if cmd_opts.opt_split_attention_v1:
|
|
|
|
|
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
|
|
|
|
@ -268,6 +270,10 @@ class StableDiffusionModelHijack:
|
|
|
|
|
for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
|
|
|
|
|
layer.padding_mode = 'circular' if enable else 'zeros'
|
|
|
|
|
|
|
|
|
|
def tokenize(self, text):
|
|
|
|
|
max_length = self.clip.max_length - 2
|
|
|
|
|
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
|
|
|
|
|
return remade_batch_tokens[0], token_count, max_length
|
|
|
|
|
|
|
|
|
|
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|
|
|
|
def __init__(self, wrapped, hijack):
|
|
|
|
|
@ -294,14 +300,16 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|
|
|
|
if mult != 1.0:
|
|
|
|
|
self.token_mults[ident] = mult
|
|
|
|
|
|
|
|
|
|
def forward(self, text):
|
|
|
|
|
self.hijack.fixes = []
|
|
|
|
|
self.hijack.comments = []
|
|
|
|
|
remade_batch_tokens = []
|
|
|
|
|
def process_text(self, text):
|
|
|
|
|
id_start = self.wrapped.tokenizer.bos_token_id
|
|
|
|
|
id_end = self.wrapped.tokenizer.eos_token_id
|
|
|
|
|
maxlen = self.wrapped.max_length
|
|
|
|
|
used_custom_terms = []
|
|
|
|
|
remade_batch_tokens = []
|
|
|
|
|
overflowing_words = []
|
|
|
|
|
hijack_comments = []
|
|
|
|
|
hijack_fixes = []
|
|
|
|
|
token_count = 0
|
|
|
|
|
|
|
|
|
|
cache = {}
|
|
|
|
|
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
|
|
|
|
|
@ -353,9 +361,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|
|
|
|
ovf = remade_tokens[maxlen - 2:]
|
|
|
|
|
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
|
|
|
|
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
|
|
|
|
|
|
|
|
|
self.hijack.comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
|
|
|
|
|
|
|
|
|
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
|
|
|
|
token_count = len(remade_tokens)
|
|
|
|
|
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
|
|
|
|
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
|
|
|
|
|
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
|
|
|
|
@ -364,8 +371,14 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|
|
|
|
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
|
|
|
|
|
|
|
|
|
remade_batch_tokens.append(remade_tokens)
|
|
|
|
|
self.hijack.fixes.append(fixes)
|
|
|
|
|
hijack_fixes.append(fixes)
|
|
|
|
|
batch_multipliers.append(multipliers)
|
|
|
|
|
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
|
|
|
|
|
|
|
|
|
def forward(self, text):
|
|
|
|
|
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
|
|
|
|
|
self.hijack.fixes = hijack_fixes
|
|
|
|
|
self.hijack.comments = hijack_comments
|
|
|
|
|
|
|
|
|
|
if len(used_custom_terms) > 0:
|
|
|
|
|
self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
|
|
|
|
|