@ -9,7 +9,7 @@ from torch import einsum
from ldm . util import default
from einops import rearrange
from modules import shared , errors
from modules import shared , errors , devices
from modules . hypernetworks import hypernetwork
from . sub_quadratic_attention import efficient_dot_product_attention
@ -52,18 +52,25 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
q , k , v = map ( lambda t : rearrange ( t , ' b n (h d) -> (b h) n d ' , h = h ) , ( q_in , k_in , v_in ) )
del q_in , k_in , v_in
r1 = torch . zeros ( q . shape [ 0 ] , q . shape [ 1 ] , v . shape [ 2 ] , device = q . device )
for i in range ( 0 , q . shape [ 0 ] , 2 ) :
end = i + 2
s1 = einsum ( ' b i d, b j d -> b i j ' , q [ i : end ] , k [ i : end ] )
s1 * = self . scale
dtype = q . dtype
if shared . opts . upcast_attn :
q , k , v = q . float ( ) , k . float ( ) , v . float ( )
s2 = s1 . softmax ( dim = - 1 )
del s1
with devices . without_autocast ( disable = not shared . opts . upcast_attn ) :
r1 = torch . zeros ( q . shape [ 0 ] , q . shape [ 1 ] , v . shape [ 2 ] , device = q . device , dtype = q . dtype )
for i in range ( 0 , q . shape [ 0 ] , 2 ) :
end = i + 2
s1 = einsum ( ' b i d, b j d -> b i j ' , q [ i : end ] , k [ i : end ] )
s1 * = self . scale
s2 = s1 . softmax ( dim = - 1 )
del s1
r1 [ i : end ] = einsum ( ' b i j, b j d -> b i d ' , s2 , v [ i : end ] )
del s2
del q , k , v
r1 [ i : end ] = einsum ( ' b i j, b j d -> b i d ' , s2 , v [ i : end ] )
del s2
del q , k , v
r1 = r1 . to ( dtype )
r2 = rearrange ( r1 , ' (b h) n d -> b n (h d) ' , h = h )
del r1
@ -82,45 +89,52 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
k_in = self . to_k ( context_k )
v_in = self . to_v ( context_v )
k_in * = self . scale
del context , x
q , k , v = map ( lambda t : rearrange ( t , ' b n (h d) -> (b h) n d ' , h = h ) , ( q_in , k_in , v_in ) )
del q_in , k_in , v_in
r1 = torch . zeros ( q . shape [ 0 ] , q . shape [ 1 ] , v . shape [ 2 ] , device = q . device , dtype = q . dtype )
mem_free_total = get_available_vram ( )
gb = 1024 * * 3
tensor_size = q . shape [ 0 ] * q . shape [ 1 ] * k . shape [ 1 ] * q . element_size ( )
modifier = 3 if q . element_size ( ) == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total :
steps = 2 * * ( math . ceil ( math . log ( mem_required / mem_free_total , 2 ) ) )
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
dtype = q_in . dtype
if shared . opts . upcast_attn :
q_in , k_in , v_in = q_in . float ( ) , k_in . float ( ) , v_in if v_in . device . type == ' mps ' else v_in . float ( )
if steps > 64 :
max_res = math . floor ( math . sqrt ( math . sqrt ( mem_free_total / 2.5 ) ) / 8 ) * 64
raise RuntimeError ( f ' Not enough memory, use lower resolution (max approx. { max_res } x { max_res } ). '
f ' Need: { mem_required / 64 / gb : 0.1f } GB free, Have: { mem_free_total / gb : 0.1f } GB free ' )
slice_size = q . shape [ 1 ] / / steps if ( q . shape [ 1 ] % steps ) == 0 else q . shape [ 1 ]
for i in range ( 0 , q . shape [ 1 ] , slice_size ) :
end = i + slice_size
s1 = einsum ( ' b i d, b j d -> b i j ' , q [ : , i : end ] , k )
s2 = s1 . softmax ( dim = - 1 , dtype = q . dtype )
del s1
r1 [ : , i : end ] = einsum ( ' b i j, b j d -> b i d ' , s2 , v )
del s2
with devices . without_autocast ( disable = not shared . opts . upcast_attn ) :
k_in = k_in * self . scale
del context , x
q , k , v = map ( lambda t : rearrange ( t , ' b n (h d) -> (b h) n d ' , h = h ) , ( q_in , k_in , v_in ) )
del q_in , k_in , v_in
r1 = torch . zeros ( q . shape [ 0 ] , q . shape [ 1 ] , v . shape [ 2 ] , device = q . device , dtype = q . dtype )
mem_free_total = get_available_vram ( )
gb = 1024 * * 3
tensor_size = q . shape [ 0 ] * q . shape [ 1 ] * k . shape [ 1 ] * q . element_size ( )
modifier = 3 if q . element_size ( ) == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total :
steps = 2 * * ( math . ceil ( math . log ( mem_required / mem_free_total , 2 ) ) )
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64 :
max_res = math . floor ( math . sqrt ( math . sqrt ( mem_free_total / 2.5 ) ) / 8 ) * 64
raise RuntimeError ( f ' Not enough memory, use lower resolution (max approx. { max_res } x { max_res } ). '
f ' Need: { mem_required / 64 / gb : 0.1f } GB free, Have: { mem_free_total / gb : 0.1f } GB free ' )
slice_size = q . shape [ 1 ] / / steps if ( q . shape [ 1 ] % steps ) == 0 else q . shape [ 1 ]
for i in range ( 0 , q . shape [ 1 ] , slice_size ) :
end = i + slice_size
s1 = einsum ( ' b i d, b j d -> b i j ' , q [ : , i : end ] , k )
s2 = s1 . softmax ( dim = - 1 , dtype = q . dtype )
del s1
r1 [ : , i : end ] = einsum ( ' b i j, b j d -> b i d ' , s2 , v )
del s2
del q , k , v
del q , k , v
r1 = r1 . to ( dtype )
r2 = rearrange ( r1 , ' (b h) n d -> b n (h d) ' , h = h )
del r1
@ -204,12 +218,20 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
context = default ( context , x )
context_k , context_v = hypernetwork . apply_hypernetworks ( shared . loaded_hypernetworks , context )
k = self . to_k ( context_k ) * self . scale
k = self . to_k ( context_k )
v = self . to_v ( context_v )
del context , context_k , context_v , x
q , k , v = map ( lambda t : rearrange ( t , ' b n (h d) -> (b h) n d ' , h = h ) , ( q , k , v ) )
r = einsum_op ( q , k , v )
dtype = q . dtype
if shared . opts . upcast_attn :
q , k , v = q . float ( ) , k . float ( ) , v if v . device . type == ' mps ' else v . float ( )
with devices . without_autocast ( disable = not shared . opts . upcast_attn ) :
k = k * self . scale
q , k , v = map ( lambda t : rearrange ( t , ' b n (h d) -> (b h) n d ' , h = h ) , ( q , k , v ) )
r = einsum_op ( q , k , v )
r = r . to ( dtype )
return self . to_out ( rearrange ( r , ' (b h) n d -> b n (h d) ' , h = h ) )
# -- End of code from https://github.com/invoke-ai/InvokeAI --
@ -234,8 +256,14 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
k = k . unflatten ( - 1 , ( h , - 1 ) ) . transpose ( 1 , 2 ) . flatten ( end_dim = 1 )
v = v . unflatten ( - 1 , ( h , - 1 ) ) . transpose ( 1 , 2 ) . flatten ( end_dim = 1 )
dtype = q . dtype
if shared . opts . upcast_attn :
q , k = q . float ( ) , k . float ( )
x = sub_quad_attention ( q , k , v , q_chunk_size = shared . cmd_opts . sub_quad_q_chunk_size , kv_chunk_size = shared . cmd_opts . sub_quad_kv_chunk_size , chunk_threshold = shared . cmd_opts . sub_quad_chunk_threshold , use_checkpoint = self . training )
x = x . to ( dtype )
x = x . unflatten ( 0 , ( - 1 , h ) ) . transpose ( 1 , 2 ) . flatten ( start_dim = 2 )
out_proj , dropout = self . to_out
@ -268,15 +296,16 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
query_chunk_size = q_tokens
kv_chunk_size = k_tokens
return efficient_dot_product_attention (
q ,
k ,
v ,
query_chunk_size = q_chunk_size ,
kv_chunk_size = kv_chunk_size ,
kv_chunk_size_min = kv_chunk_size_min ,
use_checkpoint = use_checkpoint ,
)
with devices . without_autocast ( disable = q . dtype == v . dtype ) :
return efficient_dot_product_attention (
q ,
k ,
v ,
query_chunk_size = q_chunk_size ,
kv_chunk_size = kv_chunk_size ,
kv_chunk_size_min = kv_chunk_size_min ,
use_checkpoint = use_checkpoint ,
)
def get_xformers_flash_attention_op ( q , k , v ) :
@ -306,8 +335,14 @@ def xformers_attention_forward(self, x, context=None, mask=None):
q , k , v = map ( lambda t : rearrange ( t , ' b n (h d) -> b n h d ' , h = h ) , ( q_in , k_in , v_in ) )
del q_in , k_in , v_in
dtype = q . dtype
if shared . opts . upcast_attn :
q , k = q . float ( ) , k . float ( )
out = xformers . ops . memory_efficient_attention ( q , k , v , attn_bias = None , op = get_xformers_flash_attention_op ( q , k , v ) )
out = out . to ( dtype )
out = rearrange ( out , ' b n h d -> b n (h d) ' , h = h )
return self . to_out ( out )
@ -378,10 +413,14 @@ def xformers_attnblock_forward(self, x):
v = self . v ( h_ )
b , c , h , w = q . shape
q , k , v = map ( lambda t : rearrange ( t , ' b c h w -> b (h w) c ' ) , ( q , k , v ) )
dtype = q . dtype
if shared . opts . upcast_attn :
q , k = q . float ( ) , k . float ( )
q = q . contiguous ( )
k = k . contiguous ( )
v = v . contiguous ( )
out = xformers . ops . memory_efficient_attention ( q , k , v , op = get_xformers_flash_attention_op ( q , k , v ) )
out = out . to ( dtype )
out = rearrange ( out , ' b (h w) c -> b c h w ' , h = h )
out = self . proj_out ( out )
return x + out