@ -4,6 +4,7 @@ import sys
import gc
from collections import namedtuple
import torch
from safetensors . torch import load_file
import re
from omegaconf import OmegaConf
@ -16,9 +17,10 @@ from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inp
model_dir = " Stable-diffusion "
model_path = os . path . abspath ( os . path . join ( models_path , model_dir ) )
CheckpointInfo = namedtuple ( " CheckpointInfo " , [ ' filename ' , ' title ' , ' hash ' , ' model_name ' , ' config ' ])
CheckpointInfo = namedtuple ( " CheckpointInfo " , [ ' filename ' , ' title ' , ' hash ' , ' model_name ' , ' config ' , ' exttype ' ])
checkpoints_list = { }
checkpoints_loaded = collections . OrderedDict ( )
checkpoint_types = { ' .ckpt ' : ' pickle ' , ' .safetensors ' : ' safetensors ' }
try :
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
@ -45,7 +47,7 @@ def checkpoint_tiles():
def list_models ( ) :
checkpoints_list . clear ( )
model_list = modelloader . load_models ( model_path = model_path , command_path = shared . cmd_opts . ckpt_dir , ext_filter = [ " .ckpt " ])
model_list = modelloader . load_models ( model_path = model_path , command_path = shared . cmd_opts . ckpt_dir , ext_filter = [ " .ckpt " ," .safetensors " ])
def modeltitle ( path , shorthash ) :
abspath = os . path . abspath ( path )
@ -60,15 +62,15 @@ def list_models():
if name . startswith ( " \\ " ) or name . startswith ( " / " ) :
name = name [ 1 : ]
shortname = os . path . splitext ( name . replace ( " / " , " _ " ) . replace ( " \\ " , " _ " ) ) [ 0 ]
shortname , ext = os . path . splitext ( name . replace ( " / " , " _ " ) . replace ( " \\ " , " _ " ) )
return f ' { name } [ { shorthash} ] ' , shortname
return f ' { name } [ { checkpoint_types[ ext ] } ] [ { shorthash} ] ' , shortname
cmd_ckpt = shared . cmd_opts . ckpt
if os . path . exists ( cmd_ckpt ) :
h = model_hash ( cmd_ckpt )
title , short_model_name = modeltitle ( cmd_ckpt , h )
checkpoints_list [ title ] = CheckpointInfo ( cmd_ckpt , title , h , short_model_name , shared . cmd_opts . config )
checkpoints_list [ title ] = CheckpointInfo ( cmd_ckpt , title , h , short_model_name , shared . cmd_opts . config , ' ' )
shared . opts . data [ ' sd_model_checkpoint ' ] = title
elif cmd_ckpt is not None and cmd_ckpt != shared . default_sd_model_file :
print ( f " Checkpoint in --ckpt argument not found (Possible it was moved to { model_path } : { cmd_ckpt } " , file = sys . stderr )
@ -76,12 +78,12 @@ def list_models():
h = model_hash ( filename )
title , short_model_name = modeltitle ( filename , h )
basename , _ = os . path . splitext ( filename )
basename , ext = os . path . splitext ( filename )
config = basename + " .yaml "
if not os . path . exists ( config ) :
config = shared . cmd_opts . config
checkpoints_list [ title ] = CheckpointInfo ( filename , title , h , short_model_name , config )
checkpoints_list [ title ] = CheckpointInfo ( filename , title , h , short_model_name , config , ext )
def get_closet_checkpoint_match ( searchString ) :
@ -173,7 +175,13 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
# load from file
print ( f " Loading weights [ { sd_model_hash } ] from { checkpoint_file } " )
pl_sd = torch . load ( checkpoint_file , map_location = shared . weight_load_location )
if ( checkpoint_types [ checkpoint_info . exttype ] == ' safetensors ' ) :
# safely load weights
# TODO: safetensors supports zero copy fast load to gpu, see issue #684
pl_sd = load_file ( checkpoint_file , device = shared . weight_load_location )
else :
pl_sd = torch . load ( checkpoint_file , map_location = shared . weight_load_location )
if " global_step " in pl_sd :
print ( f " Global Step: { pl_sd [ ' global_step ' ] } " )