@ -9,23 +9,9 @@ import glob
from copy import deepcopy
from copy import deepcopy
model_dir = " Stable-diffusion "
vae_path = os . path . abspath ( os . path . join ( models_path , " VAE " ) )
model_path = os . path . abspath ( os . path . join ( models_path , model_dir ) )
vae_dir = " VAE "
vae_path = os . path . abspath ( os . path . join ( models_path , vae_dir ) )
vae_ignore_keys = { " model_ema.decay " , " model_ema.num_updates " }
vae_ignore_keys = { " model_ema.decay " , " model_ema.num_updates " }
vae_dict = { }
default_vae_dict = { " auto " : " auto " , " None " : None , None : None }
default_vae_list = [ " auto " , " None " ]
default_vae_values = [ default_vae_dict [ x ] for x in default_vae_list ]
vae_dict = dict ( default_vae_dict )
vae_list = list ( default_vae_list )
first_load = True
base_vae = None
base_vae = None
@ -64,100 +50,69 @@ def restore_base_vae(model):
def get_filename ( filepath ) :
def get_filename ( filepath ) :
return os . path . splitext( os . path . basename( filepath ) ) [ 0 ]
return os . path . basename( filepath )
def refresh_vae_list ( vae_path = vae_path , model_path = model_path ) :
def refresh_vae_list ( ) :
global vae_dict , vae_list
vae_dict . clear ( )
res = { }
candidate s = [
path s = [
* glob . iglob ( os . path . join ( model_path, ' **/*.vae.ckpt ' ) , recursive = True ) ,
os . path . join ( sd_models. model_path, ' **/*.vae.ckpt ' ) ,
* glob . iglob ( os . path . join ( model_path, ' **/*.vae.pt ' ) , recursive = True ) ,
os . path . join ( sd_models. model_path, ' **/*.vae.pt ' ) ,
* glob . iglob ( os . path . join ( model_path, ' **/*.vae.safetensors ' ) , recursive = True ) ,
os . path . join ( sd_models. model_path, ' **/*.vae.safetensors ' ) ,
* glob . iglob ( os . path . join ( vae_path , ' **/*.ckpt ' ) , recursive = True ) ,
os . path . join ( vae_path , ' **/*.ckpt ' ) ,
* glob . iglob ( os . path . join ( vae_path , ' **/*.pt ' ) , recursive = True ) ,
os . path . join ( vae_path , ' **/*.pt ' ) ,
* glob . iglob ( os . path . join ( vae_path , ' **/*.safetensors ' ) , recursive = True ) ,
os . path . join ( vae_path , ' **/*.safetensors ' ) ,
]
]
if shared . cmd_opts . vae_path is not None and os . path . isfile ( shared . cmd_opts . vae_path ) :
candidates . append ( shared . cmd_opts . vae_path )
if shared . cmd_opts . ckpt_dir is not None and os . path . isdir ( shared . cmd_opts . ckpt_dir ) :
paths + = [
os . path . join ( shared . cmd_opts . ckpt_dir , ' **/*.vae.ckpt ' ) ,
os . path . join ( shared . cmd_opts . ckpt_dir , ' **/*.vae.pt ' ) ,
os . path . join ( shared . cmd_opts . ckpt_dir , ' **/*.vae.safetensors ' ) ,
]
candidates = [ ]
for path in paths :
candidates + = glob . iglob ( path , recursive = True )
for filepath in candidates :
for filepath in candidates :
name = get_filename ( filepath )
name = get_filename ( filepath )
res [ name ] = filepath
vae_dict [ name ] = filepath
vae_list . clear ( )
vae_list . extend ( default_vae_list )
vae_list . extend ( list ( res . keys ( ) ) )
def find_vae_near_checkpoint ( checkpoint_file ) :
vae_dict . clear ( )
checkpoint_path = os . path . splitext ( checkpoint_file ) [ 0 ]
vae_dict . update ( res )
for vae_location in [ checkpoint_path + " .vae.pt " , checkpoint_path + " .vae.ckpt " , checkpoint_path + " .vae.safetensors " ] :
vae_dict . update ( default_vae_dict )
if os . path . isfile ( vae_location ) :
return vae_list
return vae_location
return None
def get_vae_from_settings ( vae_file = " auto " ) :
# else, we load from settings, if not set to be default
if vae_file == " auto " and shared . opts . sd_vae is not None :
def resolve_vae ( checkpoint_file ) :
# if saved VAE settings isn't recognized, fallback to auto
if shared . cmd_opts . vae_path is not None :
vae_file = vae_dict . get ( shared . opts . sd_vae , " auto " )
return shared . cmd_opts . vae_path , ' from commandline argument '
# if VAE selected but not found, fallback to auto
if vae_file not in default_vae_values and not os . path . isfile ( vae_file ) :
vae_near_checkpoint = find_vae_near_checkpoint ( checkpoint_file )
vae_file = " auto "
if vae_near_checkpoint is not None and ( shared . opts . sd_vae_as_default or shared . opts . sd_vae == " auto " ) :
print ( f " Selected VAE doesn ' t exist: { vae_file } " )
return vae_near_checkpoint , ' found near the checkpoint '
return vae_file
if shared . opts . sd_vae == " None " :
return None , None
def resolve_vae ( checkpoint_file = None , vae_file = " auto " ) :
global first_load , vae_dict , vae_list
vae_from_options = vae_dict . get ( shared . opts . sd_vae , None )
if vae_from_options is not None :
# if vae_file argument is provided, it takes priority, but not saved
return vae_from_options , ' specified in settings '
if vae_file and vae_file not in default_vae_list :
if not os . path . isfile ( vae_file ) :
if shared . opts . sd_vae != " Automatic " :
print ( f " VAE provided as function argument doesn ' t exist: { vae_file } " )
print ( f " Couldn ' t find VAE named { shared . opts . sd_vae } ; using None instead " )
vae_file = " auto "
# for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
return None , None
if first_load and shared . cmd_opts . vae_path is not None :
if os . path . isfile ( shared . cmd_opts . vae_path ) :
vae_file = shared . cmd_opts . vae_path
def load_vae ( model , vae_file = None , vae_source = " from unknown source " ) :
shared . opts . data [ ' sd_vae ' ] = get_filename ( vae_file )
global vae_dict , loaded_vae_file
else :
print ( f " VAE provided as command line argument doesn ' t exist: { vae_file } " )
# fallback to selector in settings, if vae selector not set to act as default fallback
if not shared . opts . sd_vae_as_default :
vae_file = get_vae_from_settings ( vae_file )
# vae-path cmd arg takes priority for auto
if vae_file == " auto " and shared . cmd_opts . vae_path is not None :
if os . path . isfile ( shared . cmd_opts . vae_path ) :
vae_file = shared . cmd_opts . vae_path
print ( f " Using VAE provided as command line argument: { vae_file } " )
# if still not found, try look for ".vae.pt" beside model
model_path = os . path . splitext ( checkpoint_file ) [ 0 ]
if vae_file == " auto " :
vae_file_try = model_path + " .vae.pt "
if os . path . isfile ( vae_file_try ) :
vae_file = vae_file_try
print ( f " Using VAE found similar to selected model: { vae_file } " )
# if still not found, try look for ".vae.ckpt" beside model
if vae_file == " auto " :
vae_file_try = model_path + " .vae.ckpt "
if os . path . isfile ( vae_file_try ) :
vae_file = vae_file_try
print ( f " Using VAE found similar to selected model: { vae_file } " )
# if still not found, try look for ".vae.safetensors" beside model
if vae_file == " auto " :
vae_file_try = model_path + " .vae.safetensors "
if os . path . isfile ( vae_file_try ) :
vae_file = vae_file_try
print ( f " Using VAE found similar to selected model: { vae_file } " )
# No more fallbacks for auto
if vae_file == " auto " :
vae_file = None
# Last check, just because
if vae_file and not os . path . exists ( vae_file ) :
vae_file = None
return vae_file
def load_vae ( model , vae_file = None ) :
global first_load , vae_dict , vae_list , loaded_vae_file
# save_settings = False
# save_settings = False
cache_enabled = shared . opts . sd_vae_checkpoint_cache > 0
cache_enabled = shared . opts . sd_vae_checkpoint_cache > 0
@ -165,12 +120,12 @@ def load_vae(model, vae_file=None):
if vae_file :
if vae_file :
if cache_enabled and vae_file in checkpoints_loaded :
if cache_enabled and vae_file in checkpoints_loaded :
# use vae checkpoint cache
# use vae checkpoint cache
print ( f " Loading VAE weights [ { get_filename ( vae_file ) } ] from cache " )
print ( f " Loading VAE weights { vae_source } : cached { get_filename ( vae_file ) } " )
store_base_vae ( model )
store_base_vae ( model )
_load_vae_dict ( model , checkpoints_loaded [ vae_file ] )
_load_vae_dict ( model , checkpoints_loaded [ vae_file ] )
else :
else :
assert os . path . isfile ( vae_file ) , f " VAE file doesn' t exist: { vae_file } "
assert os . path . isfile ( vae_file ) , f " VAE { vae_source } doesn' t exist: { vae_file } "
print ( f " Loading VAE weights from : { vae_file } " )
print ( f " Loading VAE weights { vae_source } : { vae_file } " )
store_base_vae ( model )
store_base_vae ( model )
vae_ckpt = sd_models . read_state_dict ( vae_file , map_location = shared . weight_load_location )
vae_ckpt = sd_models . read_state_dict ( vae_file , map_location = shared . weight_load_location )
@ -191,14 +146,12 @@ def load_vae(model, vae_file=None):
vae_opt = get_filename ( vae_file )
vae_opt = get_filename ( vae_file )
if vae_opt not in vae_dict :
if vae_opt not in vae_dict :
vae_dict [ vae_opt ] = vae_file
vae_dict [ vae_opt ] = vae_file
vae_list . append ( vae_opt )
elif loaded_vae_file :
elif loaded_vae_file :
restore_base_vae ( model )
restore_base_vae ( model )
loaded_vae_file = vae_file
loaded_vae_file = vae_file
first_load = False
# don't call this from outside
# don't call this from outside
def _load_vae_dict ( model , vae_dict_1 ) :
def _load_vae_dict ( model , vae_dict_1 ) :
@ -211,7 +164,10 @@ def clear_loaded_vae():
loaded_vae_file = None
loaded_vae_file = None
def reload_vae_weights ( sd_model = None , vae_file = " auto " ) :
unspecified = object ( )
def reload_vae_weights ( sd_model = None , vae_file = unspecified ) :
from modules import lowvram , devices , sd_hijack
from modules import lowvram , devices , sd_hijack
if not sd_model :
if not sd_model :
@ -219,7 +175,11 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
checkpoint_info = sd_model . sd_checkpoint_info
checkpoint_info = sd_model . sd_checkpoint_info
checkpoint_file = checkpoint_info . filename
checkpoint_file = checkpoint_info . filename
vae_file = resolve_vae ( checkpoint_file , vae_file = vae_file )
if vae_file == unspecified :
vae_file , vae_source = resolve_vae ( checkpoint_file )
else :
vae_source = " from function argument "
if loaded_vae_file == vae_file :
if loaded_vae_file == vae_file :
return
return
@ -231,7 +191,7 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
sd_hijack . model_hijack . undo_hijack ( sd_model )
sd_hijack . model_hijack . undo_hijack ( sd_model )
load_vae ( sd_model , vae_file )
load_vae ( sd_model , vae_file , vae_source )
sd_hijack . model_hijack . hijack ( sd_model )
sd_hijack . model_hijack . hijack ( sd_model )
script_callbacks . model_loaded_callback ( sd_model )
script_callbacks . model_loaded_callback ( sd_model )
@ -239,5 +199,5 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
if not shared . cmd_opts . lowvram and not shared . cmd_opts . medvram :
if not shared . cmd_opts . lowvram and not shared . cmd_opts . medvram :
sd_model . to ( devices . device )
sd_model . to ( devices . device )
print ( " VAE W eights loaded." )
print ( " VAE w eights loaded." )
return sd_model
return sd_model