@ -68,6 +68,7 @@ parser.add_argument("--embeddings-dir", type=str, default='embeddings', help="em
parser . add_argument ( " --allow-code " , action = ' store_true ' , help = " allow custom script execution from webui " )
parser . add_argument ( " --medvram " , action = ' store_true ' , help = " enable stable diffusion model optimizations for sacrficing a little speed for low VRM usage " )
parser . add_argument ( " --lowvram " , action = ' store_true ' , help = " enable stable diffusion model optimizations for sacrficing a lot of speed for very low VRM usage " )
parser . add_argument ( " --always-batch-cond-uncond " , action = ' store_true ' , help = " a workaround test; may help with speed in you use --lowvram " )
parser . add_argument ( " --precision " , type = str , help = " evaluate at this precision " , choices = [ " full " , " autocast " ] , default = " autocast " )
parser . add_argument ( " --share " , action = ' store_true ' , help = " use share=True for gradio and make the UI accessible through their site (doesn ' t work for me but you might have better luck) " )
cmd_opts = parser . parse_args ( )
@ -75,9 +76,20 @@ cmd_opts = parser.parse_args()
cpu = torch . device ( " cpu " )
gpu = torch . device ( " cuda " )
device = gpu if torch . cuda . is_available ( ) else cpu
batch_cond_uncond = not ( cmd_opts . lowvram or cmd_opts . medvram )
batch_cond_uncond = cmd_opts . always_batch_cond_uncond or not ( cmd_opts . lowvram or cmd_opts . medvram )
queue_lock = threading . Lock ( )
class State :
interrupted = False
job = " "
def interrupt ( self ) :
self . interrupted = True
state = State ( )
if not cmd_opts . share :
# fix gradio phoning home
gradio . utils . version_check = lambda : None
@ -198,6 +210,7 @@ class Options:
" outdir_img2img_grids " : OptionInfo ( " outputs/img2img-grids " , ' Output dictectory for img2img grids ' ) ,
" save_to_dirs " : OptionInfo ( False , " When writing images/grids, create a directory with name derived from the prompt " ) ,
" save_to_dirs_prompt_len " : OptionInfo ( 10 , " When using above, how many words from prompt to put into directory name " , gr . Slider , { " minimum " : 1 , " maximum " : 32 , " step " : 1 } ) ,
" outdir_save " : OptionInfo ( " log/images " , " Directory for saving images using the Save button " ) ,
" samples_save " : OptionInfo ( True , " Save indiviual samples " ) ,
" samples_format " : OptionInfo ( ' png ' , ' File format for indiviual samples ' ) ,
" grid_save " : OptionInfo ( True , " Save image grids " ) ,
@ -400,8 +413,6 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
image . save ( f " { fullfn_without_extension } .jpg " , quality = opts . jpeg_quality , pnginfo = pnginfo )
def sanitize_filename_part ( text ) :
return text . replace ( ' ' , ' _ ' ) . translate ( { ord ( x ) : ' ' for x in invalid_filename_chars } ) [ : 128 ]
@ -410,6 +421,7 @@ def plaintext_to_html(text):
text = " " . join ( [ f " <p> { html . escape ( x ) } </p> \n " for x in text . split ( ' \n ' ) ] )
return text
def image_grid ( imgs , batch_size = 1 , rows = None ) :
if rows is None :
if opts . n_rows > 0 :
@ -652,18 +664,29 @@ def wrap_gradio_gpu_call(func):
return res
return f
return wrap_gradio_call( f)
def wrap_gradio_call ( func ) :
def f ( * args , * * kwargs ) :
t = time . perf_counter ( )
res = list ( func ( * args , * * kwargs ) )
try :
res = list ( func ( * args , * * kwargs ) )
except Exception as e :
print ( " Error completing request " , file = sys . stderr )
print ( " Arguments: " , args , kwargs , file = sys . stderr )
print ( traceback . format_exc ( ) , file = sys . stderr )
res = [ None , f " <div class= ' error ' > { plaintext_to_html ( type ( e ) . __name__ + ' : ' + str ( e ) ) } </div> " ]
elapsed = time . perf_counter ( ) - t
# last item is always HTML
res [ - 1 ] = res [ - 1 ] + f " <p class= ' performance ' >Time taken: { elapsed : .2f } s</p> "
state . interrupted = False
return tuple ( res )
return f
@ -883,7 +906,6 @@ class StableDiffusionProcessing:
self . extra_generation_params : dict = extra_generation_params
self . overlay_images = overlay_images
self . paste_to = None
self . progress_info = " "
def init ( self ) :
pass
@ -959,6 +981,15 @@ class CFGDenoiser(nn.Module):
return denoised
def extended_trange ( * args , * * kwargs ) :
for x in tqdm . trange ( * args , desc = state . job , * * kwargs ) :
if state . interrupted :
break
yield x
class KDiffusionSampler :
def __init__ ( self , funcname ) :
self . model_wrap = k_diffusion . external . CompVisDenoiser ( sd_model )
@ -980,7 +1011,7 @@ class KDiffusionSampler:
self . model_wrap_cfg . init_latent = p . init_latent
if hasattr ( k_diffusion . sampling , ' trange ' ) :
k_diffusion . sampling . trange = lambda * args , * * kwargs : tqdm. tqdm ( range ( * args ) , desc = p . progress_info , * * kwargs )
k_diffusion . sampling . trange = lambda * args , * * kwargs : extended_trange( * args , * * kwargs )
return self . func ( self . model_wrap_cfg , xi , sigma_sched , extra_args = { ' cond ' : conditioning , ' uncond ' : unconditional_conditioning , ' cond_scale ' : p . cfg_scale } , disable = False )
@ -989,13 +1020,36 @@ class KDiffusionSampler:
x = x * sigmas [ 0 ]
if hasattr ( k_diffusion . sampling , ' trange ' ) :
k_diffusion . sampling . trange = lambda * args , * * kwargs : tqdm. tqdm ( range ( * args ) , desc = p . progress_info , * * kwargs )
k_diffusion . sampling . trange = lambda * args , * * kwargs : extended_trange( * args , * * kwargs )
samples_ddim = self . func ( self . model_wrap_cfg , x , sigmas , extra_args = { ' cond ' : conditioning , ' uncond ' : unconditional_conditioning , ' cond_scale ' : p . cfg_scale } , disable = False )
return samples_ddim
Processed = namedtuple ( ' Processed ' , [ ' images ' , ' seed ' , ' info ' ] )
class Processed :
def __init__ ( self , p : StableDiffusionProcessing , images , seed , info ) :
self . images = images
self . prompt = p . prompt
self . seed = seed
self . info = info
self . width = p . width
self . height = p . height
self . sampler = samplers [ p . sampler_index ] . name
self . cfg_scale = p . cfg_scale
self . steps = p . steps
def js ( self ) :
obj = {
" prompt " : self . prompt ,
" seed " : int ( self . seed ) ,
" width " : self . width ,
" height " : self . height ,
" sampler " : self . sampler ,
" cfg_scale " : self . cfg_scale ,
" steps " : self . steps ,
}
return json . dumps ( obj )
def process_images ( p : StableDiffusionProcessing ) - > Processed :
@ -1063,6 +1117,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
p . init ( )
for n in range ( p . n_iter ) :
if state . interrupted :
break
prompts = all_prompts [ n * p . batch_size : ( n + 1 ) * p . batch_size ]
seeds = all_seeds [ n * p . batch_size : ( n + 1 ) * p . batch_size ]
@ -1075,7 +1132,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
# we manually generate all input noises because each one should have a specific seed
x = create_random_tensors ( [ opt_C , p . height / / opt_f , p . width / / opt_f ] , seeds = seeds )
p . progress_info = f " Batch { n + 1 } out of { p . n_iter } "
if p . n_iter > 0 :
state . job = f " Batch { n + 1 } out of { p . n_iter } "
samples_ddim = p . sample ( x = x , conditioning = c , unconditional_conditioning = uc )
x_samples_ddim = model . decode_first_stage ( samples_ddim )
@ -1137,7 +1196,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
save_image ( grid , p . outpath_grids , " grid " , seed , prompt , opts . grid_format , info = infotext ( ) , short_filename = not opts . grid_extended_filename )
torch_gc ( )
return Processed ( output_images, seed , infotext ( ) )
return Processed ( p, output_images, seed , infotext ( ) )
class StableDiffusionProcessingTxt2Img ( StableDiffusionProcessing ) :
@ -1188,52 +1247,47 @@ def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, u
module . display = display
exec ( compiled , module . __dict__ )
processed = Processed ( * display_result_data )
processed = Processed ( p , * display_result_data )
else :
processed = process_images ( p )
return processed . images , processed . seed , plaintext_to_html ( processed . info )
return processed . images , processed . js( ) , plaintext_to_html ( processed . info )
def save_files ( js_data , images ) :
import csv
class Flagging ( gr . FlaggingCallback ) :
os . makedirs ( opts . outdir_save , exist_ok = True )
def setup ( self , components , flagging_dir : str ) :
pass
def flag ( self , flag_data , flag_option = None , flag_index = None , username = None ) :
import csv
filenames = [ ]
os . makedirs ( " log/images " , exist_ok = True )
data = json . loads ( js_data )
# those must match the "txt2img" function
prompt , steps , sampler_index , use_gfpgan , prompt_matrix , n_iter , batch_size , cfg_scale , seed , height , width , code , images , seed , comment = flag_data
with open ( " log/log.csv " , " a " , encoding = " utf8 " , newline = ' ' ) as file :
import time
import base64
filenames = [ ]
at_start = file . tell ( ) == 0
writer = csv . writer ( file )
if at_start :
writer . writerow ( [ " prompt " , " seed " , " width " , " height " , " sampler " , " cfgs " , " steps " , " filename " ] )
with open ( " log/log.csv " , " a " , encoding = " utf8 " , newline = ' ' ) as file :
import time
import base64
filename_base = str ( int ( time . time ( ) * 1000 ) )
for i , filedata in enumerate ( images ) :
filename = filename_base + ( " " if len ( images ) == 1 else " - " + str ( i + 1 ) ) + " .png "
filepath = os . path . join ( opts . outdir_save , filename )
at_start = file . tell ( ) == 0
writer = csv . writer ( file )
if at_start :
writer . writerow ( [ " prompt " , " seed " , " width " , " height " , " cfgs " , " steps " , " filename " ] )
if filedata . startswith ( " data:image/png;base64, " ) :
filedata = filedata [ len ( " data:image/png;base64, " ) : ]
filename_base = str ( int ( time . time ( ) * 1000 ) )
for i , filedata in enumerate ( images ) :
filename = " log/images/ " + filename_base + ( " " if len ( images ) == 1 else " - " + str ( i + 1 ) ) + " .png "
with open ( filepath , " wb " ) as imgfile :
imgfile . write ( base64 . decodebytes ( filedata . encode ( ' utf-8 ' ) ) )
if filedata . startswith ( " data:image/png;base64, " ) :
filedata = filedata [ len ( " data:image/png;base64, " ) : ]
filenames . append ( filename )
with open ( filename , " wb " ) as imgfile :
imgfile . write ( base64 . decodebytes ( filedata . encode ( ' utf-8 ' ) ) )
writer . writerow ( [ data [ " prompt " ] , data [ " seed " ] , data [ " width " ] , data [ " height " ] , data [ " sampler " ] , data [ " cfg_scale " ] , data [ " steps " ] , filenames [ 0 ] ] )
filenames . append ( filename )
return ' ' , ' ' , plaintext_to_html ( f " Saved: { filenames [ 0 ] } " )
writer . writerow ( [ prompt , seed , width , height , cfg_scale , steps , filenames [ 0 ] ] )
print ( " Logged: " , filenames [ 0 ] )
with gr . Blocks ( analytics_enabled = False ) as txt2img_interface :
with gr . Row ( ) :
@ -1267,8 +1321,15 @@ with gr.Blocks(analytics_enabled=False) as txt2img_interface:
with gr . Column ( variant = ' panel ' ) :
with gr . Group ( ) :
gallery = gr . Gallery ( label = ' Output ' )
output_seed = gr . Number ( label = ' Seed ' , visible = False )
with gr . Group ( ) :
with gr . Row ( ) :
interrupt = gr . Button ( ' Interrupt ' )
save = gr . Button ( ' Save ' )
with gr . Group ( ) :
html_info = gr . HTML ( )
generation_info = gr . Textbox ( visible = False )
txt2img_args = dict (
fn = wrap_gradio_gpu_call ( txt2img ) ,
@ -1289,7 +1350,7 @@ with gr.Blocks(analytics_enabled=False) as txt2img_interface:
] ,
outputs = [
gallery ,
output_seed ,
generation_info ,
html_info
]
)
@ -1297,6 +1358,25 @@ with gr.Blocks(analytics_enabled=False) as txt2img_interface:
prompt . submit ( * * txt2img_args )
submit . click ( * * txt2img_args )
interrupt . click (
fn = lambda : state . interrupt ( ) ,
inputs = [ ] ,
outputs = [ ] ,
)
save . click (
fn = wrap_gradio_call ( save_files ) ,
inputs = [
generation_info ,
gallery ,
] ,
outputs = [
html_info ,
html_info ,
html_info ,
]
)
def get_crop_region ( mask , pad = 0 ) :
h , w = mask . shape
@ -1508,6 +1588,7 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index
p . batch_size = 1
p . do_not_save_grid = True
state . job = f " Batch { i + 1 } out of { n_iter } "
processed = process_images ( p )
if initial_seed is None :
@ -1523,13 +1604,13 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index
save_image ( grid , p . outpath_grids , " grid " , initial_seed , prompt , opts . grid_format , info = info , short_filename = not opts . grid_extended_filename )
processed = Processed ( history, initial_seed , initial_info )
processed = Processed ( p, history, initial_seed , initial_info )
elif is_upscale :
initial_seed = None
initial_info = None
upscaler = sd_upscalers [ upscaler_name ]
upscaler = sd_upscalers . get ( upscaler_name , next ( iter ( sd_upscalers . values ( ) ) ) )
img = upscaler ( init_img )
torch_gc ( )
@ -1553,6 +1634,7 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index
for i in range ( batch_count ) :
p . init_images = work [ i * p . batch_size : ( i + 1 ) * p . batch_size ]
state . job = f " Batch { i + 1 } out of { batch_count } "
processed = process_images ( p )
if initial_seed is None :
@ -1565,19 +1647,19 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index
image_index = 0
for y , h , row in grid . tiles :
for tiledata in row :
tiledata [ 2 ] = work_results [ image_index ]
tiledata [ 2 ] = work_results [ image_index ] if image_index < len ( work_results ) else Image . new ( " RGB " , ( p . width , p . height ) )
image_index + = 1
combined_image = combine_grid ( grid )
save_image ( combined_image , p . outpath_grids , " grid " , initial_seed , prompt , opts . grid_format , info = initial_info , short_filename = not opts . grid_extended_filename )
processed = Processed ( [ combined_image ] , initial_seed , initial_info )
processed = Processed ( p , [ combined_image ] , initial_seed , initial_info )
else :
processed = process_images ( p )
return processed . images , processed . seed , plaintext_to_html ( processed . info )
return processed . images , processed . js( ) , plaintext_to_html ( processed . info )
sample_img2img = " assets/stable-samples/img2img/sketch-mountains-input.jpg "
@ -1609,8 +1691,8 @@ with gr.Blocks(analytics_enabled=False) as img2img_interface:
inpaint_full_res = gr . Checkbox ( label = ' Inpaint at full resolution ' , value = True , visible = False )
with gr . Row ( ) :
sd_upscale_upscaler_name = gr . Radio ( label = ' Upscaler ' , choices = list ( sd_upscalers . keys ( ) ) , value = " RealESRGAN " )
sd_upscale_overlap = gr . Slider ( minimum = 0 , maximum = 256 , step = 16 , label = ' Tile overlap ' , value = 64 )
sd_upscale_upscaler_name = gr . Radio ( label = ' Upscaler ' , choices = list ( sd_upscalers . keys ( ) ) , value = list ( sd_upscalers . keys ( ) ) [ 0 ] , visible = False )
sd_upscale_overlap = gr . Slider ( minimum = 0 , maximum = 256 , step = 16 , label = ' Tile overlap ' , value = 64 , visible = False )
with gr . Row ( ) :
batch_count = gr . Slider ( minimum = 1 , maximum = cmd_opts . max_batch_count , step = 1 , label = ' Batch count ' , value = 1 )
@ -1629,8 +1711,15 @@ with gr.Blocks(analytics_enabled=False) as img2img_interface:
with gr . Column ( variant = ' panel ' ) :
with gr . Group ( ) :
gallery = gr . Gallery ( label = ' Output ' )
output_seed = gr . Number ( label = ' Seed ' , visible = False )
with gr . Group ( ) :
with gr . Row ( ) :
interrupt = gr . Button ( ' Interrupt ' )
save = gr . Button ( ' Save ' )
with gr . Group ( ) :
html_info = gr . HTML ( )
generation_info = gr . Textbox ( visible = False )
def apply_mode ( mode ) :
is_classic = mode == 0
@ -1647,7 +1736,7 @@ with gr.Blocks(analytics_enabled=False) as img2img_interface:
batch_count : gr . update ( visible = not is_upscale ) ,
batch_size : gr . update ( visible = not is_loopback ) ,
sd_upscale_upscaler_name : gr . update ( visible = is_upscale ) ,
sd_upscale_overlap : gr . update( visible = is_upscale ) ,
sd_upscale_overlap : gr . Slider. update( visible = is_upscale ) ,
inpaint_full_res : gr . update ( visible = is_inpaint ) ,
}
@ -1695,7 +1784,7 @@ with gr.Blocks(analytics_enabled=False) as img2img_interface:
] ,
outputs = [
gallery ,
output_seed ,
generation_info ,
html_info
]
)
@ -1703,6 +1792,25 @@ with gr.Blocks(analytics_enabled=False) as img2img_interface:
prompt . submit ( * * img2img_args )
submit . click ( * * img2img_args )
interrupt . click (
fn = lambda : state . interrupt ( ) ,
inputs = [ ] ,
outputs = [ ] ,
)
save . click (
fn = wrap_gradio_call ( save_files ) ,
inputs = [
generation_info ,
gallery ,
] ,
outputs = [
html_info ,
html_info ,
html_info ,
]
)
def upscale_with_realesrgan ( image , RealESRGAN_upscaling , RealESRGAN_model_index ) :
info = realesrgan_models [ RealESRGAN_model_index ]
@ -1744,7 +1852,7 @@ def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_in
save_image ( image , outpath , " " , None , ' ' , opts . samples_format , short_filename = True )
return image , 0 , ' '
return image , ' ' , ' '
extras_interface = gr . Interface (
@ -1757,7 +1865,7 @@ extras_interface = gr.Interface(
] ,
outputs = [
gr . Image ( label = " Result " ) ,
gr . Number( label = ' Seed ' , visible = False ) ,
gr . HTML( ) ,
gr . HTML ( ) ,
] ,
allow_flagging = " never " ,
@ -1779,7 +1887,7 @@ def run_pnginfo(image):
message = " Nothing found in the image. "
info = f " <div><p> { message } <p></div> "
return [ info ]
return ' ' , ' ' , info
pnginfo_interface = gr . Interface (
@ -1789,6 +1897,8 @@ pnginfo_interface = gr.Interface(
] ,
outputs = [
gr . HTML ( ) ,
gr . HTML ( ) ,
gr . HTML ( ) ,
] ,
allow_flagging = " never " ,
analytics_enabled = False ,
@ -1809,7 +1919,7 @@ def run_settings(*args):
opts . save ( config_filename )
return ' Settings saved. ' , ' '
return ' Settings saved. ' , ' ' , ' '
def create_setting_component ( key ) :
@ -1839,6 +1949,7 @@ settings_interface = gr.Interface(
outputs = [
gr . Textbox ( label = ' Result ' ) ,
gr . HTML ( ) ,
gr . HTML ( ) ,
] ,
title = None ,
description = None ,
@ -1863,17 +1974,18 @@ try:
except Exception :
pass
sd_config = OmegaConf . load ( cmd_opts . config )
sd_model = load_model_from_config ( sd_config , cmd_opts . ckpt )
sd_model = ( sd_model if cmd_opts . no_half else sd_model . half ( ) )
if False :
sd_config = OmegaConf . load ( cmd_opts . config )
sd_model = load_model_from_config ( sd_config , cmd_opts . ckpt )
sd_model = ( sd_model if cmd_opts . no_half else sd_model . half ( ) )
if cmd_opts . lowvram or cmd_opts . medvram :
setup_for_low_vram ( sd_model )
else :
sd_model = sd_model . to ( device )
if cmd_opts . lowvram or cmd_opts . medvram :
setup_for_low_vram ( sd_model )
else :
sd_model = sd_model . to ( device )
model_hijack = StableDiffusionModelHijack ( )
model_hijack . hijack ( sd_model )
model_hijack = StableDiffusionModelHijack ( )
model_hijack . hijack ( sd_model )
with open ( os . path . join ( script_path , " style.css " ) , " r " , encoding = " utf8 " ) as file :
css = file . read ( )