8000 MPS torch.where() is giving objectively incorrect results, leading to critical calculation errors · Issue #122916 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
MPS torch.where() is giving objectively incorrect results, leading to critical calculation errors #122916
Closed
@aradley

Description

@aradley

🐛 Describe the bug

I think I have an example of how MPS can get completely different results from CPU. Hopefully the simplicity of this example will be clear and helpful. This may be related to a previous issue noted on this forum (#84936).

import numpy as np
import torch
mps_device = torch.device("mps")

## Create a numpy matrix with many zeros
np.random.seed(0)
Numpy_Test = np.random.random(200000000)
indices = np.random.choice(np.arange(Numpy_Test.size), replace=False,size=int(Numpy_Test.size * 0.6))
Numpy_Test[indices] = 0
Numpy_Matrix = Numpy_Test.reshape((20000,10000))

## Get the indices of non-zero values in the matrix, and convert these indices into a numpy array
indices = np.where(Numpy_Matrix != 0)
indices = np.asarray(indices)

## Use numpy, torch, or a torch.mps object to find where indices[1] == 8000
# Using np.where
np.where(indices[1] == 8000)[0]
array([   19165,    27061,    39165, ..., 79979029, 79987021, 79995171])

# Using torch.where
torch.where(torch.from_numpy(indices)[1] == 8000)[0]
tensor([   19165,    27061,    39165,  ..., 79979029, 79987021, 79995171])

# Using torch.where with an NPS object
torch.where(torch.from_numpy(indices)[1].to(mps_device) == 8000)[0]
tensor([   19165,    27061,    39165,  ..., 79979032, 79987024, 79995168], device='mps:0')

Notice how the first two np.where and torch.where examples give them same results, but when using the tensor converted to MPS we get different results?

If I've not made an obvious mistake, this is a clear example of how MPS completely ruins calculations, because in this case, the indexes change, and all downstream calculations become meaningless.

Versions

torch version v0.2.1 and v0.2.0

cc @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: 64-bitProblems related to incorrectly using 32-bit integers when 64-bit is needed (e.g., 8G tensors)module: correctness (silent)issue that returns an incorrect result silentlymodule: mpsRelated to Apple Metal Performance Shaders frameworktriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0