-
Notifications
You must be signed in to change notification settings - Fork 519
Zero copy tensor conversion between xla:gpu and torch.cuda #4692
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Based on my understanding, what we need to do is to make that gpu buffer somehow recognized by the PJRT runtime (as a PJRTBuffer) . I am curious what's your current approach. |
https://github.com/openxla/xla/blob/3c22aa8d716edfc4821b085b920534b4b01e9438/xla/python/dlpack.cc#L283 dlpack implementation of xla and torch would be good references. |
from the Toy Example
It seems the same issue has been addressed in the eager/dynamo mixed scenario with the inductor backend. Seems it could handle When I changed the backend from
I'm not reporting a issue, just saying the |
@kevint324 In fact, as I understand it, the dynamo inductor backend does not give enough experience with zero copy between torch xla + eager. Because even with inductor, the device property of the tensor is not affected, and a tensor at the boundary between dynamo and eager can naturally be used in both scenarios. If torch xla was still using xrt, this would be a tricky problem. But for now xrt is confirmed to be deprecated. I think it's time to rethink the interaction between xla and cuda tensor. |
PyTorch/XLA DLPack support was introduced at #7025. |
Currently, switching between lazy and eager can be a huge overhead even when using the same device. This is mainly due to the ir graph execution and the conversion of tensor device types. However, the latter is not necessary, I think it's historical reasons (xrt), which can be seen from the interface name
TransferToServer
/TransferFromServer
. Even if it is from gpu to the same gpu, it must be redirected from the cpu.I'm implementing a PoC so that
xla_tensor.to('cuda')
andcuda_tensor.to('xla')
are actually zero copy. So far it could running a eager/lazy mixed mnist.But there should be some problems here, I used
_to_copy
op but there is no copy actually, I wonder if there will be problems with the backward direction during training.I am currently considering how to implement zero copy while ensuring correctness, and would like to know if the community has any relevant experience.
The text was updated successfully, but these errors were encountered: