@ -1,6 +1,6 @@
import re
from collections import namedtuple
from typing import List
import lark
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
@ -175,15 +175,14 @@ def get_multicond_prompt_list(prompts):
class ComposableScheduledPromptConditioning :
def __init__ ( self , schedules , weight = 1.0 ) :
self . schedules = schedules # : list[ScheduledPromptConditioning]
self . schedules : List [ ScheduledPromptConditioning ] = schedules
self . weight : float = weight
class MulticondLearnedConditioning :
def __init__ ( self , shape , batch ) :
self . shape : tuple = shape # the shape field is needed to send this object to DDIM/PLMS
self . batch = batch # : list[list[ComposableScheduledPromptConditioning]]
self . batch : List [ List [ ComposableScheduledPromptConditioning ] ] = batch
def get_multicond_learned_conditioning ( model , prompts , steps ) - > MulticondLearnedConditioning :
""" same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
@ -203,7 +202,7 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne
return MulticondLearnedConditioning ( shape = ( len ( prompts ) , ) , batch = res )
def reconstruct_cond_batch ( c , current_step ) : # c: list[list[ScheduledPromptConditioning]]
def reconstruct_cond_batch ( c : List [ List [ ScheduledPromptConditioning ] ] , current_step ) :
param = c [ 0 ] [ 0 ] . cond
res = torch . zeros ( ( len ( c ) , ) + param . shape , device = param . device , dtype = param . dtype )
for i , cond_schedule in enumerate ( c ) :