Add workaround for MPS layer_norm on PyTorch 2.0
On PyTorch 2.0, with MPS layer_norm only accepts float32 inputs. This was fixed shortly after 2.0 was finalized so the workaround can be applied with an exact version match.master
parent
c5142e2fbe
commit
27fe3eb6a9
Loading…
Reference in New Issue