@ -568,6 +568,24 @@ def create_ui(wrap_gradio_gpu_call):
import modules . img2img
import modules . txt2img
def create_refresh_button ( refresh_component , refresh_method , refreshed_args , elem_id ) :
def refresh ( ) :
refresh_method ( )
args = refreshed_args ( ) if callable ( refreshed_args ) else refreshed_args
for k , v in args . items ( ) :
setattr ( refresh_component , k , v )
return gr . update ( * * ( args or { } ) )
refresh_button = gr . Button ( value = refresh_symbol , elem_id = elem_id )
refresh_button . click (
fn = refresh ,
inputs = [ ] ,
outputs = [ refresh_component ]
)
return refresh_button
with gr . Blocks ( analytics_enabled = False ) as txt2img_interface :
txt2img_prompt , roll , txt2img_prompt_style , txt2img_negative_prompt , txt2img_prompt_style2 , submit , _ , _ , txt2img_prompt_style_apply , txt2img_save_style , txt2img_paste , token_counter , token_button = create_toprow ( is_img2img = False )
dummy_component = gr . Label ( visible = False )
@ -1205,8 +1223,12 @@ def create_ui(wrap_gradio_gpu_call):
with gr . Tab ( label = " Train " ) :
gr . HTML ( value = " <p style= ' margin-bottom: 0.7em ' >Train an embedding; must specify a directory with a set of 1:1 ratio images</p> " )
train_embedding_name = gr . Dropdown ( label = ' Embedding ' , choices = sorted ( sd_hijack . model_hijack . embedding_db . word_embeddings . keys ( ) ) )
train_hypernetwork_name = gr . Dropdown ( label = ' Hypernetwork ' , choices = [ x for x in shared . hypernetworks . keys ( ) ] )
with gr . Row ( ) :
train_embedding_name = gr . Dropdown ( label = ' Embedding ' , choices = sorted ( sd_hijack . model_hijack . embedding_db . word_embeddings . keys ( ) ) )
create_refresh_button ( train_embedding_name , sd_hijack . model_hijack . embedding_db . load_textual_inversion_embeddings , lambda : { " choices " : sorted ( sd_hijack . model_hijack . embedding_db . word_embeddings . keys ( ) ) } , " refresh_train_embedding_name " )
with gr . Row ( ) :
train_hypernetwork_name = gr . Dropdown ( label = ' Hypernetwork ' , choices = [ x for x in shared . hypernetworks . keys ( ) ] )
create_refresh_button ( train_hypernetwork_name , shared . reload_hypernetworks , lambda : { " choices " : sorted ( [ x for x in shared . hypernetworks . keys ( ) ] ) } , " refresh_train_hypernetwork_name " )
learn_rate = gr . Textbox ( label = ' Learning rate ' , placeholder = " Learning rate " , value = " 0.005 " )
batch_size = gr . Number ( label = ' Batch size ' , value = 1 , precision = 0 )
dataset_directory = gr . Textbox ( label = ' Dataset directory ' , placeholder = " Path to directory with input images " )
@ -1357,26 +1379,11 @@ def create_ui(wrap_gradio_gpu_call):
if info . refresh is not None :
if is_quicksettings :
res = comp ( label = info . label , value = fun , * * ( args or { } ) )
refresh_button = gr. Button ( value = refresh_symbol , elem_id = " refresh_ " + key )
refresh_button = create_refresh_button( res , info . refresh , info . component_args , " refresh_ " + key )
else :
with gr . Row ( variant = " compact " ) :
res = comp ( label = info . label , value = fun , * * ( args or { } ) )
refresh_button = gr . Button ( value = refresh_symbol , elem_id = " refresh_ " + key )
def refresh ( ) :
info . refresh ( )
refreshed_args = info . component_args ( ) if callable ( info . component_args ) else info . component_args
for k , v in refreshed_args . items ( ) :
setattr ( res , k , v )
return gr . update ( * * ( refreshed_args or { } ) )
refresh_button . click (
fn = refresh ,
inputs = [ ] ,
outputs = [ res ] ,
)
refresh_button = create_refresh_button ( res , info . refresh , info . component_args , " refresh_ " + key )
else :
res = comp ( label = info . label , value = fun , * * ( args or { } ) )