|
|
|
|
@ -15,7 +15,13 @@ import torch
|
|
|
|
|
from torch import Tensor
|
|
|
|
|
from torch.utils.checkpoint import checkpoint
|
|
|
|
|
import math
|
|
|
|
|
from typing import Optional, NamedTuple, Protocol, List
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
from typing import Protocol
|
|
|
|
|
except:
|
|
|
|
|
from typing_extensions import Protocol
|
|
|
|
|
|
|
|
|
|
from typing import Optional, NamedTuple, List
|
|
|
|
|
|
|
|
|
|
def narrow_trunc(
|
|
|
|
|
input: Tensor,
|
|
|
|
|
|