@ -106,6 +106,33 @@ def autocast(disable=False):
return torch . autocast ( " cuda " )
return torch . autocast ( " cuda " )
class NansException ( Exception ) :
pass
def test_for_nans ( x , where ) :
from modules import shared
if not torch . all ( torch . isnan ( x ) ) . item ( ) :
return
if where == " unet " :
message = " A tensor with all NaNs was produced in Unet. "
if not shared . cmd_opts . no_half :
message + = " This could be either because there ' s not enough precision to represent the picture, or because your video card does not support half type. Try using --no-half commandline argument to fix this. "
elif where == " vae " :
message = " A tensor with all NaNs was produced in VAE. "
if not shared . cmd_opts . no_half and not shared . cmd_opts . no_half_vae :
message + = " This could be because there ' s not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this. "
else :
message = " A tensor with all NaNs was produced. "
raise NansException ( message )
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
orig_tensor_to = torch . Tensor . to
orig_tensor_to = torch . Tensor . to
def tensor_to_fix ( self , * args , * * kwargs ) :
def tensor_to_fix ( self , * args , * * kwargs ) :
@ -156,3 +183,4 @@ if has_mps():
torch . Tensor . cumsum = lambda self , * args , * * kwargs : ( cumsum_fix ( self , orig_Tensor_cumsum , * args , * * kwargs ) )
torch . Tensor . cumsum = lambda self , * args , * * kwargs : ( cumsum_fix ( self , orig_Tensor_cumsum , * args , * * kwargs ) )
orig_narrow = torch . narrow
orig_narrow = torch . narrow
torch . narrow = lambda * args , * * kwargs : ( orig_narrow ( * args , * * kwargs ) . clone ( ) )
torch . narrow = lambda * args , * * kwargs : ( orig_narrow ( * args , * * kwargs ) . clone ( ) )