@ -4,7 +4,7 @@ import torch.nn as nn
import numpy as np
import gradio as gr
from omegaconf import OmegaConf
from PIL import Image , ImageFont , ImageDraw
from PIL import Image , ImageFont , ImageDraw , PngImagePlugin
from itertools import islice
from einops import rearrange , repeat
from torch import autocast
@ -12,6 +12,8 @@ from contextlib import contextmanager, nullcontext
import mimetypes
import random
import math
import html
import time
import k_diffusion as K
from ldm . util import instantiate_from_config
@ -49,8 +51,13 @@ parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=(
parser . add_argument ( " --no-verify-input " , action = ' store_true ' , help = " do not verify input to check if it ' s too long " )
parser . add_argument ( " --no-half " , action = ' store_true ' , help = " do not switch the model to 16-bit floats " )
parser . add_argument ( " --no-progressbar-hiding " , action = ' store_true ' , help = " do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser) " )
parser . add_argument ( " --max-batch-count " , type = int , default = 16 , help = " maximum batch count value for the UI " )
parser . add_argument ( " --grid-format " , type = str , default = ' png ' , help = " file format for saved grids; can be png or jpg " )
parser . add_argument ( " --max-batch-count " , type = int , default = 16 , help = " maximum batch count value for the UI " )
parser . add_argument ( " --save-format " , type = str , default = ' png ' , help = " file format for saved indiviual samples; can be png or jpg " )
parser . add_argument ( " --grid-format " , type = str , default = ' png ' , help = " file format for saved grids; can be png or jpg " )
parser . add_argument ( " --grid-extended-filename " , action = ' store_true ' , help = " save grid images to filenames with extended info: seed, prompt " )
parser . add_argument ( " --jpeg-quality " , type = int , default = 80 , help = " quality for saved jpeg images " )
parser . add_argument ( " --disable-pnginfo " , action = ' store_true ' , help = " disable saving text information about generation parameters as chunks to png files " )
parser . add_argument ( " --inversion " , action = ' store_true ' , help = " switch to stable inversion version; allows for uploading embeddings; this option should be used only with textual inversion repo " )
opt = parser . parse_args ( )
@ -130,6 +137,37 @@ def create_random_tensors(shape, seeds):
return x
def torch_gc ( ) :
torch . cuda . empty_cache ( )
torch . cuda . ipc_collect ( )
def sanitize_filename_part ( text ) :
return text . replace ( ' ' , ' _ ' ) . translate ( { ord ( x ) : ' ' for x in invalid_filename_chars } ) [ : 128 ]
def save_image ( image , path , basename , seed , prompt , extension , info = None , short_filename = False ) :
prompt = sanitize_filename_part ( prompt )
if short_filename :
filename = f " { basename } . { extension } "
else :
filename = f " { basename } - { seed } - { prompt [ : 128 ] } . { extension } "
if extension == ' png ' and not opt . disable_pnginfo :
pnginfo = PngImagePlugin . PngInfo ( )
pnginfo . add_text ( " parameters " , info )
else :
pnginfo = None
image . save ( os . path . join ( path , filename ) , quality = opt . jpeg_quality , pnginfo = pnginfo )
def plaintext_to_html ( text ) :
text = " " . join ( [ f " <p> { html . escape ( x ) } </p> \n " for x in text . split ( ' \n ' ) ] )
return text
def load_GFPGAN ( ) :
model_name = ' GFPGANv1.3 '
model_path = os . path . join ( GFPGAN_dir , ' experiments/pretrained_models ' , model_name + ' .pth ' )
@ -301,11 +339,25 @@ def check_prompt_length(prompt, comments):
comments . append ( f " Warning: too many input tokens; some ( { len ( overflowing_words ) } ) have been truncated: \n { overflowing_text } \n " )
def wrap_gradio_call ( func ) :
def f ( * p1 , * * p2 ) :
t = time . perf_counter ( )
res = list ( func ( * p1 , * * p2 ) )
elapsed = time . perf_counter ( ) - t
# last item is always HTML
res [ - 1 ] = res [ - 1 ] + f " <p class= ' performance ' >Time taken: { elapsed : .2f } s</p> "
return tuple ( res )
return f
def process_images ( outpath , func_init , func_sample , prompt , seed , sampler_name , batch_size , n_iter , steps , cfg_scale , width , height , prompt_matrix , use_GFPGAN , do_not_save_grid = False ) :
""" this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch """
assert prompt is not None
torch . cuda . empty_cache ( )
torch _gc ( )
if seed == - 1 :
seed = random . randrange ( 4294967294 )
@ -351,6 +403,11 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
all_prompts = batch_size * n_iter * [ prompt ]
all_seeds = [ seed + x for x in range ( len ( all_prompts ) ) ]
info = f """
{ prompt }
Steps : { steps } , Sampler : { sampler_name } , CFG scale : { cfg_scale } , Seed : { seed } { ' , GFPGAN ' if use_GFPGAN and GFPGAN is not None else ' ' }
""" .strip() + " " .join([ " \n \n " + x for x in comments])
precision_scope = autocast if opt . precision == " autocast " else nullcontext
output_images = [ ]
with torch . no_grad ( ) , precision_scope ( " cuda " ) , model . ema_scope ( ) :
@ -385,9 +442,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
x_sample = restored_img
image = Image . fromarray ( x_sample )
filename = f " { base_count : 05 } - { seeds [ i ] } _ { prompts [ i ] . replace ( ' ' , ' _ ' ) . translate ( { ord ( x ) : ' ' for x in invalid_filename_chars } ) [ : 128 ] } .png "
image . save ( os . path . join ( sample_path , filename ) )
save_image ( image , sample_path , f " { base_count : 05 } " , seeds [ i ] , prompts [ i ] , opt . save_format , info = info )
output_images . append ( image )
base_count + = 1
@ -406,17 +461,10 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
output_images . insert ( 0 , grid )
grid. save ( os . path . join ( outpath , f ' grid- { grid_count : 04 } . { opt . grid_format } ' ) )
save_image( grid , outpath , f " grid- { grid_count : 04 } " , seed , prompt , opt . grid_format , info = info , short_filename = not opt . grid_extended_filename )
grid_count + = 1
info = f """
{ prompt }
Steps : { steps } , Sampler : { sampler_name } , CFG scale : { cfg_scale } , Seed : { seed } { ' , GFPGAN ' if use_GFPGAN and GFPGAN is not None else ' ' }
""" .strip()
for comment in comments :
info + = " \n \n " + comment
torch_gc ( )
return output_images , seed , info
@ -465,7 +513,7 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, p
del sampler
return output_images , seed , info
return output_images , seed , plaintext_to_html( info)
class Flagging ( gr . FlaggingCallback ) :
@ -510,7 +558,7 @@ class Flagging(gr.FlaggingCallback):
txt2img_interface = gr . Interface (
txt2img,
wrap_gradio_call( txt2img) ,
inputs = [
gr . Textbox ( label = " Prompt " , placeholder = " A corgi wearing a top hat as an oil painting. " , lines = 1 ) ,
gr . Slider ( minimum = 1 , maximum = 150 , step = 1 , label = " Sampling Steps " , value = 50 ) ,
@ -529,7 +577,7 @@ txt2img_interface = gr.Interface(
outputs = [
gr . Gallery ( label = " Images " ) ,
gr . Number ( label = ' Seed ' ) ,
gr . Textbox( label = " Copy-paste generation parameters " ) ,
gr . HTML( ) ,
] ,
title = " Stable Diffusion Text-to-Image K " ,
description = " Generate images from text with Stable Diffusion (using K-LMS) " ,
@ -608,7 +656,8 @@ def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_mat
grid_count = len ( os . listdir ( outpath ) ) - 1
grid = image_grid ( history , batch_size , force_n_rows = 1 )
grid . save ( os . path . join ( outpath , f ' grid- { grid_count : 04 } . { opt . grid_format } ' ) )
save_image ( grid , outpath , f " grid- { grid_count : 04 } " , initial_seed , prompt , opt . grid_format , info = info , short_filename = not opt . grid_extended_filename )
output_images = history
seed = initial_seed
@ -633,14 +682,14 @@ def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_mat
del sampler
return output_images , seed , info
return output_images , seed , plaintext_to_html( info)
sample_img2img = " assets/stable-samples/img2img/sketch-mountains-input.jpg "
sample_img2img = sample_img2img if os . path . exists ( sample_img2img ) else None
img2img_interface = gr . Interface (
img2img,
wrap_gradio_call( img2img) ,
inputs = [
gr . Textbox ( placeholder = " A fantasy landscape, trending on artstation. " , lines = 1 ) ,
gr . Image ( value = sample_img2img , source = " upload " , interactive = True , type = " pil " ) ,
@ -661,7 +710,7 @@ img2img_interface = gr.Interface(
outputs = [
gr . Gallery ( ) ,
gr . Number ( label = ' Seed ' ) ,
gr . Textbox( label = " Copy-paste generation parameters " ) ,
gr . HTML( ) ,
] ,
title = " Stable Diffusion Image-to-Image " ,
description = " Generate images from images with Stable Diffusion " ,
@ -682,7 +731,7 @@ def run_GFPGAN(image, strength):
if strength < 1.0 :
res = Image . blend ( image , res , strength )
return res
return res , 0 , ' '
if GFPGAN is not None :
@ -694,6 +743,8 @@ if GFPGAN is not None:
] ,
outputs = [
gr . Image ( label = " Result " ) ,
gr . Number ( label = ' Seed ' , visible = False ) ,
gr . HTML ( ) ,
] ,
title = " GFPGAN " ,
description = " Fix faces on images " ,
@ -704,7 +755,10 @@ if GFPGAN is not None:
demo = gr . TabbedInterface (
interface_list = [ x [ 0 ] for x in interfaces ] ,
tab_names = [ x [ 1 ] for x in interfaces ] ,
css = ( " " if opt . no_progressbar_hiding else css_hide_progressbar )
css = ( " " if opt . no_progressbar_hiding else css_hide_progressbar ) + """
. output - html p { margin : 0 0.5 em ; }
. performance { font - size : 0.85 em ; color : #444; }
"""
)
demo . launch ( )