@ -14,17 +14,56 @@ import ldm.modules.midas as midas
from ldm . util import instantiate_from_config
from ldm . util import instantiate_from_config
from modules import shared , modelloader , devices , script_callbacks , sd_vae , sd_disable_initialization , errors
from modules import shared , modelloader , devices , script_callbacks , sd_vae , sd_disable_initialization , errors , hashes
from modules . paths import models_path
from modules . paths import models_path
from modules . sd_hijack_inpainting import do_inpainting_hijack , should_hijack_inpainting
from modules . sd_hijack_inpainting import do_inpainting_hijack , should_hijack_inpainting
model_dir = " Stable-diffusion "
model_dir = " Stable-diffusion "
model_path = os . path . abspath ( os . path . join ( models_path , model_dir ) )
model_path = os . path . abspath ( os . path . join ( models_path , model_dir ) )
CheckpointInfo = namedtuple ( " CheckpointInfo " , [ ' filename ' , ' title ' , ' hash ' , ' model_name ' ] )
checkpoints_list = { }
checkpoints_list = { }
checkpoint_alisases = { }
checkpoints_loaded = collections . OrderedDict ( )
checkpoints_loaded = collections . OrderedDict ( )
class CheckpointInfo :
def __init__ ( self , filename ) :
self . filename = filename
abspath = os . path . abspath ( filename )
if shared . cmd_opts . ckpt_dir is not None and abspath . startswith ( shared . cmd_opts . ckpt_dir ) :
name = abspath . replace ( shared . cmd_opts . ckpt_dir , ' ' )
elif abspath . startswith ( model_path ) :
name = abspath . replace ( model_path , ' ' )
else :
name = os . path . basename ( filename )
if name . startswith ( " \\ " ) or name . startswith ( " / " ) :
name = name [ 1 : ]
self . title = name
self . model_name = os . path . splitext ( name . replace ( " / " , " _ " ) . replace ( " \\ " , " _ " ) ) [ 0 ]
self . hash = model_hash ( filename )
self . ids = [ self . hash , self . model_name , self . title , f ' { name } [ { self . hash } ] ' ]
self . shorthash = None
self . sha256 = None
def register ( self ) :
checkpoints_list [ self . title ] = self
for id in self . ids :
checkpoint_alisases [ id ] = self
def calculate_shorthash ( self ) :
self . sha256 = hashes . sha256 ( self . filename , self . title )
self . shorthash = self . sha256 [ 0 : 10 ]
if self . shorthash not in self . ids :
self . ids + = [ self . shorthash , self . sha256 ]
self . register ( )
return self . shorthash
try :
try :
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
@ -44,9 +83,13 @@ def setup_model():
def checkpoint_tiles ( ) :
def checkpoint_tiles ( ) :
convert = lambda name : int ( name ) if name . isdigit ( ) else name . lower ( )
def convert ( name ) :
alphanumeric_key = lambda key : [ convert ( c ) for c in re . split ( ' ([0-9]+) ' , key ) ]
return int ( name ) if name . isdigit ( ) else name . lower ( )
return sorted ( [ x . title for x in checkpoints_list . values ( ) ] , key = alphanumeric_key )
def alphanumeric_key ( key ) :
return [ convert ( c ) for c in re . split ( ' ([0-9]+) ' , key ) ]
return sorted ( [ x . title for x in checkpoints_list . values ( ) ] , key = alphanumeric_key )
def find_checkpoint_config ( info ) :
def find_checkpoint_config ( info ) :
@ -62,48 +105,38 @@ def find_checkpoint_config(info):
def list_models ( ) :
def list_models ( ) :
checkpoints_list . clear ( )
checkpoints_list . clear ( )
checkpoint_alisases . clear ( )
model_list = modelloader . load_models ( model_path = model_path , command_path = shared . cmd_opts . ckpt_dir , ext_filter = [ " .ckpt " , " .safetensors " ] , ext_blacklist = [ " .vae.safetensors " ] )
model_list = modelloader . load_models ( model_path = model_path , command_path = shared . cmd_opts . ckpt_dir , ext_filter = [ " .ckpt " , " .safetensors " ] , ext_blacklist = [ " .vae.safetensors " ] )
def modeltitle ( path , shorthash ) :
abspath = os . path . abspath ( path )
if shared . cmd_opts . ckpt_dir is not None and abspath . startswith ( shared . cmd_opts . ckpt_dir ) :
name = abspath . replace ( shared . cmd_opts . ckpt_dir , ' ' )
elif abspath . startswith ( model_path ) :
name = abspath . replace ( model_path , ' ' )
else :
name = os . path . basename ( path )
if name . startswith ( " \\ " ) or name . startswith ( " / " ) :
name = name [ 1 : ]
shortname = os . path . splitext ( name . replace ( " / " , " _ " ) . replace ( " \\ " , " _ " ) ) [ 0 ]
return f ' { name } [ { shorthash } ] ' , shortname
cmd_ckpt = shared . cmd_opts . ckpt
cmd_ckpt = shared . cmd_opts . ckpt
if os . path . exists ( cmd_ckpt ) :
if os . path . exists ( cmd_ckpt ) :
h = model_hash ( cmd_ckpt )
checkpoint_info = CheckpointInfo ( cmd_ckpt )
title, short_model_name = modeltitle ( cmd_ckpt , h )
checkpoint_info . register ( )
checkpoints_list [ title ] = CheckpointInfo ( cmd_ckpt , title , h , short_model_name )
shared . opts . data [ ' sd_model_checkpoint ' ] = title
shared . opts . data [ ' sd_model_checkpoint ' ] = checkpoint_info . title
elif cmd_ckpt is not None and cmd_ckpt != shared . default_sd_model_file :
elif cmd_ckpt is not None and cmd_ckpt != shared . default_sd_model_file :
print ( f " Checkpoint in --ckpt argument not found (Possible it was moved to { model_path } : { cmd_ckpt } " , file = sys . stderr )
print ( f " Checkpoint in --ckpt argument not found (Possible it was moved to { model_path } : { cmd_ckpt } " , file = sys . stderr )
for filename in model_list :
for filename in model_list :
h = model_hash ( filename )
checkpoint_info = CheckpointInfo ( filename )
title , short_model_name = modeltitle ( filename , h )
checkpoint_info . register ( )
checkpoints_list [ title ] = CheckpointInfo ( filename , title , h , short_model_name )
def get_closet_checkpoint_match ( search_string ) :
checkpoint_info = checkpoint_alisases . get ( search_string , None )
if checkpoint_info is not None :
return
found = sorted ( [ info for info in checkpoints_list . values ( ) if search_string in info . title ] , key = lambda x : len ( x . title ) )
if found :
return found [ 0 ]
def get_closet_checkpoint_match ( searchString ) :
applicable = sorted ( [ info for info in checkpoints_list . values ( ) if searchString in info . title ] , key = lambda x : len ( x . title ) )
if len ( applicable ) > 0 :
return applicable [ 0 ]
return None
return None
def model_hash ( filename ) :
def model_hash ( filename ) :
""" old hash that only looks at a small part of the file and is prone to collisions """
try :
try :
with open ( filename , " rb " ) as file :
with open ( filename , " rb " ) as file :
import hashlib
import hashlib
@ -119,7 +152,7 @@ def model_hash(filename):
def select_checkpoint ( ) :
def select_checkpoint ( ) :
model_checkpoint = shared . opts . sd_model_checkpoint
model_checkpoint = shared . opts . sd_model_checkpoint
checkpoint_info = checkpoint s_list . get ( model_checkpoint , None )
checkpoint_info = checkpoint _alisases . get ( model_checkpoint , None )
if checkpoint_info is not None :
if checkpoint_info is not None :
return checkpoint_info
return checkpoint_info
@ -189,9 +222,8 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
return sd
return sd
def load_model_weights ( model , checkpoint_info , vae_file = " auto " ) :
def load_model_weights ( model , checkpoint_info : CheckpointInfo , vae_file = " auto " ) :
checkpoint_file = checkpoint_info . filename
sd_model_hash = checkpoint_info . calculate_shorthash ( )
sd_model_hash = checkpoint_info . hash
cache_enabled = shared . opts . sd_checkpoint_cache > 0
cache_enabled = shared . opts . sd_checkpoint_cache > 0
@ -201,9 +233,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
model . load_state_dict ( checkpoints_loaded [ checkpoint_info ] )
model . load_state_dict ( checkpoints_loaded [ checkpoint_info ] )
else :
else :
# load from file
# load from file
print ( f " Loading weights [ { sd_model_hash } ] from { checkpoint_ file} " )
print ( f " Loading weights [ { sd_model_hash } ] from { checkpoint_ info. filenam e} " )
sd = read_state_dict ( checkpoint_ file)
sd = read_state_dict ( checkpoint_ info. filenam e)
model . load_state_dict ( sd , strict = False )
model . load_state_dict ( sd , strict = False )
del sd
del sd
@ -235,14 +267,14 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
checkpoints_loaded . popitem ( last = False ) # LRU
checkpoints_loaded . popitem ( last = False ) # LRU
model . sd_model_hash = sd_model_hash
model . sd_model_hash = sd_model_hash
model . sd_model_checkpoint = checkpoint_ file
model . sd_model_checkpoint = checkpoint_ info. filenam e
model . sd_checkpoint_info = checkpoint_info
model . sd_checkpoint_info = checkpoint_info
model . logvar = model . logvar . to ( devices . device ) # fix for training
model . logvar = model . logvar . to ( devices . device ) # fix for training
sd_vae . delete_base_vae ( )
sd_vae . delete_base_vae ( )
sd_vae . clear_loaded_vae ( )
sd_vae . clear_loaded_vae ( )
vae_file = sd_vae . resolve_vae ( checkpoint_ file, vae_file = vae_file )
vae_file = sd_vae . resolve_vae ( checkpoint_ info. filenam e, vae_file = vae_file )
sd_vae . load_vae ( model , vae_file )
sd_vae . load_vae ( model , vae_file )