@ -199,7 +199,7 @@ class StableDiffusionProcessing():
def init ( self , all_prompts , all_seeds , all_subseeds ) :
pass
def sample ( self , conditioning , unconditional_conditioning , seeds , subseeds , subseed_strength ):
def sample ( self , conditioning , unconditional_conditioning , seeds , subseeds , subseed_strength , prompts ):
raise NotImplementedError ( )
def close ( self ) :
@ -521,11 +521,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
shared . state . job = f " Batch { n + 1 } out of { p . n_iter } "
with devices . autocast ( ) :
# Only Txt2Img needs an extra argument, n, when saving intermediate images pre highres fix.
if isinstance ( p , StableDiffusionProcessingTxt2Img ) :
samples_ddim = p . sample ( conditioning = c , unconditional_conditioning = uc , seeds = seeds , subseeds = subseeds , subseed_strength = p . subseed_strength , n = n )
else :
samples_ddim = p . sample ( conditioning = c , unconditional_conditioning = uc , seeds = seeds , subseeds = subseeds , subseed_strength = p . subseed_strength )
samples_ddim = p . sample ( conditioning = c , unconditional_conditioning = uc , seeds = seeds , subseeds = subseeds , subseed_strength = p . subseed_strength , prompts = prompts )
samples_ddim = samples_ddim . to ( devices . dtype_vae )
x_samples_ddim = decode_first_stage ( p . sd_model , samples_ddim )
@ -653,7 +649,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self . truncate_x = int ( self . firstphase_width - firstphase_width_truncated ) / / opt_f
self . truncate_y = int ( self . firstphase_height - firstphase_height_truncated ) / / opt_f
def sample ( self , conditioning , unconditional_conditioning , seeds , subseeds , subseed_strength , n= 0 ) :
def sample ( self , conditioning , unconditional_conditioning , seeds , subseeds , subseed_strength , prompts ) :
self . sampler = sd_samplers . create_sampler_with_index ( sd_samplers . samplers , self . sampler_index , self . sd_model )
if not self . enable_hr :
@ -666,9 +662,21 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = samples [ : , : , self . truncate_y / / 2 : samples . shape [ 2 ] - self . truncate_y / / 2 , self . truncate_x / / 2 : samples . shape [ 3 ] - self . truncate_x / / 2 ]
""" saves image before applying hires fix, if enabled in options; takes as an arguyment either an image or batch with latent space images """
def save_intermediate ( image , index ) :
if not opts . save or self . do_not_save_samples or not opts . save_images_before_highres_fix :
return
if not isinstance ( image , Image . Image ) :
image = sd_samplers . sample_to_image ( image , index )
images . save_image ( image , self . outpath_samples , " " , seeds [ index ] , prompts [ index ] , opts . samples_format , suffix = " -before-highres-fix " )
if opts . use_scale_latent_for_hires_fix :
samples = torch . nn . functional . interpolate ( samples , size = ( self . height / / opt_f , self . width / / opt_f ) , mode = " bilinear " )
for i in range ( samples . shape [ 0 ] ) :
save_intermediate ( samples , i )
else :
decoded_samples = decode_first_stage ( self . sd_model , samples )
lowres_samples = torch . clamp ( ( decoded_samples + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
@ -678,6 +686,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x_sample = 255. * np . moveaxis ( x_sample . cpu ( ) . numpy ( ) , 0 , 2 )
x_sample = x_sample . astype ( np . uint8 )
image = Image . fromarray ( x_sample )
save_intermediate ( image , i )
image = images . resize_image ( 0 , image , self . width , self . height )
image = np . array ( image ) . astype ( np . float32 ) / 255.0
image = np . moveaxis ( image , 2 , 0 )
@ -689,15 +700,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = self . sd_model . get_first_stage_encoding ( self . sd_model . encode_first_stage ( decoded_samples ) )
# Save a copy of the image/s before doing highres fix, if applicable.
if opts . save and not self . do_not_save_samples and opts . save_images_before_highres_fix :
for i in range ( self . batch_size ) :
# This batch's ith image.
img = sd_samplers . sample_to_image ( samples , i )
# Index that accounts for both batch size and batch count.
ind = i + self . batch_size * n
images . save_image ( img , self . outpath_samples , " " , self . all_seeds [ ind ] , self . all_prompts [ ind ] , opts . samples_format , suffix = f " -before-highres-fix " )
shared . state . nextjob ( )
self . sampler = sd_samplers . create_sampler_with_index ( sd_samplers . samplers , self . sampler_index , self . sd_model )
@ -844,8 +846,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self . image_conditioning = self . img2img_image_conditioning ( image , self . init_latent , self . image_mask )
def sample ( self , conditioning , unconditional_conditioning , seeds , subseeds , subseed_strength ) :
def sample ( self , conditioning , unconditional_conditioning , seeds , subseeds , subseed_strength , prompts ) :
x = create_random_tensors ( [ opt_C , self . height / / opt_f , self . width / / opt_f ] , seeds = seeds , subseeds = subseeds , subseed_strength = self . subseed_strength , seed_resize_from_h = self . seed_resize_from_h , seed_resize_from_w = self . seed_resize_from_w , p = self )
samples = self . sampler . sample_img2img ( self , self . init_latent , x , conditioning , unconditional_conditioning , image_conditioning = self . image_conditioning )
@ -856,4 +857,4 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
del x
devices . torch_gc ( )
return samples
return samples