@ -11,25 +11,41 @@ from omegaconf import OmegaConf
from ldm . models . diffusion . ddim import DDIMSampler
from ldm . util import instantiate_from_config , ismap
from modules import shared , sd_hijack
warnings . filterwarnings ( " ignore " , category = UserWarning )
cached_ldsr_model : torch . nn . Module = None
# Create LDSR Class
class LDSR :
def load_model_from_config ( self , half_attention ) :
print ( f " Loading model from { self . modelPath } " )
pl_sd = torch . load ( self . modelPath , map_location = " cpu " )
sd = pl_sd [ " state_dict " ]
config = OmegaConf . load ( self . yamlPath )
config . model . target = " ldm.models.diffusion.ddpm.LatentDiffusionV1 "
model = instantiate_from_config ( config . model )
model . load_state_dict ( sd , strict = False )
model . cuda ( )
if half_attention :
model = model . half ( )
model . eval ( )
global cached_ldsr_model
if shared . opts . ldsr_cached and cached_ldsr_model is not None :
print ( f " Loading model from cache " )
model : torch . nn . Module = cached_ldsr_model
else :
print ( f " Loading model from { self . modelPath } " )
pl_sd = torch . load ( self . modelPath , map_location = " cpu " )
sd = pl_sd [ " state_dict " ]
config = OmegaConf . load ( self . yamlPath )
config . model . target = " ldm.models.diffusion.ddpm.LatentDiffusionV1 "
model : torch . nn . Module = instantiate_from_config ( config . model )
model . load_state_dict ( sd , strict = False )
model = model . to ( shared . device )
if half_attention :
model = model . half ( )
if shared . cmd_opts . opt_channelslast :
model = model . to ( memory_format = torch . channels_last )
sd_hijack . model_hijack . hijack ( model ) # apply optimization
model . eval ( )
if shared . opts . ldsr_cached :
cached_ldsr_model = model
return { " model " : model }
def __init__ ( self , model_path , yaml_path ) :