-
Notifications
You must be signed in to change notification settings - Fork 24.1k
MPS torch.where() is giving objectively incorrect results, leading to critical calculation errors #122916
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
hey solved this :D! i'll put in a pull request w the fix v soon (sry im a bit slow, am a pure mathematician and thus never actually pr'ed before...) basically the issue is that in the mpsGraph scatter operation the tensor that is written from, which in the case of this torch.where call (which turns out to be a torch.nonzero call) is a list of coordinates which are int32, is secretly being cast to a float32 behind the scenes. you'll notice that the outputs always have the top 24 bits correct! (and indeed casting an int to a float starts rounding it past 2^(24)!) so all that is required is to split the coordinates tensor into two in the mpsGraph calls --- one modulo 2^(23), say, and one (integer-)divided by 2^(23), scatter those, and then add them back up i should have the fix for this requested very soon!!! all credit to @Jckwind for spreading the word about this (and for getting me up to speed)! hopefully a number of these other MPS arithmetic issues are related, we shall see... |
Thanks @alpoge for the fix. We are looking into if there is a more efficient way to do where we can use all the int32 index range values. |
Thank you! I was beginning to doubt my sanity. |
Fixes Issue #122916 Resolves correctness issue seen with large inputs to the mps nonzero op by using a different scatter mode. Native nonzero op is still used with smaller inputs for better performance. Pull Request resolved: #126188 Approved by: https://github.com/kulinseth, https://github.com/malfet
This should be resolved by the PR linked above, please reopen or file a new issue if further issues are observed |
…26188) Fixes Issue pytorch#122916 Resolves correctness issue seen with large inputs to the mps nonzero op by using a different scatter mode. Native nonzero op is still used with smaller inputs for better performance. Pull Request resolved: pytorch#126188 Approved by: https://github.com/kulinseth, https://github.com/malfet
…26188) Fixes Issue pytorch#122916 Resolves correctness issue seen with large inputs to the mps nonzero op by using a different scatter mode. Native nonzero op is still used with smaller inputs for better performance. Pull Request resolved: pytorch#126188 Approved by: https://github.com/kulinseth, https://github.com/malfet
Hi @skotapati. |
🐛 Describe the bug
I think I have an example of how MPS can get completely different results from CPU. Hopefully the simplicity of this example will be clear and helpful. This may be related to a previous issue noted on this forum (#84936).
Notice how the first two np.where and torch.where examples give them same results, but when using the tensor converted to MPS we get different results?
If I've not made an obvious mistake, this is a clear example of how MPS completely ruins calculations, because in this case, the indexes change, and all downstream calculations become meaningless.
Versions
torch version v0.2.1 and v0.2.0
cc @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr
The text was updated successfully, but these errors were encountered: