@ -129,6 +129,73 @@ class StableDiffusionProcessing():
self . all_seeds = None
self . all_subseeds = None
def txt2img_image_conditioning ( self , x , width = None , height = None ) :
if self . sampler . conditioning_key not in { ' hybrid ' , ' concat ' } :
# Dummy zero conditioning if we're not using inpainting model.
# Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
return torch . zeros (
x . shape [ 0 ] , 5 , 1 , 1 ,
dtype = x . dtype ,
device = x . device
)
height = height or self . height
width = width or self . width
# The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning = torch . zeros ( x . shape [ 0 ] , 3 , height , width , device = x . device )
image_conditioning = self . sd_model . get_first_stage_encoding ( self . sd_model . encode_first_stage ( image_conditioning ) )
# Add the fake full 1s mask to the first dimension.
image_conditioning = torch . nn . functional . pad ( image_conditioning , ( 0 , 0 , 0 , 0 , 1 , 0 ) , value = 1.0 )
image_conditioning = image_conditioning . to ( x . dtype )
return image_conditioning
def img2img_image_conditioning ( self , source_image , latent_image , image_mask = None ) :
if self . sampler . conditioning_key not in { ' hybrid ' , ' concat ' } :
# Dummy zero conditioning if we're not using inpainting model.
return torch . zeros (
latent_image . shape [ 0 ] , 5 , 1 , 1 ,
dtype = latent_image . dtype ,
device = latent_image . device
)
# Handle the different mask inputs
if image_mask is not None :
if torch . is_tensor ( image_mask ) :
conditioning_mask = image_mask
else :
conditioning_mask = np . array ( image_mask . convert ( " L " ) )
conditioning_mask = conditioning_mask . astype ( np . float32 ) / 255.0
conditioning_mask = torch . from_numpy ( conditioning_mask [ None , None ] )
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
conditioning_mask = torch . round ( conditioning_mask )
else :
conditioning_mask = torch . ones ( 1 , 1 , * source_image . shape [ - 2 : ] )
# Create another latent image, this time with a masked version of the original input.
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
conditioning_mask = conditioning_mask . to ( source_image . device )
conditioning_image = torch . lerp (
source_image ,
source_image * ( 1.0 - conditioning_mask ) ,
getattr ( self , " inpainting_mask_weight " , shared . opts . inpainting_mask_weight )
)
# Encode the new masked image using first stage of network.
conditioning_image = self . sd_model . get_first_stage_encoding ( self . sd_model . encode_first_stage ( conditioning_image ) )
# Create the concatenated conditioning tensor to be fed to `c_concat`
conditioning_mask = torch . nn . functional . interpolate ( conditioning_mask , size = latent_image . shape [ - 2 : ] )
conditioning_mask = conditioning_mask . expand ( conditioning_image . shape [ 0 ] , - 1 , - 1 , - 1 )
image_conditioning = torch . cat ( [ conditioning_mask , conditioning_image ] , dim = 1 )
image_conditioning = image_conditioning . to ( shared . device ) . type ( self . sd_model . dtype )
return image_conditioning
def init ( self , all_prompts , all_seeds , all_subseeds ) :
pass
@ -571,37 +638,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self . truncate_x = int ( self . firstphase_width - firstphase_width_truncated ) / / opt_f
self . truncate_y = int ( self . firstphase_height - firstphase_height_truncated ) / / opt_f
def create_dummy_mask ( self , x , width = None , height = None ) :
if self . sampler . conditioning_key in { ' hybrid ' , ' concat ' } :
height = height or self . height
width = width or self . width
# The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning = torch . zeros ( x . shape [ 0 ] , 3 , height , width , device = x . device )
image_conditioning = self . sd_model . get_first_stage_encoding ( self . sd_model . encode_first_stage ( image_conditioning ) )
# Add the fake full 1s mask to the first dimension.
image_conditioning = torch . nn . functional . pad ( image_conditioning , ( 0 , 0 , 0 , 0 , 1 , 0 ) , value = 1.0 )
image_conditioning = image_conditioning . to ( x . dtype )
else :
# Dummy zero conditioning if we're not using inpainting model.
# Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
image_conditioning = torch . zeros ( x . shape [ 0 ] , 5 , 1 , 1 , dtype = x . dtype , device = x . device )
return image_conditioning
def sample ( self , conditioning , unconditional_conditioning , seeds , subseeds , subseed_strength ) :
self . sampler = sd_samplers . create_sampler_with_index ( sd_samplers . samplers , self . sampler_index , self . sd_model )
if not self . enable_hr :
x = create_random_tensors ( [ opt_C , self . height / / opt_f , self . width / / opt_f ] , seeds = seeds , subseeds = subseeds , subseed_strength = self . subseed_strength , seed_resize_from_h = self . seed_resize_from_h , seed_resize_from_w = self . seed_resize_from_w , p = self )
samples = self . sampler . sample ( self , x , conditioning , unconditional_conditioning , image_conditioning = self . create_dummy_mask ( x ) )
samples = self . sampler . sample ( self , x , conditioning , unconditional_conditioning , image_conditioning = self . txt2img_image_conditioning ( x ) )
return samples
x = create_random_tensors ( [ opt_C , self . firstphase_height / / opt_f , self . firstphase_width / / opt_f ] , seeds = seeds , subseeds = subseeds , subseed_strength = self . subseed_strength , seed_resize_from_h = self . seed_resize_from_h , seed_resize_from_w = self . seed_resize_from_w , p = self )
samples = self . sampler . sample ( self , x , conditioning , unconditional_conditioning , image_conditioning = self . create_dummy_mask ( x , self . firstphase_width , self . firstphase_height ) )
samples = self . sampler . sample ( self , x , conditioning , unconditional_conditioning , image_conditioning = self . txt2img_image_conditioning ( x , self . firstphase_width , self . firstphase_height ) )
samples = samples [ : , : , self . truncate_y / / 2 : samples . shape [ 2 ] - self . truncate_y / / 2 , self . truncate_x / / 2 : samples . shape [ 3 ] - self . truncate_x / / 2 ]
@ -638,7 +684,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = None
devices . torch_gc ( )
samples = self . sampler . sample_img2img ( self , samples , noise , conditioning , unconditional_conditioning , steps = self . steps , image_conditioning = self . create_dummy_mask ( samples ) )
image_conditioning = self . img2img_image_conditioning (
decoded_samples ,
samples ,
decoded_samples . new_ones ( decoded_samples . shape [ 0 ] , 1 , decoded_samples . shape [ 2 ] , decoded_samples . shape [ 3 ] )
)
samples = self . sampler . sample_img2img ( self , samples , noise , conditioning , unconditional_conditioning , steps = self . steps , image_conditioning = image_conditioning )
return samples
@ -770,33 +821,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self . inpainting_fill == 3 :
self . init_latent = self . init_latent * self . mask
if self . sampler . conditioning_key in { ' hybrid ' , ' concat ' } :
if self . image_mask is not None :
conditioning_mask = np . array ( self . image_mask . convert ( " L " ) )
conditioning_mask = conditioning_mask . astype ( np . float32 ) / 255.0
conditioning_mask = torch . from_numpy ( conditioning_mask [ None , None ] )
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
conditioning_mask = torch . round ( conditioning_mask )
else :
conditioning_mask = torch . ones ( 1 , 1 , * image . shape [ - 2 : ] )
# Create another latent image, this time with a masked version of the original input.
conditioning_mask = conditioning_mask . to ( image . device )
conditioning_image = image * ( 1.0 - conditioning_mask )
conditioning_image = self . sd_model . get_first_stage_encoding ( self . sd_model . encode_first_stage ( conditioning_image ) )
# Create the concatenated conditioning tensor to be fed to `c_concat`
conditioning_mask = torch . nn . functional . interpolate ( conditioning_mask , size = self . init_latent . shape [ - 2 : ] )
conditioning_mask = conditioning_mask . expand ( conditioning_image . shape [ 0 ] , - 1 , - 1 , - 1 )
self . image_conditioning = torch . cat ( [ conditioning_mask , conditioning_image ] , dim = 1 )
self . image_conditioning = self . image_conditioning . to ( shared . device ) . type ( self . sd_model . dtype )
else :
self . image_conditioning = torch . zeros (
self . init_latent . shape [ 0 ] , 5 , 1 , 1 ,
dtype = self . init_latent . dtype ,
device = self . init_latent . device
)
self . image_conditioning = self . img2img_image_conditioning ( image , self . init_latent , self . image_mask )
def sample ( self , conditioning , unconditional_conditioning , seeds , subseeds , subseed_strength ) :