@ -38,9 +38,9 @@ samplers = [
samplers_for_img2img = [ x for x in samplers if x . name != ' PLMS ' ]
def setup_img2img_steps ( p ):
if opts . img2img_fix_steps :
steps = int ( p . steps / min ( p . denoising_strength , 0.999 ) )
def setup_img2img_steps ( p , steps = None ):
if opts . img2img_fix_steps or steps is not None :
steps = int ( ( steps or p . steps ) / min ( p . denoising_strength , 0.999 ) ) if p . denoising_strength > 0 else 0
t_enc = p . steps - 1
else :
steps = p . steps
@ -115,8 +115,8 @@ class VanillaStableDiffusionSampler:
self . step + = 1
return res
def sample_img2img ( self , p , x , noise , conditioning , unconditional_conditioning ):
steps , t_enc = setup_img2img_steps ( p )
def sample_img2img ( self , p , x , noise , conditioning , unconditional_conditioning , steps = None ):
steps , t_enc = setup_img2img_steps ( p , steps )
# existing code fails with cetain step counts, like 9
try :
@ -127,16 +127,16 @@ class VanillaStableDiffusionSampler:
x1 = self . sampler . stochastic_encode ( x , torch . tensor ( [ t_enc ] * int ( x . shape [ 0 ] ) ) . to ( shared . device ) , noise = noise )
self . sampler . p_sample_ddim = self . p_sample_ddim_hook
self . mask = p . mask
self . nmask = p . nmask
self . init_latent = p. init_latent
self . mask = p . mask if hasattr ( p , ' mask ' ) else None
self . nmask = p . nmask if hasattr ( p , ' nmask ' ) else None
self . init_latent = x
self . step = 0
samples = self . sampler . decode ( x1 , conditioning , t_enc , unconditional_guidance_scale = p . cfg_scale , unconditional_conditioning = unconditional_conditioning )
return samples
def sample ( self , p , x , conditioning , unconditional_conditioning ):
def sample ( self , p , x , conditioning , unconditional_conditioning , steps = None ):
for fieldname in [ ' p_sample_ddim ' , ' p_sample_plms ' ] :
if hasattr ( self . sampler , fieldname ) :
setattr ( self . sampler , fieldname , self . p_sample_ddim_hook )
@ -145,11 +145,13 @@ class VanillaStableDiffusionSampler:
self . init_latent = None
self . step = 0
steps = steps or p . steps
# existing code fails with cetin step counts, like 9
try :
samples_ddim , _ = self . sampler . sample ( S = p. steps, conditioning = conditioning , batch_size = int ( x . shape [ 0 ] ) , shape = x [ 0 ] . shape , verbose = False , unconditional_guidance_scale = p . cfg_scale , unconditional_conditioning = unconditional_conditioning , x_T = x )
samples_ddim , _ = self . sampler . sample ( S = steps, conditioning = conditioning , batch_size = int ( x . shape [ 0 ] ) , shape = x [ 0 ] . shape , verbose = False , unconditional_guidance_scale = p . cfg_scale , unconditional_conditioning = unconditional_conditioning , x_T = x )
except Exception :
samples_ddim , _ = self . sampler . sample ( S = p. steps+ 1 , conditioning = conditioning , batch_size = int ( x . shape [ 0 ] ) , shape = x [ 0 ] . shape , verbose = False , unconditional_guidance_scale = p . cfg_scale , unconditional_conditioning = unconditional_conditioning , x_T = x )
samples_ddim , _ = self . sampler . sample ( S = steps+ 1 , conditioning = conditioning , batch_size = int ( x . shape [ 0 ] ) , shape = x [ 0 ] . shape , verbose = False , unconditional_guidance_scale = p . cfg_scale , unconditional_conditioning = unconditional_conditioning , x_T = x )
return samples_ddim
@ -186,7 +188,7 @@ class CFGDenoiser(torch.nn.Module):
return denoised
def extended_trange ( count, * args , * * kwargs ) :
def extended_trange ( sampler, count, * args , * * kwargs ) :
state . sampling_steps = count
state . sampling_step = 0
@ -194,6 +196,9 @@ def extended_trange(count, *args, **kwargs):
if state . interrupted :
break
if sampler . stop_at is not None and x > sampler . stop_at :
break
yield x
state . sampling_step + = 1
@ -222,6 +227,7 @@ class KDiffusionSampler:
self . model_wrap_cfg = CFGDenoiser ( self . model_wrap )
self . sampler_noises = None
self . sampler_noise_index = 0
self . stop_at = None
def callback_state ( self , d ) :
store_latent ( d [ " denoised " ] )
@ -240,8 +246,8 @@ class KDiffusionSampler:
self . sampler_noise_index + = 1
return res
def sample_img2img ( self , p , x , noise , conditioning , unconditional_conditioning ):
steps , t_enc = setup_img2img_steps ( p )
def sample_img2img ( self , p , x , noise , conditioning , unconditional_conditioning , steps = None ):
steps , t_enc = setup_img2img_steps ( p , steps )
sigmas = self . model_wrap . get_sigmas ( steps )
@ -251,33 +257,36 @@ class KDiffusionSampler:
sigma_sched = sigmas [ steps - t_enc - 1 : ]
self . model_wrap_cfg . mask = p . mask
self . model_wrap_cfg . nmask = p . nmask
self . model_wrap_cfg . init_latent = p. init_latent
self . model_wrap_cfg . mask = p . mask if hasattr ( p , ' mask ' ) else None
self . model_wrap_cfg . nmask = p . nmask if hasattr ( p , ' nmask ' ) else None
self . model_wrap_cfg . init_latent = x
self . model_wrap . step = 0
self . sampler_noise_index = 0
if hasattr ( k_diffusion . sampling , ' trange ' ) :
k_diffusion . sampling . trange = lambda * args , * * kwargs : extended_trange ( * args , * * kwargs )
k_diffusion . sampling . trange = lambda * args , * * kwargs : extended_trange ( self , * args , * * kwargs )
if self . sampler_noises is not None :
k_diffusion . sampling . torch = TorchHijack ( self )
return self . func ( self . model_wrap_cfg , xi , sigma_sched , extra_args = { ' cond ' : conditioning , ' uncond ' : unconditional_conditioning , ' cond_scale ' : p . cfg_scale } , disable = False , callback = self . callback_state )
def sample ( self , p , x , conditioning , unconditional_conditioning ) :
sigmas = self . model_wrap . get_sigmas ( p . steps )
def sample ( self , p , x , conditioning , unconditional_conditioning , steps = None ) :
steps = steps or p . steps
sigmas = self . model_wrap . get_sigmas ( steps )
x = x * sigmas [ 0 ]
self . model_wrap_cfg . step = 0
self . sampler_noise_index = 0
if hasattr ( k_diffusion . sampling , ' trange ' ) :
k_diffusion . sampling . trange = lambda * args , * * kwargs : extended_trange ( * args , * * kwargs )
k_diffusion . sampling . trange = lambda * args , * * kwargs : extended_trange ( self , * args , * * kwargs )
if self . sampler_noises is not None :
k_diffusion . sampling . torch = TorchHijack ( self )
samples_ddim = self . func ( self . model_wrap_cfg , x , sigmas , extra_args = { ' cond ' : conditioning , ' uncond ' : unconditional_conditioning , ' cond_scale ' : p . cfg_scale } , disable = False , callback = self . callback_state )
return samples_ddim
samples = self . func ( self . model_wrap_cfg , x , sigmas , extra_args = { ' cond ' : conditioning , ' uncond ' : unconditional_conditioning , ' cond_scale ' : p . cfg_scale } , disable = False , callback = self . callback_state )
return samples