Closed
Description
🐛 Describe the bug
I think I have an example of how MPS can get completely different results from CPU. Hopefully the simplicity of this example will be clear and helpful. This may be related to a previous issue noted on this forum (#84936).
import numpy as np
import torch
mps_device = torch.device("mps")
## Create a numpy matrix with many zeros
np.random.seed(0)
Numpy_Test = np.random.random(200000000)
indices = np.random.choice(np.arange(Numpy_Test.size), replace=False,size=int(Numpy_Test.size * 0.6))
Numpy_Test[indices] = 0
Numpy_Matrix = Numpy_Test.reshape((20000,10000))
## Get the indices of non-zero values in the matrix, and convert these indices into a numpy array
indices = np.where(Numpy_Matrix != 0)
indices = np.asarray(indices)
## Use numpy, torch, or a torch.mps object to find where indices[1] == 8000
# Using np.where
np.where(indices[1] == 8000)[0]
array([ 19165, 27061, 39165, ..., 79979029, 79987021, 79995171])
# Using torch.where
torch.where(torch.from_numpy(indices)[1] == 8000)[0]
tensor([ 19165, 27061, 39165, ..., 79979029, 79987021, 79995171])
# Using torch.where with an NPS object
torch.where(torch.from_numpy(indices)[1].to(mps_device) == 8000)[0]
tensor([ 19165, 27061, 39165, ..., 79979032, 79987024, 79995168], device='mps:0')
Notice how the first two np.where and torch.where examples give them same results, but when using the tensor converted to MPS we get different results?
If I've not made an obvious mistake, this is a clear example of how MPS completely ruins calculations, because in this case, the indexes change, and all downstream calculations become meaningless.
Versions
torch version v0.2.1 and v0.2.0