@ -17,6 +17,7 @@ from ldm.util import instantiate_from_config
from modules import shared , modelloader , devices , script_callbacks , sd_vae , sd_disable_initialization , errors , hashes
from modules import shared , modelloader , devices , script_callbacks , sd_vae , sd_disable_initialization , errors , hashes
from modules . paths import models_path
from modules . paths import models_path
from modules . sd_hijack_inpainting import do_inpainting_hijack , should_hijack_inpainting
from modules . sd_hijack_inpainting import do_inpainting_hijack , should_hijack_inpainting
from modules . sd_hijack_ip2p import should_hijack_ip2p
model_dir = " Stable-diffusion "
model_dir = " Stable-diffusion "
model_path = os . path . abspath ( os . path . join ( models_path , model_dir ) )
model_path = os . path . abspath ( os . path . join ( models_path , model_dir ) )
@ -365,6 +366,15 @@ def load_model(checkpoint_info=None):
sd_config . model . params . unet_config . params . in_channels = 9
sd_config . model . params . unet_config . params . in_channels = 9
sd_config . model . params . finetune_keys = None
sd_config . model . params . finetune_keys = None
if should_hijack_ip2p ( checkpoint_info ) :
sd_config . model . target = " modules.models.diffusion.ddpm_edit.LatentDiffusion "
sd_config . model . params . conditioning_key = " hybrid "
sd_config . model . params . first_stage_key = " edited "
sd_config . model . params . cond_stage_key = " edit "
sd_config . model . params . image_size = 16
sd_config . model . params . unet_config . params . in_channels = 8
sd_config . model . params . unet_config . params . out_channels = 4
if not hasattr ( sd_config . model . params , " use_ema " ) :
if not hasattr ( sd_config . model . params , " use_ema " ) :
sd_config . model . params . use_ema = False
sd_config . model . params . use_ema = False
@ -429,7 +439,7 @@ def reload_model_weights(sd_model=None, info=None):
checkpoint_config = find_checkpoint_config ( current_checkpoint_info )
checkpoint_config = find_checkpoint_config ( current_checkpoint_info )
if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config ( checkpoint_info ) or should_hijack_inpainting ( checkpoint_info ) != should_hijack_inpainting ( sd_model . sd_checkpoint_info ) :
if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config ( checkpoint_info ) or should_hijack_inpainting ( checkpoint_info ) != should_hijack_inpainting ( sd_model . sd_checkpoint_info ) or should_hijack_ip2p ( checkpoint_info ) != should_hijack_ip2p ( sd_model . sd_checkpoint_info ) :
del sd_model
del sd_model
checkpoints_loaded . clear ( )
checkpoints_loaded . clear ( )
load_model ( checkpoint_info )
load_model ( checkpoint_info )