|
|
|
@ -53,12 +53,10 @@ def torch_gc():
|
|
|
|
|
|
|
|
|
|
|
|
def enable_tf32():
|
|
|
|
def enable_tf32():
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
for devid in range(0,torch.cuda.device_count()):
|
|
|
|
if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]):
|
|
|
|
if torch.cuda.get_device_capability(devid) == (7, 5):
|
|
|
|
|
|
|
|
shd = True
|
|
|
|
|
|
|
|
if shd:
|
|
|
|
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
torch.backends.cudnn.enabled = True
|
|
|
|
torch.backends.cudnn.enabled = True
|
|
|
|
|
|
|
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
|
|
|
|
|
|
|
|
|
|
|