@ -1,5 +1,6 @@
from collections import namedtuple
from copy import copy
from itertools import permutations
import random
from PIL import Image
@ -29,6 +30,31 @@ def apply_prompt(p, x, xs):
p . negative_prompt = p . negative_prompt . replace ( xs [ 0 ] , x )
def apply_order ( p , x , xs ) :
token_order = [ ]
# Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen
for token in x :
token_order . append ( ( p . prompt . find ( token ) , token ) )
token_order . sort ( key = lambda t : t [ 0 ] )
prompt_parts = [ ]
# Split the prompt up, taking out the tokens
for _ , token in token_order :
n = p . prompt . find ( token )
prompt_parts . append ( p . prompt [ 0 : n ] )
p . prompt = p . prompt [ n + len ( token ) : ]
# Rebuild the prompt with the tokens in the order we want
prompt_tmp = " "
for idx , part in enumerate ( prompt_parts ) :
prompt_tmp + = part
prompt_tmp + = x [ idx ]
p . prompt = prompt_tmp + p . prompt
samplers_dict = { }
for i , sampler in enumerate ( modules . sd_samplers . samplers ) :
samplers_dict [ sampler . name . lower ( ) ] = i
@ -60,16 +86,26 @@ def format_value_add_label(p, opt, x):
def format_value ( p , opt , x ) :
if type ( x ) == float :
x = round ( x , 8 )
return x
def format_value_join_list ( p , opt , x ) :
return " , " . join ( x )
def do_nothing ( p , x , xs ) :
pass
def format_nothing ( p , opt , x ) :
return " "
def str_permutations ( x ) :
""" dummy function for specifying it in AxisOption ' s type when you want to get a list of permutations """
return x
AxisOption = namedtuple ( " AxisOption " , [ " label " , " type " , " apply " , " format_value " ] )
AxisOptionImg2Img = namedtuple ( " AxisOptionImg2Img " , [ " label " , " type " , " apply " , " format_value " ] )
@ -82,6 +118,7 @@ axis_options = [
AxisOption ( " Steps " , int , apply_field ( " steps " ) , format_value_add_label ) ,
AxisOption ( " CFG Scale " , float , apply_field ( " cfg_scale " ) , format_value_add_label ) ,
AxisOption ( " Prompt S/R " , str , apply_prompt , format_value ) ,
AxisOption ( " Prompt order " , str_permutations , apply_order , format_value_join_list ) ,
AxisOption ( " Sampler " , str , apply_sampler , format_value ) ,
AxisOption ( " Checkpoint name " , str , apply_checkpoint , format_value ) ,
AxisOption ( " Sigma Churn " , float , apply_field ( " s_churn " ) , format_value_add_label ) ,
@ -131,6 +168,7 @@ re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d
re_range_count = re . compile ( r " \ s*([+-]? \ s* \ d+) \ s*- \ s*([+-]? \ s* \ d+)(?: \ s* \ [( \ d+) \ s* \ ])? \ s* " )
re_range_count_float = re . compile ( r " \ s*([+-]? \ s* \ d+(?:. \ d*)?) \ s*- \ s*([+-]? \ s* \ d+(?:. \ d*)?)(?: \ s* \ [( \ d+(?:. \ d*)?) \ s* \ ])? \ s* " )
class Script ( scripts . Script ) :
def title ( self ) :
return " X/Y plot "
@ -206,6 +244,8 @@ class Script(scripts.Script):
valslist_ext . append ( val )
valslist = valslist_ext
elif opt . type == str_permutations :
valslist = list ( permutations ( valslist ) )
valslist = [ opt . type ( x ) for x in valslist ]