@ -23,6 +23,8 @@ class Embedding:
self . vec = vec
self . name = name
self . step = step
self . shape = None
self . vectors = 0
self . cached_checksum = None
self . sd_checkpoint = None
self . sd_checkpoint_name = None
@ -57,8 +59,10 @@ class EmbeddingDatabase:
def __init__ ( self , embeddings_dir ) :
self . ids_lookup = { }
self . word_embeddings = { }
self . skipped_embeddings = [ ]
self . dir_mtime = None
self . embeddings_dir = embeddings_dir
self . expected_shape = - 1
def register_embedding ( self , embedding , model ) :
@ -75,14 +79,35 @@ class EmbeddingDatabase:
return embedding
def load_textual_inversion_embeddings ( self ) :
def get_expected_shape ( self ) :
expected_shape = - 1 # initialize with unknown
idx = torch . tensor ( 0 ) . to ( shared . device )
if expected_shape == - 1 :
try : # matches sd15 signature
first_embedding = shared . sd_model . cond_stage_model . wrapped . transformer . text_model . embeddings . token_embedding . wrapped ( idx )
expected_shape = first_embedding . shape [ 0 ]
except :
pass
if expected_shape == - 1 :
try : # matches sd20 signature
first_embedding = shared . sd_model . cond_stage_model . wrapped . model . token_embedding . wrapped ( idx )
expected_shape = first_embedding . shape [ 0 ]
except :
pass
if expected_shape == - 1 :
print ( ' Could not determine expected embeddings shape from model ' )
return expected_shape
def load_textual_inversion_embeddings ( self , force_reload = False ) :
mt = os . path . getmtime ( self . embeddings_dir )
if self . dir_mtime is not None and mt < = self . dir_mtime :
if not force_reload and self . dir_mtime is not None and mt < = self . dir_mtime :
return
self . dir_mtime = mt
self . ids_lookup . clear ( )
self . word_embeddings . clear ( )
self . skipped_embeddings = [ ]
self . expected_shape = self . get_expected_shape ( )
def process_file ( path , filename ) :
name = os . path . splitext ( filename ) [ 0 ]
@ -122,7 +147,14 @@ class EmbeddingDatabase:
embedding . step = data . get ( ' step ' , None )
embedding . sd_checkpoint = data . get ( ' sd_checkpoint ' , None )
embedding . sd_checkpoint_name = data . get ( ' sd_checkpoint_name ' , None )
self . register_embedding ( embedding , shared . sd_model )
embedding . vectors = vec . shape [ 0 ]
embedding . shape = vec . shape [ - 1 ]
if ( self . expected_shape == - 1 ) or ( self . expected_shape == embedding . shape ) :
self . register_embedding ( embedding , shared . sd_model )
else :
self . skipped_embeddings . append ( name )
# print('Skipping embedding {name}: shape was {shape} expected {expected}'.format(name = name, shape = embedding.shape, expected = self.expected_shape))
for fn in os . listdir ( self . embeddings_dir ) :
try :
@ -137,8 +169,9 @@ class EmbeddingDatabase:
print ( traceback . format_exc ( ) , file = sys . stderr )
continue
print ( f " Loaded a total of { len ( self . word_embeddings ) } textual inversion embeddings. " )
print ( " Embeddings: " , ' , ' . join ( self . word_embeddings . keys ( ) ) )
print ( " Textual inversion embeddings {num} loaded: {val} " . format ( num = len ( self . word_embeddings ) , val = ' , ' . join ( self . word_embeddings . keys ( ) ) ) )
if ( len ( self . skipped_embeddings ) > 0 ) :
print ( " Textual inversion embeddings {num} skipped: {val} " . format ( num = len ( self . skipped_embeddings ) , val = ' , ' . join ( self . skipped_embeddings ) ) )
def find_embedding_at_position ( self , tokens , offset ) :
token = tokens [ offset ]