@ -18,7 +18,6 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm . modules . diffusionmodules . model . nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm . modules . diffusionmodules . model . AttnBlock . forward
def apply_optimizations ( ) :
undo_optimizations ( )
@ -83,7 +82,7 @@ class StableDiffusionModelHijack:
layer . padding_mode = ' circular ' if enable else ' zeros '
def tokenize ( self , text ) :
max_length = self . clip . max_length - 2
max_length = opts . max_prompt_tokens - 2
_ , remade_batch_tokens , _ , _ , _ , token_count = self . clip . process_text ( [ text ] )
return remade_batch_tokens [ 0 ] , token_count , max_length
@ -94,7 +93,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
self . wrapped = wrapped
self . hijack : StableDiffusionModelHijack = hijack
self . tokenizer = wrapped . tokenizer
self . max_length = wrapped . max_length
self . token_mults = { }
tokens_with_parens = [ ( k , v ) for k , v in self . tokenizer . get_vocab ( ) . items ( ) if ' ( ' in k or ' ) ' in k or ' [ ' in k or ' ] ' in k ]
@ -116,7 +114,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def tokenize_line ( self , line , used_custom_terms , hijack_comments ) :
id_start = self . wrapped . tokenizer . bos_token_id
id_end = self . wrapped . tokenizer . eos_token_id
maxlen = self . wrapped . max_length
maxlen = opts . max_prompt_tokens
if opts . enable_emphasis :
parsed = prompt_parser . parse_prompt_attention ( line )
@ -191,7 +189,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def process_text_old ( self , text ) :
id_start = self . wrapped . tokenizer . bos_token_id
id_end = self . wrapped . tokenizer . eos_token_id
maxlen = self . wrapped . max_length
maxlen = self . wrapped . max_length # you get to stay at 77
used_custom_terms = [ ]
remade_batch_tokens = [ ]
overflowing_words = [ ]
@ -268,8 +266,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if len ( used_custom_terms ) > 0 :
self . hijack . comments . append ( " Used embeddings: " + " , " . join ( [ f ' { word } [ { checksum } ] ' for word , checksum in used_custom_terms ] ) )
position_ids_array = [ min ( x , 75 ) for x in range ( len ( remade_batch_tokens [ 0 ] ) - 1 ) ] + [ 76 ]
position_ids = torch . asarray ( position_ids_array , device = devices . device ) . expand ( ( 1 , - 1 ) )
tokens = torch . asarray ( remade_batch_tokens ) . to ( device )
outputs = self . wrapped . transformer ( input_ids = tokens )
outputs = self . wrapped . transformer ( input_ids = tokens , position_ids = position_ids )
z = outputs . last_hidden_state
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise