8000 MPS torch.where() is giving objectively incorrect results, leading to critical calculation errors · Issue #122916 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

MPS torch.where() is giving objectively incorrect results, leading to critical calculation errors #122916

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
aradley opened this issue Mar 28, 2024 · 7 comments
Labels
module: correctness (silent) issue that returns an incorrect result silently module: mps Related to Apple Metal Performance Shaders framework module: 64-bit Problems related to incorrectly using 32-bit integers when 64-bit is needed (e.g., 8G tensors) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@aradley
Copy link
aradley commented Mar 28, 2024

🐛 Describe the bug

I think I have an example of how MPS can get completely different results from CPU. Hopefully the simplicity of this example will be clear and helpful. This may be related to a previous issue noted on this forum (#84936).

import numpy as np
import torch
mps_device = torch.device("mps")

## Create a numpy matrix with many zeros
np.random.seed(0)
Numpy_Test = np.random.random(200000000)
indices = np.random.choice(np.arange(Numpy_Test.size), replace=False,size=int(Numpy_Test.size * 0.6))
Numpy_Test[indices] = 0
Numpy_Matrix = Numpy_Test.reshape((20000,10000))

## Get the indices of non-zero values in the matrix, and convert these indices into a numpy array
indices = np.where(Numpy_Matrix != 0)
indices = np.asarray(indices)

## Use numpy, torch, or a torch.mps object to find where indices[1] == 8000
# Using np.where
np.where(indices[1] == 8000)[0]
array([   19165,    27061,    39165, ..., 79979029, 79987021, 79995171])

# Using torch.where
torch.where(torch.from_numpy(indices)[1] == 8000)[0]
tensor([   19165,    27061,    39165,  ..., 79979029, 79987021, 79995171])

# Using torch.where with an NPS object
torch.where(torch.from_numpy(indices)[1].to(mps_device) == 8000)[0]
tensor([   19165,    27061,    39165,  ..., 79979032, 79987024, 79995168], device='mps:0')

Notice how the first two np.where and torch.where examples give them same results, but when using the tensor converted to MPS we get different results?

If I've not made an obvious mistake, this is a clear example of how MPS completely ruins calculations, because in this case, the indexes change, and all downstream calculations become meaningless.

Versions

torch version v0.2.1 and v0.2.0

cc @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr

@malfet malfet added module: mps Related to Apple Metal Performance Shaders framework module: correctness (silent) issue that returns an incorrect result silently module: 64-bit Problems related to incorrectly using 32-bit integers when 64-bit is needed (e.g., 8G tensors) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Mar 28, 2024
@alpoge
Copy link
alpoge commented Mar 30, 2024

hey solved this :D! i'll put in a pull request w the fix v soon (sry im a bit slow, am a pure mathematician and thus never actually pr'ed before...)

basically the issue is that in the mpsGraph scatter operation the tensor that is written from, which in the case of this torch.where call (which turns out to be a torch.nonzero call) is a list of coordinates which are int32, is secretly being cast to a float32 behind the scenes. you'll notice that the outputs always have the top 24 bits correct! (and indeed casting an int to a float starts rounding it past 2^(24)!)

so all that is required is to split the coordinates tensor into two in the mpsGraph calls --- one modulo 2^(23), say, and one (integer-)divided by 2^(23), scatter those, and then add them back up

i should have the fix for this requested very soon!!! all credit to @Jckwind for spreading the word about this (and for getting me up to speed)!

hopefully a number of these other MPS arithmetic issues are related, we shall see...

@kulinseth
Copy link
Collaborator

Thanks @alpoge for the fix. We are looking into if there is a more efficient way to do where we can use all the int32 index range values.

@kjhenner
Copy link

Thank you! I was beginning to doubt my sanity.

pytorchmergebot pushed a commit that referenced this issue Aug 16, 2024
Fixes Issue #122916

Resolves correctness issue seen with large inputs to the mps nonzero op by using a different scatter mode. Native nonzero op is still used with smaller inputs for better performance.
Pull Request resolved: #126188
Approved by: https://github.com/kulinseth, https://github.com/malfet
@skotapati
Copy link
Collaborator

This should be resolved by the PR linked above, please reopen or file a new issue if further issues are observed

malfet pushed a commit to aditew01/pytorch that referenced this issue Sep 13, 2024
…26188)

Fixes Issue pytorch#122916

Resolves correctness issue seen with large inputs to the mps nonzero op by using a different scatter mode. Native nonzero op is still used with smaller inputs for better performance.
Pull Request resolved: pytorch#126188
Approved by: https://github.com/kulinseth, https://github.com/malfet
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this issue Sep 26, 2024
…26188)

Fixes Issue pytorch#122916

Resolves correctness issue seen with large inputs to the mps nonzero op by using a different scatter mode. Native nonzero op is still used with smaller inputs for better performance.
Pull Request resolved: pytorch#126188
Approved by: https://github.com/kulinseth, https://github.com/malfet
@wiggin15
Copy link

Hi @skotapati.
It looks like the PR above never merged. Can we reopen this please?

@hvaara
Copy link
Contributor
hvaara commented Feb 23, 2025

@wiggin15 Are you currently seeing this issue? As mentioned above, #126188 should have fixed it.

@wiggin15
Copy link
wiggin15 commented Mar 1, 2025

@hvaara sorry, I thought #126188 was not merged because GitHub shows it as "Closed" rather than "Merged".
I see the issue is fixed with the latest version of PyTorch. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: correctness (silent) issue that returns an incorrect result silently module: mps Related to Apple Metal Performance Shaders framework module: 64-bit Problems related to incorrectly using 32-bit integers when 64-bit is needed (e.g., 8G tensors) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

8 participants
0