@ -2,14 +2,17 @@ import base64
import io
import time
import uvicorn
from gradio . processing_utils import decode_base64_to_file , decode_base64_to_image
from fastapi import APIRouter , Depends , HTTPException
from threading import Lock
from gradio . processing_utils import encode_pil_to_base64 , decode_base64_to_file , decode_base64_to_image
from fastapi import APIRouter , Depends , FastAPI , HTTPException
import modules . shared as shared
from modules . api . models import *
from modules . processing import StableDiffusionProcessingTxt2Img , StableDiffusionProcessingImg2Img , process_images
from modules . sd_samplers import all_samplers , sample_to_image , samples_to_image_grid
from modules . sd_samplers import all_samplers
from modules . extras import run_extras , run_pnginfo
from modules . sd_models import checkpoints_list
from modules . realesrgan_model import get_realesrgan_models
from typing import List
def upscaler_to_index ( name : str ) :
try :
@ -37,7 +40,7 @@ def encode_pil_to_base64(image):
class Api :
def __init__ ( self , app , queue_l ock) :
def __init__ ( self , app : FastAPI , queue_l ock: L ock) :
self . router = APIRouter ( )
self . app = app
self . queue_lock = queue_lock
@ -48,6 +51,19 @@ class Api:
self . app . add_api_route ( " /sdapi/v1/png-info " , self . pnginfoapi , methods = [ " POST " ] , response_model = PNGInfoResponse )
self . app . add_api_route ( " /sdapi/v1/progress " , self . progressapi , methods = [ " GET " ] , response_model = ProgressResponse )
self . app . add_api_route ( " /sdapi/v1/interrupt " , self . interruptapi , methods = [ " POST " ] )
self . app . add_api_route ( " /sdapi/v1/options " , self . get_config , methods = [ " GET " ] , response_model = OptionsModel )
self . app . add_api_route ( " /sdapi/v1/options " , self . set_config , methods = [ " POST " ] )
self . app . add_api_route ( " /sdapi/v1/cmd-flags " , self . get_cmd_flags , methods = [ " GET " ] , response_model = FlagsModel )
self . app . add_api_route ( " /sdapi/v1/info " , self . get_info , methods = [ " GET " ] )
self . app . add_api_route ( " /sdapi/v1/samplers " , self . get_samplers , methods = [ " GET " ] , response_model = List [ SamplerItem ] )
self . app . add_api_route ( " /sdapi/v1/upscalers " , self . get_upscalers , methods = [ " GET " ] , response_model = List [ UpscalerItem ] )
self . app . add_api_route ( " /sdapi/v1/sd-models " , self . get_sd_models , methods = [ " GET " ] , response_model = List [ SDModelItem ] )
self . app . add_api_route ( " /sdapi/v1/hypernetworks " , self . get_hypernetworks , methods = [ " GET " ] , response_model = List [ HypernetworkItem ] )
self . app . add_api_route ( " /sdapi/v1/face-restorers " , self . get_face_restorers , methods = [ " GET " ] , response_model = List [ FaceRestorerItem ] )
self . app . add_api_route ( " /sdapi/v1/realesrgan-models " , self . get_realesrgan_models , methods = [ " GET " ] , response_model = List [ RealesrganItem ] )
self . app . add_api_route ( " /sdapi/v1/prompt-styles " , self . get_promp_styles , methods = [ " GET " ] , response_model = List [ PromptStyleItem ] )
self . app . add_api_route ( " /sdapi/v1/artist-categories " , self . get_artists_categories , methods = [ " GET " ] , response_model = List [ str ] )
self . app . add_api_route ( " /sdapi/v1/artists " , self . get_artists , methods = [ " GET " ] , response_model = List [ ArtistItem ] )
def text2imgapi ( self , txt2imgreq : StableDiffusionTxt2ImgProcessingAPI ) :
sampler_index = sampler_to_index ( txt2imgreq . sampler_index )
@ -190,6 +206,77 @@ class Api:
shared . state . interrupt ( )
return { }
def get_config ( self ) :
options = { }
for key in shared . opts . data . keys ( ) :
metadata = shared . opts . data_labels . get ( key )
if ( metadata is not None ) :
options . update ( { key : shared . opts . data . get ( key , shared . opts . data_labels . get ( key ) . default ) } )
else :
options . update ( { key : shared . opts . data . get ( key , None ) } )
return options
def set_config ( self , req : OptionsModel ) :
reqDict = vars ( req )
for o in reqDict :
setattr ( shared . opts , o , reqDict [ o ] )
shared . opts . save ( shared . config_filename )
return
def get_cmd_flags ( self ) :
return vars ( shared . cmd_opts )
def get_info ( self ) :
return {
" hypernetworks " : [ { " name " : name , " path " : shared . hypernetworks [ name ] } for name in shared . hypernetworks ] ,
" face_restorers " : [ { " name " : x . name ( ) , " cmd_dir " : getattr ( x , " cmd_dir " , None ) } for x in shared . face_restorers ] ,
" realesrgan_models " : [ { " name " : x . name , " path " : x . data_path , " scale " : x . scale } for x in get_realesrgan_models ( None ) ] ,
" promp_styles " : [ shared . prompt_styles . styles [ k ] for k in shared . prompt_styles . styles ] ,
" artists_categories " : shared . artist_db . cats ,
# "artists": [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
}
def get_samplers ( self ) :
return [ { " name " : sampler [ 0 ] , " aliases " : sampler [ 2 ] , " options " : sampler [ 3 ] } for sampler in all_samplers ]
def get_upscalers ( self ) :
upscalers = [ ]
for upscaler in shared . sd_upscalers :
u = upscaler . scaler
upscalers . append ( { " name " : u . name , " model_name " : u . model_name , " model_path " : u . model_path , " model_url " : u . model_url } )
return upscalers
def get_sd_models ( self ) :
return [ { " title " : x . title , " model_name " : x . model_name , " hash " : x . hash , " filename " : x . filename , " config " : x . config } for x in checkpoints_list . values ( ) ]
def get_hypernetworks ( self ) :
return [ { " name " : name , " path " : shared . hypernetworks [ name ] } for name in shared . hypernetworks ]
def get_face_restorers ( self ) :
return [ { " name " : x . name ( ) , " cmd_dir " : getattr ( x , " cmd_dir " , None ) } for x in shared . face_restorers ]
def get_realesrgan_models ( self ) :
return [ { " name " : x . name , " path " : x . data_path , " scale " : x . scale } for x in get_realesrgan_models ( None ) ]
def get_promp_styles ( self ) :
styleList = [ ]
for k in shared . prompt_styles . styles :
style = shared . prompt_styles . styles [ k ]
styleList . append ( { " name " : style [ 0 ] , " prompt " : style [ 1 ] , " negative_prompr " : style [ 2 ] } )
return styleList
def get_artists_categories ( self ) :
return shared . artist_db . cats
def get_artists ( self ) :
return [ { " name " : x [ 0 ] , " score " : x [ 1 ] , " category " : x [ 2 ] } for x in shared . artist_db . artists ]
def launch ( self , server_name , port ) :
self . app . include_router ( self . router )