@ -54,4 +54,6 @@ if has_mps:
CondFunc ( ' torch.cumsum ' , cumsum_fix_func , None )
CondFunc ( ' torch.cumsum ' , cumsum_fix_func , None )
CondFunc ( ' torch.Tensor.cumsum ' , cumsum_fix_func , None )
CondFunc ( ' torch.Tensor.cumsum ' , cumsum_fix_func , None )
CondFunc ( ' torch.narrow ' , lambda orig_func , * args , * * kwargs : orig_func ( * args , * * kwargs ) . clone ( ) , None )
CondFunc ( ' torch.narrow ' , lambda orig_func , * args , * * kwargs : orig_func ( * args , * * kwargs ) . clone ( ) , None )
if version . parse ( torch . __version__ ) == version . parse ( " 2.0 " ) :
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
CondFunc ( ' torch.nn.functional.layer_norm ' , lambda orig_func , x , normalized_shape , weight , bias , eps , * * kwargs : orig_func ( x . float ( ) , normalized_shape , weight . float ( ) if weight is not None else None , bias . float ( ) if bias is not None else bias , eps ) . to ( x . dtype ) , lambda * args , * * kwargs : len ( args ) == 6 )