@ -1,7 +1,14 @@
import math
import torch
from torch import einsum
try :
import xformers . ops
import functorch
xformers . _is_functorch_available = True
shared . xformers_available = True
except :
print ( ' Cannot find xformers, defaulting to split attention. Try setting --xformers in your webui-user file if you wish to install it. ' )
continue
from ldm . util import default
from einops import rearrange
@ -115,6 +122,25 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
return self . to_out ( r2 )
def xformers_attention_forward ( self , x , context = None , mask = None ) :
h = self . heads
q_in = self . to_q ( x )
context = default ( context , x )
hypernetwork = shared . selected_hypernetwork ( )
hypernetwork_layers = ( hypernetwork . layers if hypernetwork is not None else { } ) . get ( context . shape [ 2 ] , None )
if hypernetwork_layers is not None :
k_in = self . to_k ( hypernetwork_layers [ 0 ] ( context ) )
v_in = self . to_v ( hypernetwork_layers [ 1 ] ( context ) )
else :
k_in = self . to_k ( context )
v_in = self . to_v ( context )
q , k , v = map ( lambda t : rearrange ( t , ' b n (h d) -> b n h d ' , h = h ) , ( q_in , k_in , v_in ) )
del q_in , k_in , v_in
out = xformers . ops . memory_efficient_attention ( q , k , v , attn_bias = None )
out = rearrange ( out , ' b n h d -> b n (h d) ' , h = h )
return self . to_out ( out )
def cross_attention_attnblock_forward ( self , x ) :
h_ = x
h_ = self . norm ( h_ )
@ -177,3 +203,13 @@ def cross_attention_attnblock_forward(self, x):
h3 + = x
return h3
def xformers_attnblock_forward ( self , x ) :
h_ = x
h_ = self . norm ( h_ )
q1 = self . q ( h_ ) . contiguous ( )
k1 = self . k ( h_ ) . contiguous ( )
v = self . v ( h_ ) . contiguous ( )
out = xformers . ops . memory_efficient_attention ( q1 , k1 , v )
out = self . proj_out ( out )
return x + out