@ -97,10 +97,26 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
ScheduledPromptConditioning = namedtuple ( " ScheduledPromptConditioning " , [ " end_at_step " , " cond " ] )
ScheduledPromptBatch = namedtuple ( " ScheduledPromptBatch " , [ " shape " , " schedules " ] )
def get_learned_conditioning ( model , prompts , steps ) :
""" converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
and the sampling step at which this condition is to be replaced by the next one .
Input :
( model , [ ' a red crown ' , ' a [blue:green:5] jeweled crown ' ] , 20 )
Output :
[
[
ScheduledPromptConditioning ( end_at_step = 20 , cond = tensor ( [ [ - 0.3886 , 0.0229 , - 0.0523 , . . . , - 0.4901 , - 0.3066 , 0.0674 ] , . . . , [ 0.3317 , - 0.5102 , - 0.4066 , . . . , 0.4119 , - 0.7647 , - 1.0160 ] ] , device = ' cuda:0 ' ) )
] ,
[
ScheduledPromptConditioning ( end_at_step = 5 , cond = tensor ( [ [ - 0.3886 , 0.0229 , - 0.0522 , . . . , - 0.4901 , - 0.3067 , 0.0673 ] , . . . , [ - 0.0192 , 0.3867 , - 0.4644 , . . . , 0.1135 , - 0.3696 , - 0.4625 ] ] , device = ' cuda:0 ' ) ) ,
ScheduledPromptConditioning ( end_at_step = 20 , cond = tensor ( [ [ - 0.3886 , 0.0229 , - 0.0522 , . . . , - 0.4901 , - 0.3067 , 0.0673 ] , . . . , [ - 0.7352 , - 0.4356 , - 0.7888 , . . . , 0.6994 , - 0.4312 , - 1.2593 ] ] , device = ' cuda:0 ' ) )
]
]
"""
res = [ ]
prompt_schedules = get_learned_conditioning_prompt_schedules ( prompts , steps )
@ -123,13 +139,75 @@ def get_learned_conditioning(model, prompts, steps):
cache [ prompt ] = cond_schedule
res . append ( cond_schedule )
return ScheduledPromptBatch ( ( len ( prompts ) , ) + res [ 0 ] [ 0 ] . cond . shape , res )
return res
re_AND = re . compile ( r " \ bAND \ b " )
re_weight = re . compile ( r " ^(.*?)(?: \ s*: \ s*([-+]? \ s*(?: \ d+| \ d* \ . \ d+)?))? \ s*$ " )
def get_multicond_prompt_list ( prompts ) :
res_indexes = [ ]
prompt_flat_list = [ ]
prompt_indexes = { }
for prompt in prompts :
subprompts = re_AND . split ( prompt )
indexes = [ ]
for subprompt in subprompts :
text , weight = re_weight . search ( subprompt ) . groups ( )
weight = float ( weight ) if weight is not None else 1.0
index = prompt_indexes . get ( text , None )
if index is None :
index = len ( prompt_flat_list )
prompt_flat_list . append ( text )
prompt_indexes [ text ] = index
indexes . append ( ( index , weight ) )
res_indexes . append ( indexes )
return res_indexes , prompt_flat_list , prompt_indexes
class ComposableScheduledPromptConditioning :
def __init__ ( self , schedules , weight = 1.0 ) :
self . schedules : list [ ScheduledPromptConditioning ] = schedules
self . weight : float = weight
class MulticondLearnedConditioning :
def __init__ ( self , shape , batch ) :
self . shape : tuple = shape # the shape field is needed to send this object to DDIM/PLMS
self . batch : list [ list [ ComposableScheduledPromptConditioning ] ] = batch
def reconstruct_cond_batch ( c : ScheduledPromptBatch , current_step ) :
param = c . schedules [ 0 ] [ 0 ] . cond
res = torch . zeros ( c . shape , device = param . device , dtype = param . dtype )
for i , cond_schedule in enumerate ( c . schedules ) :
def get_multicond_learned_conditioning ( model , prompts , steps ) - > MulticondLearnedConditioning :
""" same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
For each prompt , the list is obtained by splitting the prompt using the AND separator .
https : / / energy - based - model . github . io / Compositional - Visual - Generation - with - Composable - Diffusion - Models /
"""
res_indexes , prompt_flat_list , prompt_indexes = get_multicond_prompt_list ( prompts )
learned_conditioning = get_learned_conditioning ( model , prompt_flat_list , steps )
res = [ ]
for indexes in res_indexes :
res . append ( [ ComposableScheduledPromptConditioning ( learned_conditioning [ i ] , weight ) for i , weight in indexes ] )
return MulticondLearnedConditioning ( shape = ( len ( prompts ) , ) , batch = res )
def reconstruct_cond_batch ( c : list [ list [ ScheduledPromptConditioning ] ] , current_step ) :
param = c [ 0 ] [ 0 ] . cond
res = torch . zeros ( ( len ( c ) , ) + param . shape , device = param . device , dtype = param . dtype )
for i , cond_schedule in enumerate ( c ) :
target_index = 0
for current , ( end_at , cond ) in enumerate ( cond_schedule ) :
if current_step < = end_at :
@ -140,6 +218,30 @@ def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step):
return res
def reconstruct_multicond_batch ( c : MulticondLearnedConditioning , current_step ) :
param = c . batch [ 0 ] [ 0 ] . schedules [ 0 ] . cond
tensors = [ ]
conds_list = [ ]
for batch_no , composable_prompts in enumerate ( c . batch ) :
conds_for_batch = [ ]
for cond_index , composable_prompt in enumerate ( composable_prompts ) :
target_index = 0
for current , ( end_at , cond ) in enumerate ( composable_prompt . schedules ) :
if current_step < = end_at :
target_index = current
break
conds_for_batch . append ( ( len ( tensors ) , composable_prompt . weight ) )
tensors . append ( composable_prompt . schedules [ target_index ] . cond )
conds_list . append ( conds_for_batch )
return conds_list , torch . stack ( tensors ) . to ( device = param . device , dtype = param . dtype )
re_attention = re . compile ( r """
\\\( |
\\\) |