@ -2,8 +2,6 @@ import collections
import os . path
import sys
import gc
import time
from collections import namedtuple
import torch
import re
import safetensors . torch
@ -14,10 +12,10 @@ import ldm.modules.midas as midas
from ldm . util import instantiate_from_config
from modules import shared , modelloader , devices , script_callbacks , sd_vae , sd_disable_initialization , errors , hashes
from modules import shared , modelloader , devices , script_callbacks , sd_vae , sd_disable_initialization , errors , hashes , sd_models_config
from modules . paths import models_path
from modules . sd_hijack_inpainting import do_inpainting_hijack , should_hijack_inpainting
from modules . sd_hijack_ip2p import should_hijack_ip2p
from modules . sd_hijack_inpainting import do_inpainting_hijack
from modules . timer import Timer
model_dir = " Stable-diffusion "
model_path = os . path . abspath ( os . path . join ( models_path , model_dir ) )
@ -99,17 +97,6 @@ def checkpoint_tiles():
return sorted ( [ x . title for x in checkpoints_list . values ( ) ] , key = alphanumeric_key )
def find_checkpoint_config ( info ) :
if info is None :
return shared . cmd_opts . config
config = os . path . splitext ( info . filename ) [ 0 ] + " .yaml "
if os . path . exists ( config ) :
return config
return shared . cmd_opts . config
def list_models ( ) :
checkpoints_list . clear ( )
checkpoint_alisases . clear ( )
@ -215,9 +202,7 @@ def get_state_dict_from_checkpoint(pl_sd):
def read_state_dict ( checkpoint_file , print_global_state = False , map_location = None ) :
_ , extension = os . path . splitext ( checkpoint_file )
if extension . lower ( ) == " .safetensors " :
device = map_location or shared . weight_load_location
if device is None :
device = devices . get_cuda_device_string ( ) if torch . cuda . is_available ( ) else " cpu "
device = map_location or shared . weight_load_location or devices . get_optimal_device_name ( )
pl_sd = safetensors . torch . load_file ( checkpoint_file , device = device )
else :
pl_sd = torch . load ( checkpoint_file , map_location = map_location or shared . weight_load_location )
@ -229,60 +214,74 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
return sd
def load_model_weights ( model , checkpoint_info : CheckpointInfo ) :
def get_checkpoint_state_dict ( checkpoint_info : CheckpointInfo , timer ) :
sd_model_hash = checkpoint_info . calculate_shorthash ( )
timer . record ( " calculate hash " )
if checkpoint_info in checkpoints_loaded :
# use checkpoint cache
print ( f " Loading weights [ { sd_model_hash } ] from cache " )
return checkpoints_loaded [ checkpoint_info ]
print ( f " Loading weights [ { sd_model_hash } ] from { checkpoint_info . filename } " )
res = read_state_dict ( checkpoint_info . filename )
timer . record ( " load weights from disk " )
return res
def load_model_weights ( model , checkpoint_info : CheckpointInfo , state_dict , timer ) :
title = checkpoint_info . title
sd_model_hash = checkpoint_info . calculate_shorthash ( )
timer . record ( " calculate hash " )
if checkpoint_info . title != title :
shared . opts . data [ " sd_model_checkpoint " ] = checkpoint_info . title
cache_enabled = shared . opts . sd_checkpoint_cache > 0
if state_dict is None :
state_dict = get_checkpoint_state_dict ( checkpoint_info , timer )
if cache_enabled and checkpoint_info in checkpoints_loaded :
# use checkpoint cache
print ( f " Loading weights [ { sd_model_hash } ] from cache " )
model . load_state_dict ( checkpoints_loaded [ checkpoint_info ] )
else :
# load from file
print ( f " Loading weights [ { sd_model_hash } ] from { checkpoint_info . filename } " )
model . load_state_dict ( state_dict , strict = False )
del state_dict
timer . record ( " apply weights to model " )
sd = read_state_dict ( checkpoint_info . filename )
model . load_state_dict ( sd , strict = False )
del sd
if cache_enabled :
# cache newly loaded model
checkpoints_loaded [ checkpoint_info ] = model . state_dict ( ) . copy ( )
if shared . opts . sd_checkpoint_cache > 0 :
# cache newly loaded model
checkpoints_loaded [ checkpoint_info ] = model . state_dict ( ) . copy ( )
if shared . cmd_opts . opt_channelslast :
model . to ( memory_format = torch . channels_last )
timer . record ( " apply channels_last " )
if shared . cmd_opts . opt_channelslast :
model . to ( memory_format = torch . channels_last )
if not shared . cmd_opts . no_half :
vae = model . first_stage_model
depth_model = getattr ( model , ' depth_model ' , None )
if not shared . cmd_opts . no_half :
vae = model . first_stage_model
depth_model = getattr ( model , ' depth_model ' , None )
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
if shared . cmd_opts . no_half_vae :
model . first_stage_model = None
# with --upcast-sampling, don't convert the depth model weights to float16
if shared . cmd_opts . upcast_sampling and depth_model :
model . depth_model = None
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
if shared . cmd_opts . no_half_vae :
model . first_stage_model = None
# with --upcast-sampling, don't convert the depth model weights to float16
if shared . cmd_opts . upcast_sampling and depth_model :
model . depth_model = None
model . half ( )
model . first_stage_model = vae
if depth_model :
model . depth_model = depth_model
model . half ( )
model . first_stage_model = vae
if depth_model :
model . depth_model = depth_model
timer . record ( " apply half() " )
devices . dtype = torch . float32 if shared . cmd_opts . no_half else torch . float16
devices . dtype_vae = torch . float32 if shared . cmd_opts . no_half or shared . cmd_opts . no_half_vae else torch . float16
devices . dtype_unet = model . model . diffusion_model . dtype
devices . unet_needs_upcast = shared . cmd_opts . upcast_sampling and devices . dtype == torch . float16 and devices . dtype_unet == torch . float16
devices . dtype = torch . float32 if shared . cmd_opts . no_half else torch . float16
devices . dtype_vae = torch . float32 if shared . cmd_opts . no_half or shared . cmd_opts . no_half_vae else torch . float16
devices . dtype_unet = model . model . diffusion_model . dtype
devices . unet_needs_upcast = shared . cmd_opts . upcast_sampling and devices . dtype == torch . float16 and devices . dtype_unet == torch . float16
model . first_stage_model . to ( devices . dtype_vae )
model . first_stage_model . to ( devices . dtype_vae )
timer . record ( " apply dtype to VAE " )
# clean up cache if limit is reached
if cache_enabled :
while len ( checkpoints_loaded ) > shared . opts . sd_checkpoint_cache + 1 : # we need to count the current model
checkpoints_loaded . popitem ( last = False ) # LRU
while len ( checkpoints_loaded ) > shared . opts . sd_checkpoint_cache :
checkpoints_loaded . popitem ( last = False )
model . sd_model_hash = sd_model_hash
model . sd_model_checkpoint = checkpoint_info . filename
@ -295,6 +294,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo):
sd_vae . clear_loaded_vae ( )
vae_file , vae_source = sd_vae . resolve_vae ( checkpoint_info . filename )
sd_vae . load_vae ( model , vae_file , vae_source )
timer . record ( " load VAE " )
def enable_midas_autodownload ( ) :
@ -340,24 +340,20 @@ def enable_midas_autodownload():
midas . api . load_model = load_model_wrapper
class Timer :
def __init__ ( self ) :
self . start = time . time ( )
def repair_config ( sd_config ) :
if not hasattr ( sd_config . model . params , " use_ema " ) :
sd_config . model . params . use_ema = False
def elapsed ( self ) :
end = time . time ( )
res = end - self . start
self . start = end
return res
if shared . cmd_opts . no_half :
sd_config . model . params . unet_config . params . use_fp16 = False
elif shared . cmd_opts . upcast_sampling :
sd_config . model . params . unet_config . params . use_fp16 = True
def load_model ( checkpoint_info = None ):
def load_model ( checkpoint_info = None , already_loaded_state_dict = None , time_taken_to_load_state_dict = None ):
from modules import lowvram , sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint ( )
checkpoint_config = find_checkpoint_config ( checkpoint_info )
if checkpoint_config != shared . cmd_opts . config :
print ( f " Loading config from: { checkpoint_config } " )
if shared . sd_model :
sd_hijack . model_hijack . undo_hijack ( shared . sd_model )
@ -365,38 +361,27 @@ def load_model(checkpoint_info=None):
gc . collect ( )
devices . torch_gc ( )
sd_config = OmegaConf . load ( checkpoint_config )
if should_hijack_inpainting ( checkpoint_info ) :
# Hardcoded config for now...
sd_config . model . target = " ldm.models.diffusion.ddpm.LatentInpaintDiffusion "
sd_config . model . params . conditioning_key = " hybrid "
sd_config . model . params . unet_config . params . in_channels = 9
sd_config . model . params . finetune_keys = None
if should_hijack_ip2p ( checkpoint_info ) :
sd_config . model . target = " modules.models.diffusion.ddpm_edit.LatentDiffusion "
sd_config . model . params . conditioning_key = " hybrid "
sd_config . model . params . first_stage_key = " edited "
sd_config . model . params . cond_stage_key = " edit "
sd_config . model . params . image_size = 16
sd_config . model . params . unet_config . params . in_channels = 8
sd_config . model . params . unet_config . params . out_channels = 4
do_inpainting_hijack ( )
if not hasattr ( sd_config . model . params , " use_ema " ) :
sd_config . model . params . use_ema = False
timer = Timer ( )
do_inpainting_hijack ( )
if already_loaded_state_dict is not None :
state_dict = already_loaded_state_dict
else :
state_dict = get_checkpoint_state_dict ( checkpoint_info , timer )
if shared . cmd_opts . no_half :
sd_config . model . params . unet_config . params . use_fp16 = False
elif shared . cmd_opts . upcast_sampling :
sd_config . model . params . unet_config . params . use_fp16 = True
checkpoint_config = sd_models_config . find_checkpoint_config ( state_dict , checkpoint_info )
timer = Timer ( )
timer . record ( " find config " )
sd_model = None
sd_config = OmegaConf . load ( checkpoint_config )
repair_config ( sd_config )
timer . record ( " load config " )
print ( f " Creating model from config: { checkpoint_config } " )
sd_model = None
try :
with sd_disable_initialization . DisableInitialization ( ) :
sd_model = instantiate_from_config ( sd_config . model )
@ -407,29 +392,35 @@ def load_model(checkpoint_info=None):
print ( ' Failed to create model quickly; will retry using slow method. ' , file = sys . stderr )
sd_model = instantiate_from_config ( sd_config . model )
elapsed_create = timer . elapsed ( )
sd_model. used_config = checkpoint_config
load_model_weights( sd_model , checkpoint_info )
timer. record ( " create model " )
elapsed_load_weights = timer . elapsed ( )
load_model_weights( sd_model , checkpoint_info , state_dict , timer )
if shared . cmd_opts . lowvram or shared . cmd_opts . medvram :
lowvram . setup_for_low_vram ( sd_model , shared . cmd_opts . medvram )
else :
sd_model . to ( shared . device )
timer . record ( " move model to device " )
sd_hijack . model_hijack . hijack ( sd_model )
timer . record ( " hijack " )
sd_model . eval ( )
shared . sd_model = sd_model
sd_hijack . model_hijack . embedding_db . load_textual_inversion_embeddings ( force_reload = True ) # Reload embeddings after model load as they may or may not fit the model
timer . record ( " load textual inversion embeddings " )
script_callbacks . model_loaded_callback ( sd_model )
elapsed_the_rest = timer . elapsed ( )
timer. record ( " scripts callbacks " )
print ( f " Model loaded in { elapsed_create + elapsed_load_weights + elapsed_the_rest : .1f } s ( { elapsed_create : .1f } s create model, { elapsed_load_weights : .1f } s load weights) ." )
print ( f " Model loaded in { timer. summary ( ) } ." )
return sd_model
@ -440,6 +431,7 @@ def reload_model_weights(sd_model=None, info=None):
if not sd_model :
sd_model = shared . sd_model
if sd_model is None : # previous model load failed
current_checkpoint_info = None
else :
@ -447,38 +439,44 @@ def reload_model_weights(sd_model=None, info=None):
if sd_model . sd_model_checkpoint == checkpoint_info . filename :
return
checkpoint_config = find_checkpoint_config ( current_checkpoint_info )
if shared . cmd_opts . lowvram or shared . cmd_opts . medvram :
lowvram . send_everything_to_cpu ( )
else :
sd_model . to ( devices . cpu )
if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config ( checkpoint_info ) or should_hijack_inpainting ( checkpoint_info ) != should_hijack_inpainting ( sd_model . sd_checkpoint_info ) or should_hijack_ip2p ( checkpoint_info ) != should_hijack_ip2p ( sd_model . sd_checkpoint_info ) :
del sd_model
checkpoints_loaded . clear ( )
load_model ( checkpoint_info )
return shared . sd_model
sd_hijack . model_hijack . undo_hijack ( sd_model )
if shared . cmd_opts . lowvram or shared . cmd_opts . medvram :
lowvram . send_everything_to_cpu ( )
else :
sd_model . to ( devices . cpu )
timer = Timer ( )
s d_hijack. model_hijack . undo_hijack ( sd_model )
state_dict = get_checkpoint_state_dict ( checkpoint_info , timer )
timer = Timer ( )
checkpoint_config = sd_models_config . find_checkpoint_config ( state_dict , checkpoint_info )
timer . record ( " find config " )
if sd_model is None or checkpoint_config != sd_model . used_config :
del sd_model
checkpoints_loaded . clear ( )
load_model ( checkpoint_info , already_loaded_state_dict = state_dict , time_taken_to_load_state_dict = timer . records [ " load weights from disk " ] )
return shared . sd_model
try :
load_model_weights ( sd_model , checkpoint_info )
load_model_weights ( sd_model , checkpoint_info , state_dict , timer )
except Exception as e :
print ( " Failed to load checkpoint, restoring previous " )
load_model_weights ( sd_model , current_checkpoint_info )
load_model_weights ( sd_model , current_checkpoint_info , None , timer )
raise
finally :
sd_hijack . model_hijack . hijack ( sd_model )
timer . record ( " hijack " )
script_callbacks . model_loaded_callback ( sd_model )
timer . record ( " script callbacks " )
if not shared . cmd_opts . lowvram and not shared . cmd_opts . medvram :
sd_model . to ( devices . device )
timer . record ( " move model to device " )
elapsed = timer . elapsed ( )
print ( f " Weights loaded in { elapsed : .1f } s. " )
print ( f " Weights loaded in { timer . summary ( ) } . " )
return sd_model