@ -41,90 +41,6 @@ sampler_extra_params = {
' sample_dpm_2 ' : [ ' s_churn ' , ' s_tmin ' , ' s_tmax ' , ' s_noise ' ] ,
}
class CFGDenoiserEdit ( torch . nn . Module ) :
"""
Classifier free guidance denoiser . A wrapper for stable diffusion model ( specifically for unet )
that can take a noisy picture and produce a noise - free picture using two guidances ( prompts )
instead of one . Originally , the second prompt is just an empty string , but we use non - empty
negative prompt .
"""
def __init__ ( self , model ) :
super ( ) . __init__ ( )
self . inner_model = model
self . mask = None
self . nmask = None
self . init_latent = None
self . step = 0
def combine_denoised ( self , x_out , conds_list , uncond , cond_scale , image_cfg_scale ) :
denoised_uncond = x_out [ - uncond . shape [ 0 ] : ]
denoised = torch . clone ( denoised_uncond )
for i , conds in enumerate ( conds_list ) :
for cond_index , weight in conds :
out_cond , out_img_cond , out_uncond = x_out . chunk ( 3 )
denoised [ i ] = out_uncond [ cond_index ] + cond_scale * ( out_cond [ cond_index ] - out_img_cond [ cond_index ] ) + image_cfg_scale * ( out_img_cond [ cond_index ] - out_uncond [ cond_index ] )
return denoised
def forward ( self , x , sigma , uncond , cond , cond_scale , image_cond , image_cfg_scale ) :
if state . interrupted or state . skipped :
raise sd_samplers_common . InterruptedException
conds_list , tensor = prompt_parser . reconstruct_multicond_batch ( cond , self . step )
uncond = prompt_parser . reconstruct_cond_batch ( uncond , self . step )
batch_size = len ( conds_list )
repeats = [ len ( conds_list [ i ] ) for i in range ( batch_size ) ]
x_in = torch . cat ( [ torch . stack ( [ x [ i ] for _ in range ( n ) ] ) for i , n in enumerate ( repeats ) ] + [ x ] + [ x ] )
sigma_in = torch . cat ( [ torch . stack ( [ sigma [ i ] for _ in range ( n ) ] ) for i , n in enumerate ( repeats ) ] + [ sigma ] + [ sigma ] )
image_cond_in = torch . cat ( [ torch . stack ( [ image_cond [ i ] for _ in range ( n ) ] ) for i , n in enumerate ( repeats ) ] + [ image_cond ] + [ torch . zeros_like ( self . init_latent ) ] )
denoiser_params = CFGDenoiserParams ( x_in , image_cond_in , sigma_in , state . sampling_step , state . sampling_steps )
cfg_denoiser_callback ( denoiser_params )
x_in = denoiser_params . x
image_cond_in = denoiser_params . image_cond
sigma_in = denoiser_params . sigma
if tensor . shape [ 1 ] == uncond . shape [ 1 ] :
cond_in = torch . cat ( [ tensor , uncond , uncond ] )
if shared . batch_cond_uncond :
x_out = self . inner_model ( x_in , sigma_in , cond = { " c_crossattn " : [ cond_in ] , " c_concat " : [ image_cond_in ] } )
else :
x_out = torch . zeros_like ( x_in )
for batch_offset in range ( 0 , x_out . shape [ 0 ] , batch_size ) :
a = batch_offset
b = a + batch_size
x_out [ a : b ] = self . inner_model ( x_in [ a : b ] , sigma_in [ a : b ] , cond = { " c_crossattn " : [ cond_in [ a : b ] ] , " c_concat " : [ image_cond_in [ a : b ] ] } )
else :
x_out = torch . zeros_like ( x_in )
batch_size = batch_size * 2 if shared . batch_cond_uncond else batch_size
for batch_offset in range ( 0 , tensor . shape [ 0 ] , batch_size ) :
a = batch_offset
b = min ( a + batch_size , tensor . shape [ 0 ] )
x_out [ a : b ] = self . inner_model ( x_in [ a : b ] , sigma_in [ a : b ] , cond = { " c_crossattn " : torch . cat ( [ tensor [ a : b ] ] , uncond ) , " c_concat " : [ image_cond_in [ a : b ] ] } )
x_out [ - uncond . shape [ 0 ] : ] = self . inner_model ( x_in [ - uncond . shape [ 0 ] : ] , sigma_in [ - uncond . shape [ 0 ] : ] , cond = { " c_crossattn " : [ uncond ] , " c_concat " : [ image_cond_in [ - uncond . shape [ 0 ] : ] ] } )
devices . test_for_nans ( x_out , " unet " )
if opts . live_preview_content == " Prompt " :
sd_samplers_common . store_latent ( x_out [ 0 : uncond . shape [ 0 ] ] )
elif opts . live_preview_content == " Negative prompt " :
sd_samplers_common . store_latent ( x_out [ - uncond . shape [ 0 ] : ] )
denoised = self . combine_denoised ( x_out , conds_list , uncond , cond_scale , image_cfg_scale )
if self . mask is not None :
denoised = self . init_latent * self . mask + self . nmask * denoised
self . step + = 1
return denoised
class CFGDenoiser ( torch . nn . Module ) :
"""
@ -141,6 +57,7 @@ class CFGDenoiser(torch.nn.Module):
self . nmask = None
self . init_latent = None
self . step = 0
self . image_cfg_scale = None
def combine_denoised ( self , x_out , conds_list , uncond , cond_scale ) :
denoised_uncond = x_out [ - uncond . shape [ 0 ] : ]
@ -152,19 +69,36 @@ class CFGDenoiser(torch.nn.Module):
return denoised
def combine_denoised_for_edit_model ( self , x_out , cond_scale ) :
out_cond , out_img_cond , out_uncond = x_out . chunk ( 3 )
denoised = out_uncond + cond_scale * ( out_cond - out_img_cond ) + self . image_cfg_scale * ( out_img_cond - out_uncond )
return denoised
def forward ( self , x , sigma , uncond , cond , cond_scale , image_cond ) :
if state . interrupted or state . skipped :
raise sd_samplers_common . InterruptedException
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
# so is_edit_model is set to False to support AND composition.
is_edit_model = shared . sd_model . cond_stage_key == " edit " and self . image_cfg_scale is not None and self . image_cfg_scale != 1.0
conds_list , tensor = prompt_parser . reconstruct_multicond_batch ( cond , self . step )
uncond = prompt_parser . reconstruct_cond_batch ( uncond , self . step )
assert not is_edit_model or all ( [ len ( conds ) == 1 for conds in conds_list ] ) , " AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0) "
batch_size = len ( conds_list )
repeats = [ len ( conds_list [ i ] ) for i in range ( batch_size ) ]
x_in = torch . cat ( [ torch . stack ( [ x [ i ] for _ in range ( n ) ] ) for i , n in enumerate ( repeats ) ] + [ x ] )
sigma_in = torch . cat ( [ torch . stack ( [ sigma [ i ] for _ in range ( n ) ] ) for i , n in enumerate ( repeats ) ] + [ sigma ] )
image_cond_in = torch . cat ( [ torch . stack ( [ image_cond [ i ] for _ in range ( n ) ] ) for i , n in enumerate ( repeats ) ] + [ image_cond ] )
if not is_edit_model :
x_in = torch . cat ( [ torch . stack ( [ x [ i ] for _ in range ( n ) ] ) for i , n in enumerate ( repeats ) ] + [ x ] )
sigma_in = torch . cat ( [ torch . stack ( [ sigma [ i ] for _ in range ( n ) ] ) for i , n in enumerate ( repeats ) ] + [ sigma ] )
image_cond_in = torch . cat ( [ torch . stack ( [ image_cond [ i ] for _ in range ( n ) ] ) for i , n in enumerate ( repeats ) ] + [ image_cond ] )
else :
x_in = torch . cat ( [ torch . stack ( [ x [ i ] for _ in range ( n ) ] ) for i , n in enumerate ( repeats ) ] + [ x ] + [ x ] )
sigma_in = torch . cat ( [ torch . stack ( [ sigma [ i ] for _ in range ( n ) ] ) for i , n in enumerate ( repeats ) ] + [ sigma ] + [ sigma ] )
image_cond_in = torch . cat ( [ torch . stack ( [ image_cond [ i ] for _ in range ( n ) ] ) for i , n in enumerate ( repeats ) ] + [ image_cond ] + [ torch . zeros_like ( self . init_latent ) ] )
denoiser_params = CFGDenoiserParams ( x_in , image_cond_in , sigma_in , state . sampling_step , state . sampling_steps )
cfg_denoiser_callback ( denoiser_params )
@ -173,7 +107,10 @@ class CFGDenoiser(torch.nn.Module):
sigma_in = denoiser_params . sigma
if tensor . shape [ 1 ] == uncond . shape [ 1 ] :
cond_in = torch . cat ( [ tensor , uncond ] )
if not is_edit_model :
cond_in = torch . cat ( [ tensor , uncond ] )
else :
cond_in = torch . cat ( [ tensor , uncond , uncond ] )
if shared . batch_cond_uncond :
x_out = self . inner_model ( x_in , sigma_in , cond = { " c_crossattn " : [ cond_in ] , " c_concat " : [ image_cond_in ] } )
@ -189,7 +126,13 @@ class CFGDenoiser(torch.nn.Module):
for batch_offset in range ( 0 , tensor . shape [ 0 ] , batch_size ) :
a = batch_offset
b = min ( a + batch_size , tensor . shape [ 0 ] )
x_out [ a : b ] = self . inner_model ( x_in [ a : b ] , sigma_in [ a : b ] , cond = { " c_crossattn " : [ tensor [ a : b ] ] , " c_concat " : [ image_cond_in [ a : b ] ] } )
if not is_edit_model :
c_crossattn = [ tensor [ a : b ] ]
else :
c_crossattn = torch . cat ( [ tensor [ a : b ] ] , uncond )
x_out [ a : b ] = self . inner_model ( x_in [ a : b ] , sigma_in [ a : b ] , cond = { " c_crossattn " : c_crossattn , " c_concat " : [ image_cond_in [ a : b ] ] } )
x_out [ - uncond . shape [ 0 ] : ] = self . inner_model ( x_in [ - uncond . shape [ 0 ] : ] , sigma_in [ - uncond . shape [ 0 ] : ] , cond = { " c_crossattn " : [ uncond ] , " c_concat " : [ image_cond_in [ - uncond . shape [ 0 ] : ] ] } )
@ -200,7 +143,10 @@ class CFGDenoiser(torch.nn.Module):
elif opts . live_preview_content == " Negative prompt " :
sd_samplers_common . store_latent ( x_out [ - uncond . shape [ 0 ] : ] )
denoised = self . combine_denoised ( x_out , conds_list , uncond , cond_scale )
if not is_edit_model :
denoised = self . combine_denoised ( x_out , conds_list , uncond , cond_scale )
else :
denoised = self . combine_denoised_for_edit_model ( x_out , cond_scale )
if self . mask is not None :
denoised = self . init_latent * self . mask + self . nmask * denoised
@ -280,12 +226,10 @@ class KDiffusionSampler:
return p . steps
def initialize ( self , p ) :
if shared . sd_model . cond_stage_key == " edit " and getattr ( p , ' image_cfg_scale ' , None ) != 1 :
self . model_wrap_cfg = CFGDenoiserEdit ( self . model_wrap )
self . model_wrap_cfg . mask = p . mask if hasattr ( p , ' mask ' ) else None
self . model_wrap_cfg . nmask = p . nmask if hasattr ( p , ' nmask ' ) else None
self . model_wrap_cfg . step = 0
self . model_wrap_cfg . image_cfg_scale = getattr ( p , ' image_cfg_scale ' , None )
self . eta = p . eta if p . eta is not None else opts . eta_ancestral
k_diffusion . sampling . torch = TorchHijack ( self . sampler_noises if self . sampler_noises is not None else [ ] )
@ -355,9 +299,6 @@ class KDiffusionSampler:
' cond_scale ' : p . cfg_scale ,
}
if hasattr ( p , ' image_cfg_scale ' ) and p . image_cfg_scale != 1 and p . image_cfg_scale != None :
extra_args [ ' image_cfg_scale ' ] = p . image_cfg_scale
samples = self . launch_sampling ( t_enc + 1 , lambda : self . func ( self . model_wrap_cfg , xi , extra_args = extra_args , disable = False , callback = self . callback_state , * * extra_params_kwargs ) )
return samples