@ -12,13 +12,22 @@ from modules import shared
def split_cross_attention_forward_v1 ( self , x , context = None , mask = None ) :
h = self . heads
q = self . to_q ( x )
q _in = self . to_q ( x )
context = default ( context , x )
k = self . to_k ( context )
v = self . to_v ( context )
hypernetwork = shared . selected_hypernetwork ( )
hypernetwork_layers = ( hypernetwork . layers if hypernetwork is not None else { } ) . get ( context . shape [ 2 ] , None )
if hypernetwork_layers is not None :
k_in = self . to_k ( hypernetwork_layers [ 0 ] ( context ) )
v_in = self . to_v ( hypernetwork_layers [ 1 ] ( context ) )
else :
k_in = self . to_k ( context )
v_in = self . to_v ( context )
del context , x
q , k , v = map ( lambda t : rearrange ( t , ' b n (h d) -> (b h) n d ' , h = h ) , ( q , k , v ) )
q , k , v = map ( lambda t : rearrange ( t , ' b n (h d) -> (b h) n d ' , h = h ) , ( q_in , k_in , v_in ) )
del q_in , k_in , v_in
r1 = torch . zeros ( q . shape [ 0 ] , q . shape [ 1 ] , v . shape [ 2 ] , device = q . device )
for i in range ( 0 , q . shape [ 0 ] , 2 ) :
@ -31,6 +40,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
r1 [ i : end ] = einsum ( ' b i j, b j d -> b i d ' , s2 , v [ i : end ] )
del s2
del q , k , v
r2 = rearrange ( r1 , ' (b h) n d -> b n (h d) ' , h = h )
del r1