@ -1,20 +1,11 @@
import re
import re
from collections import namedtuple
from collections import namedtuple
import torch
import torch
from lark import Lark , Transformer , Visitor
import functools
import modules . shared as shared
import modules . shared as shared
re_prompt = re . compile ( r '''
( . * ? )
\[
( [ ^ ] : ] + ) :
( ? : ( [ ^ ] : ] * ) : ) ?
( [ 0 - 9 ] * \. ? [ 0 - 9 ] + )
]
|
( . + )
''' , re.X)
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
# will be represented with prompt_schedule like this (assuming steps=100):
# will be represented with prompt_schedule like this (assuming steps=100):
# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
@ -25,61 +16,57 @@ re_prompt = re.compile(r'''
def get_learned_conditioning_prompt_schedules ( prompts , steps ) :
def get_learned_conditioning_prompt_schedules ( prompts , steps ) :
res = [ ]
grammar = r """
cache = { }
start : prompt
prompt : ( emphasized | scheduled | weighted | plain ) *
for prompt in prompts :
! emphasized : " ( " prompt " ) "
prompt_schedule : list [ list [ str | int ] ] = [ [ steps , " " ] ]
| " ( " prompt " : " prompt " ) "
| " [ " prompt " ] "
cached = cache . get ( prompt , None )
scheduled : " [ " ( prompt " : " ) ? prompt " : " NUMBER " ] "
if cached is not None :
! weighted : " { " weighted_item ( " | " weighted_item ) * " } "
res . append ( cached )
! weighted_item : prompt ( " : " prompt ) ?
continue
plain : / ( [ ^ \\\[ \] ( ) { } : | ] | \\. ) + /
% import common . SIGNED_NUMBER - > NUMBER
for m in re_prompt . finditer ( prompt ) :
"""
plaintext = m . group ( 1 ) if m . group ( 5 ) is None else m . group ( 5 )
parser = Lark ( grammar , parser = ' lalr ' )
concept_from = m . group ( 2 )
def collect_steps ( steps , tree ) :
concept_to = m . group ( 3 )
l = [ steps ]
if concept_to is None :
class CollectSteps ( Visitor ) :
concept_to = concept_from
def scheduled ( self , tree ) :
concept_from = " "
tree . children [ - 1 ] = float ( tree . children [ - 1 ] )
swap_position = float ( m . group ( 4 ) ) if m . group ( 4 ) is not None else None
if tree . children [ - 1 ] < 1 :
tree . children [ - 1 ] * = steps
if swap_position is not None :
tree . children [ - 1 ] = min ( steps , int ( tree . children [ - 1 ] ) )
if swap_position < 1 :
l . append ( tree . children [ - 1 ] )
swap_position = swap_position * steps
CollectSteps ( ) . visit ( tree )
swap_position = int ( min ( swap_position , steps ) )
return sorted ( set ( l ) )
def at_step ( step , tree ) :
swap_index = None
class AtStep ( Transformer ) :
found_exact_index = False
def scheduled ( self , args ) :
for i in range ( len ( prompt_schedule ) ) :
if len ( args ) == 2 :
end_step = prompt_schedule [ i ] [ 0 ]
before , after , when = ( ) , * args
prompt_schedule [ i ] [ 1 ] + = plaintext
else :
before , after , when = args
if swap_position is not None and swap_index is None :
yield before if step < = when else after
if swap_position == end_step :
def start ( self , args ) :
swap_index = i
def flatten ( x ) :
found_exact_index = True
if type ( x ) == str :
yield x
if swap_position < end_step :
else :
swap_index = i
for gen in x :
yield from flatten ( gen )
if swap_index is not None :
return ' ' . join ( flatten ( args [ 0 ] ) )
if not found_exact_index :
def plain ( self , args ) :
prompt_schedule . insert ( swap_index , [ swap_position , prompt_schedule [ swap_index ] [ 1 ] ] )
yield args [ 0 ] . value
def __default__ ( self , data , children , meta ) :
for i in range ( len ( prompt_schedule ) ) :
for child in children :
end_step = prompt_schedule [ i ] [ 0 ]
yield from child
must_replace = swap_position < end_step
return AtStep ( ) . transform ( tree )
@functools.cache
prompt_schedule [ i ] [ 1 ] + = concept_to if must_replace else concept_from
def get_schedule ( prompt ) :
tree = parser . parse ( prompt )
res . append ( prompt_schedule )
return [ [ t , at_step ( t , tree ) ] for t in collect_steps ( steps , tree ) ]
cache [ prompt ] = prompt_schedule
return [ get_schedule ( prompt ) for prompt in prompts ]
#for t in prompt_schedule:
# print(t)
return res
ScheduledPromptConditioning = namedtuple ( " ScheduledPromptConditioning " , [ " end_at_step " , " cond " ] )
ScheduledPromptConditioning = namedtuple ( " ScheduledPromptConditioning " , [ " end_at_step " , " cond " ] )