|
|
|
@ -9,7 +9,7 @@ from torch import einsum
|
|
|
|
from ldm.util import default
|
|
|
|
from ldm.util import default
|
|
|
|
from einops import rearrange
|
|
|
|
from einops import rearrange
|
|
|
|
|
|
|
|
|
|
|
|
from modules import shared
|
|
|
|
from modules import shared, errors
|
|
|
|
from modules.hypernetworks import hypernetwork
|
|
|
|
from modules.hypernetworks import hypernetwork
|
|
|
|
|
|
|
|
|
|
|
|
from .sub_quadratic_attention import efficient_dot_product_attention
|
|
|
|
from .sub_quadratic_attention import efficient_dot_product_attention
|
|
|
|
@ -279,6 +279,21 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_xformers_flash_attention_op(q, k, v):
|
|
|
|
|
|
|
|
if not shared.cmd_opts.xformers_flash_attention:
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
|
|
|
|
|
|
|
|
fw, bw = flash_attention_op
|
|
|
|
|
|
|
|
if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
|
|
|
|
|
|
|
|
return flash_attention_op
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
|
|
errors.display_once(e, "enabling flash attention")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def xformers_attention_forward(self, x, context=None, mask=None):
|
|
|
|
def xformers_attention_forward(self, x, context=None, mask=None):
|
|
|
|
h = self.heads
|
|
|
|
h = self.heads
|
|
|
|
q_in = self.to_q(x)
|
|
|
|
q_in = self.to_q(x)
|
|
|
|
@ -291,18 +306,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
|
|
|
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
|
|
|
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
|
|
|
|
del q_in, k_in, v_in
|
|
|
|
del q_in, k_in, v_in
|
|
|
|
|
|
|
|
|
|
|
|
if shared.cmd_opts.xformers_flash_attention:
|
|
|
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
|
|
|
|
op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
|
|
|
|
|
|
|
|
fw, bw = op
|
|
|
|
|
|
|
|
if not fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
|
|
|
|
|
|
|
|
# print('xformers_attention_forward', q.shape, k.shape, v.shape)
|
|
|
|
|
|
|
|
# Flash Attention is not availabe for the input arguments.
|
|
|
|
|
|
|
|
# Fallback to default xFormers' backend.
|
|
|
|
|
|
|
|
op = None
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
op = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=op)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
|
|
|
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
|
|
|
return self.to_out(out)
|
|
|
|
return self.to_out(out)
|
|
|
|
@ -377,17 +381,7 @@ def xformers_attnblock_forward(self, x):
|
|
|
|
q = q.contiguous()
|
|
|
|
q = q.contiguous()
|
|
|
|
k = k.contiguous()
|
|
|
|
k = k.contiguous()
|
|
|
|
v = v.contiguous()
|
|
|
|
v = v.contiguous()
|
|
|
|
if shared.cmd_opts.xformers_flash_attention:
|
|
|
|
out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
|
|
|
|
op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
|
|
|
|
|
|
|
|
fw, bw = op
|
|
|
|
|
|
|
|
if not fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v)):
|
|
|
|
|
|
|
|
# print('xformers_attnblock_forward', q.shape, k.shape, v.shape)
|
|
|
|
|
|
|
|
# Flash Attention is not availabe for the input arguments.
|
|
|
|
|
|
|
|
# Fallback to default xFormers' backend.
|
|
|
|
|
|
|
|
op = None
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
op = None
|
|
|
|
|
|
|
|
out = xformers.ops.memory_efficient_attention(q, k, v, op=op)
|
|
|
|
|
|
|
|
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
|
|
|
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
|
|
|
out = self.proj_out(out)
|
|
|
|
out = self.proj_out(out)
|
|
|
|
return x + out
|
|
|
|
return x + out
|
|
|
|
|