PyTorch/XLA is a Python package that uses the XLA deep learning compiler to connect the PyTorch deep learning framework and Cloud TPUs. You can try it right now, for free, on a single Cloud TPU VM with Kaggle!
Take a look at one of our Kaggle notebooks to get started:
PyTorch/XLA is now on PyPI!
To install PyTorch/XLA a new TPU VM:
pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 -f https://storage.googleapis.com/libtpu-releases/index.html
To update your existing training loop, make the following changes:
-import torch.multiprocessing as mp
+import torch_xla.core.xla_model as xm
+import torch_xla.distributed.parallel_loader as pl
+import torch_xla.distributed.xla_multiprocessing as xmp
def _mp_fn(index):
...
+ # Move the model paramters to your XLA device
+ model.to(xm.xla_device())
+
+ # MpDeviceLoader preloads data to the XLA device
+ xla_train_loader = pl.MpDeviceLoader(train_loader, xm.xla_device())
- for inputs, labels in train_loader:
+ for inputs, labels in xla_train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
- optimizer.step()
+
+ # `xm.optimizer_step` combines gradients across replicas
+ xm.optimizer_step()
if __name__ == '__main__':
- mp.spawn(_mp_fn, args=(), nprocs=world_size)
+ # xmp.spawn automatically selects the correct world size
+ xmp.spawn(_mp_fn, args=())
If you're using DistributedDataParallel
, make the following changes:
import torch.distributed as dist
-import torch.multiprocessing as mp
+import torch_xla.core.xla_model as xm
+import torch_xla.distributed.parallel_loader as pl
+import torch_xla.distributed.xla_multiprocessing as xmp
+import torch_xla.distributed.xla_backend
def _mp_fn(rank):
...
- os.environ['MASTER_ADDR'] = 'localhost'
- os.environ['MASTER_PORT'] = '12355'
- dist.init_process_group("gloo", rank=rank, world_size=world_size)
+ # Rank and world size are inferred from the XLA device runtime
+ dist.init_process_group("xla", init_method='xla://')
+
+ model.to(xm.xla_device())
+ # `gradient_as_bucket_view=True` required for XLA
+ ddp_model = DDP(model, gradient_as_bucket_view=True)
- model = model.to(rank)
- ddp_model = DDP(model, device_ids=[rank])
+ xla_train_loader = pl.MpDeviceLoader(train_loader, xm.xla_device())
- for inputs, labels in train_loader:
+ for inputs, labels in xla_train_loader:
optimizer.zero_grad()
outputs = ddp_model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
if __name__ == '__main__':
- mp.spawn(_mp_fn, args=(), nprocs=world_size)
+ xmp.spawn(_mp_fn, args=())
Additional information on PyTorch/XLA, including a description of its semantics and functions, is available at PyTorch.org. See the API Guide for best practices when writing networks that run on XLA devices (TPU, CUDA, CPU and...).
Our comprehensive user guides are available at:
Documentation for the latest release
Documentation for master branch
PyTorch/XLA releases starting with version r2.1 will be available on PyPI. You
can now install the main build with pip install torch_xla
. To also install the
Cloud TPU plugin, install the optional tpu
dependencies:
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
GPU, XRT (legacy runtime), and nightly builds are available in our public GCS bucket.
Version | Cloud TPU/GPU VMs Wheel |
---|---|
2.2 (Python 3.8) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.2.0-cp38-cp38-manylinux_2_28_x86_64.whl |
2.2 (Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.2 (CUDA 12.1 + Python 3.8) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp38-cp38-manylinux_2_28_x86_64.whl |
2.2 (CUDA 12.1 + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl |
nightly (Python 3.8) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl |
nightly (Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl |
nightly (CUDA 12.1 + Python 3.8) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-nightly-cp38-cp38-linux_x86_64.whl |
older versions
Version | Cloud TPU VMs Wheel |
---|---|
2.1 (XRT + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/xrt/tpuvm/torch_xla-2.1.0%2Bxrt-cp310-cp310-manylinux_2_28_x86_64.whl |
2.1 (Python 3.8) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.1.0-cp38-cp38-linux_x86_64.whl |
2.0 (Python 3.8) | https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-2.0-cp38-cp38-linux_x86_64.whl |
1.13 | https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.13-cp38-cp38-linux_x86_64.whl |
1.12 | https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.12-cp38-cp38-linux_x86_64.whl |
1.11 | https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.11-cp38-cp38-linux_x86_64.whl |
1.10 | https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.10-cp38-cp38-linux_x86_64.whl |
Note: For TPU Pod customers using XRT (our legacy runtime), we have custom
wheels for torch
and torch_xla
at
https://storage.googleapis.com/tpu-pytorch/wheels/xrt
.
Package | Cloud TPU VMs Wheel (XRT on Pod, Legacy Only) |
---|---|
torch_xla | https://storage.googleapis.com/pytorch-xla-releases/wheels/xrt/tpuvm/torch_xla-2.1.0%2Bxrt-cp310-cp310-manylinux_2_28_x86_64.whl |
torch | https://storage.googleapis.com/pytorch-xla-releases/wheels/xrt/tpuvm/torch-2.1.0%2Bxrt-cp310-cp310-linux_x86_64.whl |
Version | GPU Wheel + Python 3.8 |
---|---|
2.1+ CUDA 11.8 | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/11.8/torch_xla-2.1.0-cp38-cp38-manylinux_2_28_x86_64.whl |
2.0 + CUDA 11.8 | https://storage.googleapis.com/tpu-pytorch/wheels/cuda/118/torch_xla-2.0-cp38-cp38-linux_x86_64.whl |
2.0 + CUDA 11.7 | https://storage.googleapis.com/tpu-pytorch/wheels/cuda/117/torch_xla-2.0-cp38-cp38-linux_x86_64.whl |
1.13 | https://storage.googleapis.com/tpu-pytorch/wheels/cuda/112/torch_xla-1.13-cp38-cp38-linux_x86_64.whl |
nightly + CUDA 12.0 >= 2023/06/27 | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.0/torch_xla-nightly-cp38-cp38-linux_x86_64.whl |
nightly + CUDA 11.8 <= 2023/04/25 | https://storage.googleapis.com/tpu-pytorch/wheels/cuda/118/torch_xla-nightly-cp38-cp38-linux_x86_64.whl |
nightly + CUDA 11.8 >= 2023/04/25 | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/11.8/torch_xla-nightly-cp38-cp38-linux_x86_64.whl |
Version | GPU Wheel + Python 3.7 |
---|---|
1.13 | https://storage.googleapis.com/tpu-pytorch/wheels/cuda/112/torch_xla-1.13-cp37-cp37m-linux_x86_64.whl |
1.12 | https://storage.googleapis.com/tpu-pytorch/wheels/cuda/112/torch_xla-1.12-cp37-cp37m-linux_x86_64.whl |
1.11 | https://storage.googleapis.com/tpu-pytorch/wheels/cuda/112/torch_xla-1.11-cp37-cp37m-linux_x86_64.whl |
Version | Colab TPU Wheel |
---|---|
2.0 | https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp310-cp310-linux_x86_64.whl |
You can also add +yyyymmdd
after torch_xla-nightly
to get the nightly wheel
of a specified date. To get the companion pytorch and torchvision nightly wheel,
replace the torch_xla
with torch
or torchvision
on above wheel links.
Version | Cloud TPU VMs Docker |
---|---|
2.2 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.2.0_3.10_tpuvm |
2.1 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_tpuvm |
2.0 | gcr.io/tpu-pytorch/xla:r2.0_3.8_tpuvm |
1.13 | gcr.io/tpu-pytorch/xla:r1.13_3.8_tpuvm |
nightly python | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm |
Version | GPU CUDA 12.1 Docker |
---|---|
2.2 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.2.0_3.10_cuda_12.1 |
2.1 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_cuda_12.1 |
nightly | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1 |
nightly at date | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1_YYYYMMDD |
Version | GPU CUDA 11.8 + Docker |
---|---|
2.1 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_cuda_11.8 |
2.0 | gcr.io/tpu-pytorch/xla:r2.0_3.8_cuda_11.8 |
nightly | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.8 |
nightly at date | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.8_YYYYMMDD |
older versions
Version | GPU CUDA 11.7 + Docker |
---|---|
2.0 | gcr.io/tpu-pytorch/xla:r2.0_3.8_cuda_11.7 |
Version | GPU CUDA 11.2 + Docker |
---|---|
1.13 | gcr.io/tpu-pytorch/xla:r1.13_3.8_cuda_11.2 |
Version | GPU CUDA 11.2 + Docker |
---|---|
1.13 | gcr.io/tpu-pytorch/xla:r1.13_3.7_cuda_11.2 |
1.12 | gcr.io/tpu-pytorch/xla:r1.12_3.7_cuda_11.2 |
To run on compute instances with GPUs.
If PyTorch/XLA isn't performing as expected, see the troubleshooting guide, which has suggestions for debugging and optimizing your network(s).
The PyTorch/XLA team is always happy to hear from users and OSS contributors! The best way to reach out is by filing an issue on this Github. Questions, bug reports, feature requests, build issues, etc. are all welcome!
See the contribution guide.
This repository is jointly operated and maintained by Google, Facebook and a number of individual contributors listed in the CONTRIBUTORS file. For questions directed at Facebook, please send an email to opensource@fb.com. For questions directed at Google, please send an email to pytorch-xla@googlegroups.com. For all other questions, please open up an issue in this repository here.
You can find additional useful reading materials in
- Performance debugging on Cloud TPU VM
- Lazy tensor intro
- Scaling deep learning workloads with PyTorch / XLA and Cloud TPU VM
- Scaling PyTorch models on Cloud TPUs with FSDP
1.对于gpu支持来说,是不健全的,例如: torch_xla.runtime.global_runtime_device_count总返回1,这是因为,在old version中,存在bug,新的master代码已经修复了,所以,需要重新clone新的xla的代码 2.torch xla是基于openxla的,在编译过冲中会下载openxla组件 3.compute_35不支持,可以将.cache中的com[pute_35相关的去掉,再重新编译torch xla 4.xla应该放在torch的源码目录中进行编译,否则,找不到aten等源码依赖的文件
git clone https://github.com/pytorch/pytorch.git cd pytorch USE_CUDA=1 python setup.py install
cd pytorch #注意 git clone https://github.com/pytorch/xla.git cd xla XLA_CUDA=1 python setup.py install
5.编译torch xla前设置环境变量
echo "export PATH=$PATH:/usr/local/cuda-12.1/bin" >> ~/.bashrc echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.1/lib64" >> ~/.bashrc source ~/.bashrc
6.建议用ssh下载(需要添加ssh key到account中)torch/xla/openxla等源码,否则容易遇到EoF(网速不行,或Buffer不够,或其他的因素导致的,总之,https限制很多)