@ -2,9 +2,10 @@ import sys, os, shlex
import contextlib
import torch
from modules import errors
from packaging import version
# has_mps is only available in nightly pytorch (for now) and Mas OS 12.3+.
# has_mps is only available in nightly pytorch (for now) and mac OS 12.3+.
# check `getattr` and try it for compatibility
def has_mps ( ) - > bool :
if not getattr ( torch , ' has_mps ' , False ) :
@ -94,3 +95,28 @@ def autocast(disable=False):
return contextlib . nullcontext ( )
return torch . autocast ( " cuda " )
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
orig_tensor_to = torch . Tensor . to
def tensor_to_fix ( self , * args , * * kwargs ) :
if self . device . type != ' mps ' and \
( ( len ( args ) > 0 and isinstance ( args [ 0 ] , torch . device ) and args [ 0 ] . type == ' mps ' ) or \
( isinstance ( kwargs . get ( ' device ' ) , torch . device ) and kwargs [ ' device ' ] . type == ' mps ' ) ) :
self = self . contiguous ( )
return orig_tensor_to ( self , * args , * * kwargs )
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
orig_layer_norm = torch . nn . functional . layer_norm
def layer_norm_fix ( * args , * * kwargs ) :
if len ( args ) > 0 and isinstance ( args [ 0 ] , torch . Tensor ) and args [ 0 ] . device . type == ' mps ' :
args = list ( args )
args [ 0 ] = args [ 0 ] . contiguous ( )
return orig_layer_norm ( * args , * * kwargs )
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
if has_mps ( ) and version . parse ( torch . __version__ ) < version . parse ( " 1.13 " ) :
torch . Tensor . to = tensor_to_fix
torch . nn . functional . layer_norm = layer_norm_fix