@ -5,6 +5,9 @@ import uvicorn
from threading import Lock
from gradio . processing_utils import encode_pil_to_base64 , decode_base64_to_file , decode_base64_to_image
from fastapi import APIRouter , Depends , FastAPI , HTTPException
from fastapi . security import HTTPBasic , HTTPBasicCredentials
from secrets import compare_digest
import modules . shared as shared
from modules import sd_samplers
from modules . api . models import *
@ -61,30 +64,48 @@ def encode_pil_to_base64(image):
class Api :
def __init__ ( self , app : FastAPI , queue_lock : Lock ) :
if shared . cmd_opts . api_auth :
self . credenticals = dict ( )
for auth in shared . cmd_opts . api_auth . split ( " , " ) :
user , password = auth . split ( " : " )
self . credenticals [ user ] = password
self . router = APIRouter ( )
self . app = app
self . queue_lock = queue_lock
self . app . add_api_route ( " /sdapi/v1/txt2img " , self . text2imgapi , methods = [ " POST " ] , response_model = TextToImageResponse )
self . app . add_api_route ( " /sdapi/v1/img2img " , self . img2imgapi , methods = [ " POST " ] , response_model = ImageToImageResponse )
self . app . add_api_route ( " /sdapi/v1/extra-single-image " , self . extras_single_image_api , methods = [ " POST " ] , response_model = ExtrasSingleImageResponse )
self . app . add_api_route ( " /sdapi/v1/extra-batch-images " , self . extras_batch_images_api , methods = [ " POST " ] , response_model = ExtrasBatchImagesResponse )
self . app . add_api_route ( " /sdapi/v1/png-info " , self . pnginfoapi , methods = [ " POST " ] , response_model = PNGInfoResponse )
self . app . add_api_route ( " /sdapi/v1/progress " , self . progressapi , methods = [ " GET " ] , response_model = ProgressResponse )
self . app . add_api_route ( " /sdapi/v1/interrogate " , self . interrogateapi , methods = [ " POST " ] )
self . app . add_api_route ( " /sdapi/v1/interrupt " , self . interruptapi , methods = [ " POST " ] )
self . app . add_api_route ( " /sdapi/v1/skip " , self . skip , methods = [ " POST " ] )
self . app . add_api_route ( " /sdapi/v1/options " , self . get_config , methods = [ " GET " ] , response_model = OptionsModel )
self . app . add_api_route ( " /sdapi/v1/options " , self . set_config , methods = [ " POST " ] )
self . app . add_api_route ( " /sdapi/v1/cmd-flags " , self . get_cmd_flags , methods = [ " GET " ] , response_model = FlagsModel )
self . app . add_api_route ( " /sdapi/v1/samplers " , self . get_samplers , methods = [ " GET " ] , response_model = List [ SamplerItem ] )
self . app . add_api_route ( " /sdapi/v1/upscalers " , self . get_upscalers , methods = [ " GET " ] , response_model = List [ UpscalerItem ] )
self . app . add_api_route ( " /sdapi/v1/sd-models " , self . get_sd_models , methods = [ " GET " ] , response_model = List [ SDModelItem ] )
self . app . add_api_route ( " /sdapi/v1/hypernetworks " , self . get_hypernetworks , methods = [ " GET " ] , response_model = List [ HypernetworkItem ] )
self . app . add_api_route ( " /sdapi/v1/face-restorers " , self . get_face_restorers , methods = [ " GET " ] , response_model = List [ FaceRestorerItem ] )
self . app . add_api_route ( " /sdapi/v1/realesrgan-models " , self . get_realesrgan_models , methods = [ " GET " ] , response_model = List [ RealesrganItem ] )
self . app . add_api_route ( " /sdapi/v1/prompt-styles " , self . get_promp_styles , methods = [ " GET " ] , response_model = List [ PromptStyleItem ] )
self . app . add_api_route ( " /sdapi/v1/artist-categories " , self . get_artists_categories , methods = [ " GET " ] , response_model = List [ str ] )
self . app . add_api_route ( " /sdapi/v1/artists " , self . get_artists , methods = [ " GET " ] , response_model = List [ ArtistItem ] )
self . add_api_route ( " /sdapi/v1/txt2img " , self . text2imgapi , methods = [ " POST " ] , response_model = TextToImageResponse )
self . add_api_route ( " /sdapi/v1/img2img " , self . img2imgapi , methods = [ " POST " ] , response_model = ImageToImageResponse )
self . add_api_route ( " /sdapi/v1/extra-single-image " , self . extras_single_image_api , methods = [ " POST " ] , response_model = ExtrasSingleImageResponse )
self . add_api_route ( " /sdapi/v1/extra-batch-images " , self . extras_batch_images_api , methods = [ " POST " ] , response_model = ExtrasBatchImagesResponse )
self . add_api_route ( " /sdapi/v1/png-info " , self . pnginfoapi , methods = [ " POST " ] , response_model = PNGInfoResponse )
self . add_api_route ( " /sdapi/v1/progress " , self . progressapi , methods = [ " GET " ] , response_model = ProgressResponse )
self . add_api_route ( " /sdapi/v1/interrogate " , self . interrogateapi , methods = [ " POST " ] )
self . add_api_route ( " /sdapi/v1/interrupt " , self . interruptapi , methods = [ " POST " ] )
self . add_api_route ( " /sdapi/v1/skip " , self . skip , methods = [ " POST " ] )
self . add_api_route ( " /sdapi/v1/options " , self . get_config , methods = [ " GET " ] , response_model = OptionsModel )
self . add_api_route ( " /sdapi/v1/options " , self . set_config , methods = [ " POST " ] )
self . add_api_route ( " /sdapi/v1/cmd-flags " , self . get_cmd_flags , methods = [ " GET " ] , response_model = FlagsModel )
self . add_api_route ( " /sdapi/v1/samplers " , self . get_samplers , methods = [ " GET " ] , response_model = List [ SamplerItem ] )
self . add_api_route ( " /sdapi/v1/upscalers " , self . get_upscalers , methods = [ " GET " ] , response_model = List [ UpscalerItem ] )
self . add_api_route ( " /sdapi/v1/sd-models " , self . get_sd_models , methods = [ " GET " ] , response_model = List [ SDModelItem ] )
self . add_api_route ( " /sdapi/v1/hypernetworks " , self . get_hypernetworks , methods = [ " GET " ] , response_model = List [ HypernetworkItem ] )
self . add_api_route ( " /sdapi/v1/face-restorers " , self . get_face_restorers , methods = [ " GET " ] , response_model = List [ FaceRestorerItem ] )
self . add_api_route ( " /sdapi/v1/realesrgan-models " , self . get_realesrgan_models , methods = [ " GET " ] , response_model = List [ RealesrganItem ] )
self . add_api_route ( " /sdapi/v1/prompt-styles " , self . get_promp_styles , methods = [ " GET " ] , response_model = List [ PromptStyleItem ] )
self . add_api_route ( " /sdapi/v1/artist-categories " , self . get_artists_categories , methods = [ " GET " ] , response_model = List [ str ] )
self . add_api_route ( " /sdapi/v1/artists " , self . get_artists , methods = [ " GET " ] , response_model = List [ ArtistItem ] )
def add_api_route ( self , path : str , endpoint , * * kwargs ) :
if shared . cmd_opts . api_auth :
return self . app . add_api_route ( path , endpoint , dependencies = [ Depends ( self . auth ) ] , * * kwargs )
return self . app . add_api_route ( path , endpoint , * * kwargs )
def auth ( self , credenticals : HTTPBasicCredentials = Depends ( HTTPBasic ( ) ) ) :
if credenticals . username in self . credenticals :
if compare_digest ( credenticals . password , self . credenticals [ credenticals . username ] ) :
return True
raise HTTPException ( status_code = 401 , detail = " Incorrect username or password " , headers = { " WWW-Authenticate " : " Basic " } )
def text2imgapi ( self , txt2imgreq : StableDiffusionTxt2ImgProcessingAPI ) :
populate = txt2imgreq . copy ( update = { # Override __init__ params