@ -1,10 +1,7 @@
import re
from collections import namedtuple
import torch
from lark import Lark , Transformer , Visitor
import functools
import modules. shared as shared
import lark
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
# will be represented with prompt_schedule like this (assuming steps=100):
@ -14,25 +11,48 @@ import modules.shared as shared
# [75, 'fantasy landscape with a lake and an oak in background masterful']
# [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
schedule_parser = lark . Lark ( r """
! start : ( prompt | / [ ] [ ( ) : ] / + ) *
prompt : ( emphasized | scheduled | plain | WHITESPACE ) *
! emphasized : " ( " prompt " ) "
| " ( " prompt " : " prompt " ) "
| " [ " prompt " ] "
scheduled : " [ " [ prompt " : " ] prompt " : " [ WHITESPACE ] NUMBER " ] "
WHITESPACE : / \s + /
plain : / ( [ ^ \\\[ \] ( ) : ] | \\. ) + /
% import common . SIGNED_NUMBER - > NUMBER
""" )
def get_learned_conditioning_prompt_schedules ( prompts , steps ) :
grammar = r """
start : prompt
prompt : ( emphasized | scheduled | weighted | plain ) *
! emphasized : " ( " prompt " ) "
| " ( " prompt " : " prompt " ) "
| " [ " prompt " ] "
scheduled : " [ " ( prompt " : " ) ? prompt " : " NUMBER " ] "
! weighted : " { " weighted_item ( " | " weighted_item ) * " } "
! weighted_item : prompt ( " : " prompt ) ?
plain : / ( [ ^ \\\[ \] ( ) { } : | ] | \\. ) + /
% import common . SIGNED_NUMBER - > NUMBER
"""
parser = Lark ( grammar , parser = ' lalr ' )
>> > g = lambda p : get_learned_conditioning_prompt_schedules ( [ p ] , 10 ) [ 0 ]
>> > g ( " test " )
[ [ 10 , ' test ' ] ]
>> > g ( " a [b:3] " )
[ [ 3 , ' a ' ] , [ 10 , ' a b ' ] ]
>> > g ( " a [b: 3] " )
[ [ 3 , ' a ' ] , [ 10 , ' a b ' ] ]
>> > g ( " a [[[b]]:2] " )
[ [ 2 , ' a ' ] , [ 10 , ' a [[b]] ' ] ]
>> > g ( " [(a:2):3] " )
[ [ 3 , ' ' ] , [ 10 , ' (a:2) ' ] ]
>> > g ( " a [b : c : 1] d " )
[ [ 1 , ' a b d ' ] , [ 10 , ' a c d ' ] ]
>> > g ( " a[b:[c:d:2]:1]e " )
[ [ 1 , ' abe ' ] , [ 2 , ' ace ' ] , [ 10 , ' ade ' ] ]
>> > g ( " a [unbalanced " )
[ [ 10 , ' a [unbalanced ' ] ]
>> > g ( " a [b:.5] c " )
[ [ 5 , ' a c ' ] , [ 10 , ' a b c ' ] ]
>> > g ( " a [ { b|d { :.5] c " ) # not handling this right now
[ [ 5 , ' a c ' ] , [ 10 , ' a { b|d { c ' ] ]
>> > g ( " ((a][:b:c [d:3] " )
[ [ 3 , ' ((a][:b:c ' ] , [ 10 , ' ((a][:b:c d ' ] ]
"""
def collect_steps ( steps , tree ) :
l = [ steps ]
class CollectSteps ( Visitor ) :
class CollectSteps ( lark. Visitor) :
def scheduled ( self , tree ) :
tree . children [ - 1 ] = float ( tree . children [ - 1 ] )
if tree . children [ - 1 ] < 1 :
@ -43,13 +63,10 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
return sorted ( set ( l ) )
def at_step ( step , tree ) :
class AtStep ( Transformer) :
class AtStep ( lark. Transformer) :
def scheduled ( self , args ) :
if len ( args ) == 2 :
before , after , when = ( ) , * args
else :
before , after , when = args
yield before if step < = when else after
before , after , _ , when = args
yield before or ( ) if step < = when else after
def start ( self , args ) :
def flatten ( x ) :
if type ( x ) == str :
@ -57,16 +74,22 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
else :
for gen in x :
yield from flatten ( gen )
return ' ' . join ( flatten ( args [0 ] ))
return ' ' . join ( flatten ( args ))
def plain ( self , args ) :
yield args [ 0 ] . value
def __default__ ( self , data , children , meta ) :
for child in children :
yield from child
return AtStep ( ) . transform ( tree )
def get_schedule ( prompt ) :
tree = parser . parse ( prompt )
try :
tree = schedule_parser . parse ( prompt )
except lark . exceptions . LarkError as e :
if 0 :
import traceback
traceback . print_exc ( )
return [ [ steps , prompt ] ]
return [ [ t , at_step ( t , tree ) ] for t in collect_steps ( steps , tree ) ]
promptdict = { prompt : get_schedule ( prompt ) for prompt in set ( prompts ) }
@ -77,8 +100,7 @@ ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at
ScheduledPromptBatch = namedtuple ( " ScheduledPromptBatch " , [ " shape " , " schedules " ] )
def get_learned_conditioning ( prompts , steps ) :
def get_learned_conditioning ( model , prompts , steps ) :
res = [ ]
prompt_schedules = get_learned_conditioning_prompt_schedules ( prompts , steps )
@ -92,7 +114,7 @@ def get_learned_conditioning(prompts, steps):
continue
texts = [ x [ 1 ] for x in prompt_schedule ]
conds = shared. sd_ model. get_learned_conditioning ( texts )
conds = model. get_learned_conditioning ( texts )
cond_schedule = [ ]
for i , ( end_at_step , text ) in enumerate ( prompt_schedule ) :
@ -105,12 +127,13 @@ def get_learned_conditioning(prompts, steps):
def reconstruct_cond_batch ( c : ScheduledPromptBatch , current_step ) :
res = torch . zeros ( c . shape , device = shared . device , dtype = next ( shared . sd_model . parameters ( ) ) . dtype )
param = c . schedules [ 0 ] [ 0 ] . cond
res = torch . zeros ( c . shape , device = param . device , dtype = param . dtype )
for i , cond_schedule in enumerate ( c . schedules ) :
target_index = 0
for curre t_index , ( end_at , cond ) in enumerate ( cond_schedule ) :
for curre nt , ( end_at , cond ) in enumerate ( cond_schedule ) :
if current_step < = end_at :
target_index = curre t_index
target_index = curre nt
break
res [ i ] = cond_schedule [ target_index ] . cond
@ -148,23 +171,26 @@ def parse_prompt_attention(text):
\\ - literal character ' \'
anything else - just text
Example :
' a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))). '
produces :
[
[ ' a ' , 1.0 ] ,
[ ' house ' , 1.5730000000000004 ] ,
[ ' ' , 1.1 ] ,
[ ' on ' , 1.0 ] ,
[ ' a ' , 1.1 ] ,
[ ' hill ' , 0.55 ] ,
[ ' , sun, ' , 1.1 ] ,
[ ' sky ' , 1.4641000000000006 ] ,
[ ' . ' , 1.1 ]
]
>> > parse_prompt_attention ( ' normal text ' )
[ [ ' normal text ' , 1.0 ] ]
>> > parse_prompt_attention ( ' an (important) word ' )
[ [ ' an ' , 1.0 ] , [ ' important ' , 1.1 ] , [ ' word ' , 1.0 ] ]
>> > parse_prompt_attention ( ' (unbalanced ' )
[ [ ' unbalanced ' , 1.1 ] ]
>> > parse_prompt_attention ( ' \ (literal \ ] ' )
[ [ ' (literal] ' , 1.0 ] ]
>> > parse_prompt_attention ( ' (unnecessary)(parens) ' )
[ [ ' unnecessaryparens ' , 1.1 ] ]
>> > parse_prompt_attention ( ' a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))). ' )
[ [ ' a ' , 1.0 ] ,
[ ' house ' , 1.5730000000000004 ] ,
[ ' ' , 1.1 ] ,
[ ' on ' , 1.0 ] ,
[ ' a ' , 1.1 ] ,
[ ' hill ' , 0.55 ] ,
[ ' , sun, ' , 1.1 ] ,
[ ' sky ' , 1.4641000000000006 ] ,
[ ' . ' , 1.1 ] ]
"""
res = [ ]
@ -206,4 +232,19 @@ def parse_prompt_attention(text):
if len ( res ) == 0 :
res = [ [ " " , 1.0 ] ]
# merge runs of identical weights
i = 0
while i + 1 < len ( res ) :
if res [ i ] [ 1 ] == res [ i + 1 ] [ 1 ] :
res [ i ] [ 0 ] + = res [ i + 1 ] [ 0 ]
res . pop ( i + 1 )
else :
i + = 1
return res
if __name__ == " __main__ " :
import doctest
doctest . testmod ( optionflags = doctest . NORMALIZE_WHITESPACE )
else :
import torch # doctest faster