@ -40,10 +40,8 @@ samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
sampler_extra_params = {
' sample_euler ' : [ ' s_churn ' , ' s_tmin ' , ' s_tmax ' , ' s_noise ' ] ,
' sample_euler_ancestral ' : [ ' eta ' ] ,
' sample_heun ' : [ ' s_churn ' , ' s_tmin ' , ' s_tmax ' , ' s_noise ' ] ,
' sample_dpm_2 ' : [ ' s_churn ' , ' s_tmin ' , ' s_tmax ' , ' s_noise ' ] ,
' sample_dpm_2_ancestral ' : [ ' eta ' ] ,
}
def setup_img2img_steps ( p , steps = None ) :
@ -101,6 +99,8 @@ class VanillaStableDiffusionSampler:
self . init_latent = None
self . sampler_noises = None
self . step = 0
self . eta = None
self . default_eta = 0.0
def number_of_needed_noises ( self , p ) :
return 0
@ -123,20 +123,29 @@ class VanillaStableDiffusionSampler:
self . step + = 1
return res
def initialize ( self , p ) :
self . eta = p . eta or opts . eta_ddim
for fieldname in [ ' p_sample_ddim ' , ' p_sample_plms ' ] :
if hasattr ( self . sampler , fieldname ) :
setattr ( self . sampler , fieldname , self . p_sample_ddim_hook )
self . mask = p . mask if hasattr ( p , ' mask ' ) else None
self . nmask = p . nmask if hasattr ( p , ' nmask ' ) else None
def sample_img2img ( self , p , x , noise , conditioning , unconditional_conditioning , steps = None ) :
steps , t_enc = setup_img2img_steps ( p , steps )
# existing code fails with cetain step counts, like 9
try :
self . sampler . make_schedule ( ddim_num_steps = steps , ddim_eta = p . ddim_eta , ddim_discretize = p . ddim_discretize , verbose = False )
self . sampler . make_schedule ( ddim_num_steps = steps , ddim_eta = self . eta, ddim_discretize = p . ddim_discretize , verbose = False )
except Exception :
self . sampler . make_schedule ( ddim_num_steps = steps + 1 , ddim_eta = p . ddim_eta , ddim_discretize = p . ddim_discretize , verbose = False )
self . sampler . make_schedule ( ddim_num_steps = steps + 1 , ddim_eta = self . eta, ddim_discretize = p . ddim_discretize , verbose = False )
x1 = self . sampler . stochastic_encode ( x , torch . tensor ( [ t_enc ] * int ( x . shape [ 0 ] ) ) . to ( shared . device ) , noise = noise )
self . sampler . p_sample_ddim = self . p_sample_ddim_hook
self . mask = p . mask if hasattr ( p , ' mask ' ) else None
self . nmask = p . nmask if hasattr ( p , ' nmask ' ) else None
self . initialize ( p )
self . init_latent = x
self . step = 0
@ -145,11 +154,8 @@ class VanillaStableDiffusionSampler:
return samples
def sample ( self , p , x , conditioning , unconditional_conditioning , steps = None ) :
for fieldname in [ ' p_sample_ddim ' , ' p_sample_plms ' ] :
if hasattr ( self . sampler , fieldname ) :
setattr ( self . sampler , fieldname , self . p_sample_ddim_hook )
self . mask = None
self . nmask = None
self . initialize ( p )
self . init_latent = None
self . step = 0
@ -157,9 +163,9 @@ class VanillaStableDiffusionSampler:
# existing code fails with cetin step counts, like 9
try :
samples_ddim , _ = self . sampler . sample ( S = steps , conditioning = conditioning , batch_size = int ( x . shape [ 0 ] ) , shape = x [ 0 ] . shape , verbose = False , unconditional_guidance_scale = p . cfg_scale , unconditional_conditioning = unconditional_conditioning , x_T = x , eta = p . eta )
samples_ddim , _ = self . sampler . sample ( S = steps , conditioning = conditioning , batch_size = int ( x . shape [ 0 ] ) , shape = x [ 0 ] . shape , verbose = False , unconditional_guidance_scale = p . cfg_scale , unconditional_conditioning = unconditional_conditioning , x_T = x , eta = self . eta )
except Exception :
samples_ddim , _ = self . sampler . sample ( S = steps + 1 , conditioning = conditioning , batch_size = int ( x . shape [ 0 ] ) , shape = x [ 0 ] . shape , verbose = False , unconditional_guidance_scale = p . cfg_scale , unconditional_conditioning = unconditional_conditioning , x_T = x , eta = p . eta )
samples_ddim , _ = self . sampler . sample ( S = steps + 1 , conditioning = conditioning , batch_size = int ( x . shape [ 0 ] ) , shape = x [ 0 ] . shape , verbose = False , unconditional_guidance_scale = p . cfg_scale , unconditional_conditioning = unconditional_conditioning , x_T = x , eta = self . eta )
return samples_ddim
@ -237,6 +243,8 @@ class KDiffusionSampler:
self . sampler_noises = None
self . sampler_noise_index = 0
self . stop_at = None
self . eta = None
self . default_eta = 1.0
def callback_state ( self , d ) :
store_latent ( d [ " denoised " ] )
@ -255,22 +263,12 @@ class KDiffusionSampler:
self . sampler_noise_index + = 1
return res
def sample_img2img ( self , p , x , noise , conditioning , unconditional_conditioning , steps = None ) :
steps , t_enc = setup_img2img_steps ( p , steps )
sigmas = self . model_wrap . get_sigmas ( steps )
noise = noise * sigmas [ steps - t_enc - 1 ]
xi = x + noise
sigma_sched = sigmas [ steps - t_enc - 1 : ]
def initialize ( self , p ) :
self . model_wrap_cfg . mask = p . mask if hasattr ( p , ' mask ' ) else None
self . model_wrap_cfg . nmask = p . nmask if hasattr ( p , ' nmask ' ) else None
self . model_wrap_cfg . init_latent = x
self . model_wrap . step = 0
self . sampler_noise_index = 0
self . eta = p . eta or opts . eta_ancestral
if hasattr ( k_diffusion . sampling , ' trange ' ) :
k_diffusion . sampling . trange = lambda * args , * * kwargs : extended_trange ( self , * args , * * kwargs )
@ -283,6 +281,25 @@ class KDiffusionSampler:
if hasattr ( p , param_name ) and param_name in inspect . signature ( self . func ) . parameters :
extra_params_kwargs [ param_name ] = getattr ( p , param_name )
if ' eta ' in inspect . signature ( self . func ) . parameters :
extra_params_kwargs [ ' eta ' ] = self . eta
return extra_params_kwargs
def sample_img2img ( self , p , x , noise , conditioning , unconditional_conditioning , steps = None ) :
steps , t_enc = setup_img2img_steps ( p , steps )
sigmas = self . model_wrap . get_sigmas ( steps )
noise = noise * sigmas [ steps - t_enc - 1 ]
xi = x + noise
extra_params_kwargs = self . initialize ( p )
sigma_sched = sigmas [ steps - t_enc - 1 : ]
self . model_wrap_cfg . init_latent = x
return self . func ( self . model_wrap_cfg , xi , sigma_sched , extra_args = { ' cond ' : conditioning , ' uncond ' : unconditional_conditioning , ' cond_scale ' : p . cfg_scale } , disable = False , callback = self . callback_state , * * extra_params_kwargs )
def sample ( self , p , x , conditioning , unconditional_conditioning , steps = None ) :
@ -291,19 +308,7 @@ class KDiffusionSampler:
sigmas = self . model_wrap . get_sigmas ( steps )
x = x * sigmas [ 0 ]
self . model_wrap_cfg . step = 0
self . sampler_noise_index = 0
if hasattr ( k_diffusion . sampling , ' trange ' ) :
k_diffusion . sampling . trange = lambda * args , * * kwargs : extended_trange ( self , * args , * * kwargs )
if self . sampler_noises is not None :
k_diffusion . sampling . torch = TorchHijack ( self )
extra_params_kwargs = { }
for param_name in self . extra_params :
if hasattr ( p , param_name ) and param_name in inspect . signature ( self . func ) . parameters :
extra_params_kwargs [ param_name ] = getattr ( p , param_name )
extra_params_kwargs = self . initialize ( p )
samples = self . func ( self . model_wrap_cfg , x , sigmas , extra_args = { ' cond ' : conditioning , ' uncond ' : unconditional_conditioning , ' cond_scale ' : p . cfg_scale } , disable = False , callback = self . callback_state , * * extra_params_kwargs )