|
|
|
@ -181,7 +181,7 @@ def einsum_op_cuda(q, k, v):
|
|
|
|
mem_free_torch = mem_reserved - mem_active
|
|
|
|
mem_free_torch = mem_reserved - mem_active
|
|
|
|
mem_free_total = mem_free_cuda + mem_free_torch
|
|
|
|
mem_free_total = mem_free_cuda + mem_free_torch
|
|
|
|
# Divide factor of safety as there's copying and fragmentation
|
|
|
|
# Divide factor of safety as there's copying and fragmentation
|
|
|
|
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
|
|
|
return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
|
|
|
|
|
|
|
|
|
|
|
def einsum_op(q, k, v):
|
|
|
|
def einsum_op(q, k, v):
|
|
|
|
if q.device.type == 'cuda':
|
|
|
|
if q.device.type == 'cuda':
|
|
|
|
|