@ -8,12 +8,12 @@ from omegaconf import OmegaConf
from PIL import Image
from itertools import islice
from einops import rearrange , repeat
from torchvision . utils import make_grid
from torch import autocast
from contextlib import contextmanager , nullcontext
import mimetypes
import random
import math
import csv
import k_diffusion as K
from ldm . util import instantiate_from_config
@ -28,6 +28,8 @@ mimetypes.add_type('application/javascript', '.js')
opt_C = 4
opt_f = 8
invalid_filename_chars = ' <>: " / \ |?* '
parser = argparse . ArgumentParser ( )
parser . add_argument ( " --outdir " , type = str , nargs = " ? " , help = " dir to write results to " , default = None )
parser . add_argument ( " --skip_grid " , action = ' store_true ' , help = " do not save a grid, only individual samples. Helpful when evaluating lots of samples " , )
@ -127,13 +129,14 @@ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cp
model = model . half ( ) . to ( device )
def image_grid ( imgs , batch_size ):
def image_grid ( imgs , batch_size , round_down = False ):
if opt . n_rows > 0 :
rows = opt . n_rows
elif opt . n_rows == 0 :
rows = batch_size
else :
rows = round ( math . sqrt ( len ( imgs ) ) )
rows = math . sqrt ( len ( imgs ) )
rows = int ( rows ) if round_down else round ( rows )
cols = math . ceil ( len ( imgs ) / rows )
@ -146,7 +149,7 @@ def image_grid(imgs, batch_size):
return grid
def dream ( prompt : str , ddim_steps : int , sampler_name : str , use_GFPGAN : bool , ddim_eta: float , n_iter : int , n_samples : int , cfg_scale : float , seed : int , height : int , width : int ) :
def dream ( prompt : str , ddim_steps : int , sampler_name : str , use_GFPGAN : bool , prompt_matrix: bool , ddim_eta: float , n_iter : int , n_samples : int , cfg_scale : float , seed : int , height : int , width : int ) :
torch . cuda . empty_cache ( )
outpath = opt . outdir or " outputs/txt2img-samples "
@ -155,6 +158,7 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, ddi
seed = random . randrange ( 4294967294 )
seed = int ( seed )
keep_same_seed = False
is_PLMS = sampler_name == ' PLMS '
is_DDIM = sampler_name == ' DDIM '
@ -177,59 +181,99 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, ddi
batch_size = n_samples
assert prompt is not None
data = [ batch_size * [ prompt ] ]
prompts = batch_size * [ prompt ]
sample_path = os . path . join ( outpath , " samples " )
os . makedirs ( sample_path , exist_ok = True )
base_count = len ( os . listdir ( sample_path ) )
grid_count = len ( os . listdir ( outpath ) ) - 1
prompt_matrix_prompts = [ ]
comment = " "
if prompt_matrix :
keep_same_seed = True
comment = " Image prompts: \n \n "
items = prompt . split ( " | " )
combination_count = 2 * * ( len ( items ) - 1 )
for combination_num in range ( combination_count ) :
current = items [ 0 ]
label = ' A '
for n , text in enumerate ( items [ 1 : ] ) :
if combination_num & ( 2 * * n ) > 0 :
current + = ( " " if text . strip ( ) . startswith ( " , " ) else " , " ) + text
label + = chr ( ord ( ' B ' ) + n )
comment + = " - " + label + " \n "
prompt_matrix_prompts . append ( current )
n_iter = math . ceil ( len ( prompt_matrix_prompts ) / batch_size )
comment + = " \n where: \n "
for n , text in enumerate ( items ) :
comment + = " " + chr ( ord ( ' A ' ) + n ) + " = " + items [ n ] + " \n "
precision_scope = autocast if opt . precision == " autocast " else nullcontext
output_images = [ ]
with torch . no_grad ( ) , precision_scope ( " cuda " ) , model . ema_scope ( ) :
for n in range ( n_iter ) :
for batch_index , prompts in enumerate ( data ) :
uc = None
if cfg_scale != 1.0 :
uc = model . get_learned_conditioning ( batch_size * [ " " ] )
if isinstance ( prompts , tuple ) :
prompts = list ( prompts )
c = model . get_learned_conditioning ( prompts )
shape = [ opt_C , height / / opt_f , width / / opt_f ]
current_seed = seed + n * len ( data ) + batch_index
if prompt_matrix :
prompts = prompt_matrix_prompts [ n * batch_size : ( n + 1 ) * batch_size ]
uc = None
if cfg_scale != 1.0 :
uc = model . get_learned_conditioning ( len ( prompts ) * [ " " ] )
if isinstance ( prompts , tuple ) :
prompts = list ( prompts )
c = model . get_learned_conditioning ( prompts )
shape = [ opt_C , height / / opt_f , width / / opt_f ]
batch_seed = seed if keep_same_seed else seed + n * len ( prompts )
# we manually generate all input noises because each one should have a specific seed
xs = [ ]
for i in range ( len ( prompts ) ) :
current_seed = seed if keep_same_seed else batch_seed + i
torch . manual_seed ( current_seed )
xs . append ( torch . randn ( shape , device = device ) )
x = torch . stack ( xs )
if is_Kdif :
sigmas = model_wrap . get_sigmas ( ddim_steps )
x = torch . randn ( [ n_samples , * shape ] , device = device ) * sigmas [ 0 ] # for GPU draw
model_wrap_cfg = CFGDenoiser ( model_wrap )
samples_ddim = K . sampling . sample_lms ( model_wrap_cfg , x , sigmas , extra_args = { ' cond ' : c , ' uncond ' : uc , ' cond_scale ' : cfg_scale } , disable = False )
if is_Kdif :
sigmas = model_wrap . get_sigmas ( ddim_steps )
x = x * sigmas [ 0 ]
model_wrap_cfg = CFGDenoiser ( model_wrap )
samples_ddim = K . sampling . sample_lms ( model_wrap_cfg , x , sigmas , extra_args = { ' cond ' : c , ' uncond ' : uc , ' cond_scale ' : cfg_scale } , disable = False )
elif sampler is not None :
samples_ddim , _ = sampler . sample ( S = ddim_steps , conditioning = c , batch_size = n_samples , shape = shape , verbose = False , unconditional_guidance_scale = cfg_scale , unconditional_conditioning = uc , eta = ddim_eta , x_T = None )
elif sampler is not None :
samples_ddim , _ = sampler . sample ( S = ddim_steps , conditioning = c , batch_size = len ( prompts ) , shape = shape , verbose = False , unconditional_guidance_scale = cfg_scale , unconditional_conditioning = uc , eta = ddim_eta , x_T = x )
x_samples_ddim = model . decode_first_stage ( samples_ddim )
x_samples_ddim = torch . clamp ( ( x_samples_ddim + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
x_samples_ddim = model . decode_first_stage ( samples_ddim )
x_samples_ddim = torch . clamp ( ( x_samples_ddim + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
if not opt . skip_save or not opt . skip_grid :
for x_sample in x_samples_ddim :
x_sample = 255. * rearrange ( x_sample . cpu ( ) . numpy ( ) , ' c h w -> h w c ' )
x_sample = x_sample . astype ( np . uint8 )
if not opt . skip_save or not opt . skip_grid :
for i , x_sample in enumerate ( x_samples_ddim ) :
x_sample = 255. * rearrange ( x_sample . cpu ( ) . numpy ( ) , ' c h w -> h w c ' )
x_sample = x_sample . astype ( np . uint8 )
if use_GFPGAN and GFPGAN is not None :
cropped_faces , restored_faces , restored_img = GFPGAN . enhance ( x_sample , has_aligned = False , only_center_face = False , paste_back = True )
x_sample = restored_img
image = Image . fromarray ( x_sample )
filename = f " { base_count : 05 } - { seed if keep_same_seed else batch_seed + i } _ { prompts [ i ] . replace ( ' ' , ' _ ' ) . translate ( { ord ( x ) : ' ' for x in invalid_filename_chars } ) [ : 128 ] } .png "
image . save ( os . path . join ( sample_path , filename ) )
output_images . append ( image )
base_count + = 1
if use_GFPGAN and GFPGAN is not None :
cropped_faces , restored_faces , restored_img = GFPGAN . enhance ( x_sample , has_aligned = False , only_center_face = False , paste_back = True )
x_sample = restored_img
image = Image . fromarray ( x_sample )
image . save ( os . path . join ( sample_path , f " { base_count : 05 } - { current_seed } _ { prompt . replace ( ' ' , ' _ ' ) [ : 128 ] } .png " ) )
output_images . append ( image )
base_count + = 1
if not opt . skip_grid :
# additionally, save as grid
grid = image_grid ( output_images , batch_size )
grid = image_grid ( output_images , batch_size , round_down = prompt_matrix )
grid . save ( os . path . join ( outpath , f ' grid- { grid_count : 04 } .png ' ) )
grid_count + = 1
@ -242,8 +286,49 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, ddi
Steps : { ddim_steps } , Sampler : { sampler_name } , CFG scale : { cfg_scale } , Seed : { seed } { ' , GFPGAN ' if use_GFPGAN and GFPGAN is not None else ' ' }
""" .strip()
if len ( comment ) > 0 :
info + = " \n \n " + comment
return output_images , seed , info
class Flagging ( gr . FlaggingCallback ) :
def setup ( self , components , flagging_dir : str ) :
pass
def flag ( self , flag_data , flag_option = None , flag_index = None , username = None ) - > int :
os . makedirs ( " log/images " , exist_ok = True )
# those must match the "dream" function
prompt , ddim_steps , sampler_name , use_GFPGAN , prompt_matrix , ddim_eta , n_iter , n_samples , cfg_scale , request_seed , height , width , images , seed , comment = flag_data
filenames = [ ]
with open ( " log/log.csv " , " a " , encoding = " utf8 " , newline = ' ' ) as file :
import time
import base64
at_start = file . tell ( ) == 0
writer = csv . writer ( file )
if at_start :
writer . writerow ( [ " prompt " , " seed " , " width " , " height " , " cfgs " , " steps " , " filename " ] )
filename_base = str ( int ( time . time ( ) * 1000 ) )
for i , filedata in enumerate ( images ) :
filename = " log/images/ " + filename_base + ( " " if len ( images ) == 1 else " - " + str ( i + 1 ) ) + " .png "
if filedata . startswith ( " data:image/png;base64, " ) :
filedata = filedata [ len ( " data:image/png;base64, " ) : ]
with open ( filename , " wb " ) as imgfile :
imgfile . write ( base64 . decodebytes ( filedata . encode ( ' utf-8 ' ) ) )
filenames . append ( filename )
writer . writerow ( [ prompt , seed , width , height , cfg_scale , ddim_steps , filenames [ 0 ] ] )
print ( " Logged: " , filenames [ 0 ] )
dream_interface = gr . Interface (
dream ,
@ -252,10 +337,11 @@ dream_interface = gr.Interface(
gr . Slider ( minimum = 1 , maximum = 150 , step = 1 , label = " Sampling Steps " , value = 50 ) ,
gr . Radio ( label = ' Sampling method ' , choices = [ " DDIM " , " PLMS " , " k-diffusion " ] , value = " k-diffusion " ) ,
gr . Checkbox ( label = ' Fix faces using GFPGAN ' , value = False , visible = GFPGAN is not None ) ,
gr . Checkbox ( label = ' Create prompt matrix (separate multiple prompts using |, and get all combinations of them) ' , value = False ) ,
gr . Slider ( minimum = 0.0 , maximum = 1.0 , step = 0.01 , label = " DDIM ETA " , value = 0.0 , visible = False ) ,
gr . Slider ( minimum = 1 , maximum = 16 , step = 1 , label = ' Sampling iterations ' , value = 1 ) ,
gr . Slider ( minimum = 1 , maximum = 4 , step = 1 , label = ' Samples per iteration ' , value = 1 ) ,
gr . Slider ( minimum = 1.0 , maximum = 15.0 , step = 0.5 , label = ' Classifier Free Guidance Scale ' , value = 7.0 ) ,
gr . Slider ( minimum = 1 , maximum = 16 , step = 1 , label = ' Batch count (how many batches of images to generate) ' , value = 1 ) ,
gr . Slider ( minimum = 1 , maximum = 4 , step = 1 , label = ' Batch size (how many images are in a batch; memory-hungry) ' , value = 1 ) ,
gr . Slider ( minimum = 1.0 , maximum = 15.0 , step = 0.5 , label = ' Classifier Free Guidance Scale (how strongly should the image follow the prompt) ' , value = 7.0 ) ,
gr . Number ( label = ' Seed ' , value = - 1 ) ,
gr . Slider ( minimum = 64 , maximum = 2048 , step = 64 , label = " Height " , value = 512 ) ,
gr . Slider ( minimum = 64 , maximum = 2048 , step = 64 , label = " Width " , value = 512 ) ,
@ -267,7 +353,7 @@ dream_interface = gr.Interface(
] ,
title = " Stable Diffusion Text-to-Image K " ,
description = " Generate images from text with Stable Diffusion (using K-LMS) " ,
allow_flagging= " never "
flagging_callback= Flagging ( )
)
@ -346,8 +432,8 @@ def translation(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, ddim_e
x_sample = restored_img
image = Image . fromarray ( x_sample )
image . save ( os . path . join ( sample_path , f " { base_count : 05 } - { current_seed } _ { prompt . replace ( ' ' , ' _ ' ) . translate ( { ord ( x ) : ' ' for x in invalid_filename_chars } ) [ : 128 ] } .png " ) )
image . save ( os . path . join ( sample_path , f " { base_count : 05 } - { current_seed } _ { prompt . replace ( ' ' , ' _ ' ) [ : 128 ] } .png " ) )
output_images . append ( image )
base_count + = 1