|
|
|
@ -15,14 +15,9 @@ import torch
|
|
|
|
from torch import Tensor
|
|
|
|
from torch import Tensor
|
|
|
|
from torch.utils.checkpoint import checkpoint
|
|
|
|
from torch.utils.checkpoint import checkpoint
|
|
|
|
import math
|
|
|
|
import math
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
from typing import Protocol
|
|
|
|
|
|
|
|
except:
|
|
|
|
|
|
|
|
from typing_extensions import Protocol
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, NamedTuple, List
|
|
|
|
from typing import Optional, NamedTuple, List
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def narrow_trunc(
|
|
|
|
def narrow_trunc(
|
|
|
|
input: Tensor,
|
|
|
|
input: Tensor,
|
|
|
|
dim: int,
|
|
|
|
dim: int,
|
|
|
|
@ -31,12 +26,14 @@ def narrow_trunc(
|
|
|
|
) -> Tensor:
|
|
|
|
) -> Tensor:
|
|
|
|
return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
|
|
|
|
return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AttnChunk(NamedTuple):
|
|
|
|
class AttnChunk(NamedTuple):
|
|
|
|
exp_values: Tensor
|
|
|
|
exp_values: Tensor
|
|
|
|
exp_weights_sum: Tensor
|
|
|
|
exp_weights_sum: Tensor
|
|
|
|
max_score: Tensor
|
|
|
|
max_score: Tensor
|
|
|
|
|
|
|
|
|
|
|
|
class SummarizeChunk(Protocol):
|
|
|
|
|
|
|
|
|
|
|
|
class SummarizeChunk:
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def __call__(
|
|
|
|
def __call__(
|
|
|
|
query: Tensor,
|
|
|
|
query: Tensor,
|
|
|
|
@ -44,7 +41,8 @@ class SummarizeChunk(Protocol):
|
|
|
|
value: Tensor,
|
|
|
|
value: Tensor,
|
|
|
|
) -> AttnChunk: ...
|
|
|
|
) -> AttnChunk: ...
|
|
|
|
|
|
|
|
|
|
|
|
class ComputeQueryChunkAttn(Protocol):
|
|
|
|
|
|
|
|
|
|
|
|
class ComputeQueryChunkAttn:
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def __call__(
|
|
|
|
def __call__(
|
|
|
|
query: Tensor,
|
|
|
|
query: Tensor,
|
|
|
|
@ -52,6 +50,7 @@ class ComputeQueryChunkAttn(Protocol):
|
|
|
|
value: Tensor,
|
|
|
|
value: Tensor,
|
|
|
|
) -> Tensor: ...
|
|
|
|
) -> Tensor: ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _summarize_chunk(
|
|
|
|
def _summarize_chunk(
|
|
|
|
query: Tensor,
|
|
|
|
query: Tensor,
|
|
|
|
key: Tensor,
|
|
|
|
key: Tensor,
|
|
|
|
@ -72,6 +71,7 @@ def _summarize_chunk(
|
|
|
|
max_score = max_score.squeeze(-1)
|
|
|
|
max_score = max_score.squeeze(-1)
|
|
|
|
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
|
|
|
|
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _query_chunk_attention(
|
|
|
|
def _query_chunk_attention(
|
|
|
|
query: Tensor,
|
|
|
|
query: Tensor,
|
|
|
|
key: Tensor,
|
|
|
|
key: Tensor,
|
|
|
|
@ -112,6 +112,7 @@ def _query_chunk_attention(
|
|
|
|
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
|
|
|
|
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
|
|
|
|
return all_values / all_weights
|
|
|
|
return all_values / all_weights
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO: refactor CrossAttention#get_attention_scores to share code with this
|
|
|
|
# TODO: refactor CrossAttention#get_attention_scores to share code with this
|
|
|
|
def _get_attention_scores_no_kv_chunking(
|
|
|
|
def _get_attention_scores_no_kv_chunking(
|
|
|
|
query: Tensor,
|
|
|
|
query: Tensor,
|
|
|
|
@ -131,10 +132,12 @@ def _get_attention_scores_no_kv_chunking(
|
|
|
|
hidden_states_slice = torch.bmm(attn_probs, value)
|
|
|
|
hidden_states_slice = torch.bmm(attn_probs, value)
|
|
|
|
return hidden_states_slice
|
|
|
|
return hidden_states_slice
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ScannedChunk(NamedTuple):
|
|
|
|
class ScannedChunk(NamedTuple):
|
|
|
|
chunk_idx: int
|
|
|
|
chunk_idx: int
|
|
|
|
attn_chunk: AttnChunk
|
|
|
|
attn_chunk: AttnChunk
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def efficient_dot_product_attention(
|
|
|
|
def efficient_dot_product_attention(
|
|
|
|
query: Tensor,
|
|
|
|
query: Tensor,
|
|
|
|
key: Tensor,
|
|
|
|
key: Tensor,
|
|
|
|
|