@ -1,12 +1,40 @@
import time
import uvicorn
from gradio . processing_utils import encode_pil_to_base64 , decode_base64_to_file , decode_base64_to_image
from fastapi import APIRouter , HTTPException
from fastapi import APIRouter , Depends, HTTPException
import modules . shared as shared
from modules import devices
from modules . api . models import *
from modules . processing import StableDiffusionProcessingTxt2Img , StableDiffusionProcessingImg2Img , process_images
from modules . sd_samplers import all_samplers
from modules . extras import run_extras , run_pnginfo
# copy from wrap_gradio_gpu_call of webui.py
# because queue lock will be acquired in api handlers
# and time start needs to be set
# the function has been modified into two parts
def before_gpu_call ( ) :
devices . torch_gc ( )
shared . state . sampling_step = 0
shared . state . job_count = - 1
shared . state . job_no = 0
shared . state . job_timestamp = shared . state . get_job_timestamp ( )
shared . state . current_latent = None
shared . state . current_image = None
shared . state . current_image_sampling_step = 0
shared . state . skipped = False
shared . state . interrupted = False
shared . state . textinfo = None
shared . state . time_start = time . time ( )
def after_gpu_call ( ) :
shared . state . job = " "
shared . state . job_count = 0
devices . torch_gc ( )
def upscaler_to_index ( name : str ) :
try :
return [ x . name . lower ( ) for x in shared . sd_upscalers ] . index ( name . lower ( ) )
@ -33,15 +61,16 @@ class Api:
self . app . add_api_route ( " /sdapi/v1/extra-single-image " , self . extras_single_image_api , methods = [ " POST " ] , response_model = ExtrasSingleImageResponse )
self . app . add_api_route ( " /sdapi/v1/extra-batch-images " , self . extras_batch_images_api , methods = [ " POST " ] , response_model = ExtrasBatchImagesResponse )
self . app . add_api_route ( " /sdapi/v1/png-info " , self . pnginfoapi , methods = [ " POST " ] , response_model = PNGInfoResponse )
self . app . add_api_route ( " /sdapi/v1/progress " , self . progressapi , methods = [ " GET " ] , response_model = ProgressResponse )
def text2imgapi ( self , txt2imgreq : StableDiffusionTxt2ImgProcessingAPI ) :
sampler_index = sampler_to_index ( txt2imgreq . sampler_index )
if sampler_index is None :
raise HTTPException ( status_code = 404 , detail = " Sampler not found " )
raise HTTPException ( status_code = 404 , detail = " Sampler not found " )
populate = txt2imgreq . copy ( update = { # Override __init__ params
" sd_model " : shared . sd_model ,
" sd_model " : shared . sd_model ,
" sampler_index " : sampler_index [ 0 ] ,
" do_not_save_samples " : True ,
" do_not_save_grid " : True
@ -49,34 +78,36 @@ class Api:
)
p = StableDiffusionProcessingTxt2Img ( * * vars ( populate ) )
# Override object param
before_gpu_call ( )
with self . queue_lock :
processed = process_images ( p )
after_gpu_call ( )
b64images = list ( map ( encode_pil_to_base64 , processed . images ) )
return TextToImageResponse ( images = b64images , parameters = vars ( txt2imgreq ) , info = processed . js ( ) )
def img2imgapi ( self , img2imgreq : StableDiffusionImg2ImgProcessingAPI ) :
sampler_index = sampler_to_index ( img2imgreq . sampler_index )
if sampler_index is None :
raise HTTPException ( status_code = 404 , detail = " Sampler not found " )
raise HTTPException ( status_code = 404 , detail = " Sampler not found " )
init_images = img2imgreq . init_images
if init_images is None :
raise HTTPException ( status_code = 404 , detail = " Init image not found " )
raise HTTPException ( status_code = 404 , detail = " Init image not found " )
mask = img2imgreq . mask
if mask :
mask = decode_base64_to_image ( mask )
populate = img2imgreq . copy ( update = { # Override __init__ params
" sd_model " : shared . sd_model ,
" sd_model " : shared . sd_model ,
" sampler_index " : sampler_index [ 0 ] ,
" do_not_save_samples " : True ,
" do_not_save_grid " : True ,
" do_not_save_grid " : True ,
" mask " : mask
}
)
@ -89,15 +120,17 @@ class Api:
p . init_images = imgs
# Override object param
before_gpu_call ( )
with self . queue_lock :
processed = process_images ( p )
after_gpu_call ( )
b64images = list ( map ( encode_pil_to_base64 , processed . images ) )
if ( not img2imgreq . include_init_images ) :
img2imgreq . init_images = None
img2imgreq . mask = None
return ImageToImageResponse ( images = b64images , parameters = vars ( img2imgreq ) , info = processed . js ( ) )
def extras_single_image_api ( self , req : ExtrasSingleImageRequest ) :
@ -125,7 +158,7 @@ class Api:
result = run_extras ( extras_mode = 1 , image = " " , input_dir = " " , output_dir = " " , * * reqDict )
return ExtrasBatchImagesResponse ( images = list ( map ( encode_pil_to_base64 , result [ 0 ] ) ) , html_info = result [ 1 ] )
def pnginfoapi ( self , req : PNGInfoRequest ) :
if ( not req . image . strip ( ) ) :
return PNGInfoResponse ( info = " " )
@ -134,6 +167,32 @@ class Api:
return PNGInfoResponse ( info = result [ 1 ] )
def progressapi ( self , req : ProgressRequest = Depends ( ) ) :
# copy from check_progress_call of ui.py
if shared . state . job_count == 0 :
return ProgressResponse ( progress = 0 , eta_relative = 0 , state = shared . state . dict ( ) )
# avoid dividing zero
progress = 0.01
if shared . state . job_count > 0 :
progress + = shared . state . job_no / shared . state . job_count
if shared . state . sampling_steps > 0 :
progress + = 1 / shared . state . job_count * shared . state . sampling_step / shared . state . sampling_steps
time_since_start = time . time ( ) - shared . state . time_start
eta = ( time_since_start / progress )
eta_relative = eta - time_since_start
progress = min ( progress , 1 )
current_image = None
if shared . state . current_image and not req . skip_current_image :
current_image = encode_pil_to_base64 ( shared . state . current_image )
return ProgressResponse ( progress = progress , eta_relative = eta_relative , state = shared . state . dict ( ) , current_image = current_image )
def launch ( self , server_name , port ) :
self . app . include_router ( self . router )
uvicorn . run ( self . app , host = server_name , port = port )