tt-xla leverages a PJRT interface to integrate JAX (and in the future other frameworks), tt-mlir
and Tenstorrent hardware. It supports ingestion of JAX models via jit compile, providing a StableHLO (SHLO) graph to tt-mlir
compiler.
The tt-xla repository is primarily used to enable running JAX models on Tenstorrent's AI hardware. It's a backend integration between the JAX ecosystem and Tenstorrent's ML accelerators using the PJRT (Portable JAX Runtime) interface.
-
- A TVM based graph compiler designed to optimize and transform computational graphs for deep learning models. Supports ingestion of PyTorch, ONNX, TensorFlow, PaddlePaddle and similar ML frameworks via TVM (tt-tvm).
- See docs pages for an overview and getting started guide.
-
- A MLIR-native, open-source, PyTorch 2.X and torch-mlir based front-end. It provides stableHLO (SHLO) graphs to
tt-mlir
. Supports ingestion of PyTorch models via PT2.X compile and ONNX models via torch-mlir (ONNX->SHLO) - See docs pages for an overview and getting started guide.
- A MLIR-native, open-source, PyTorch 2.X and torch-mlir based front-end. It provides stableHLO (SHLO) graphs to
-
- Leverages a PJRT interface to integrate JAX (and in the future other frameworks),
tt-mlir
and Tenstorrent hardware. Supports ingestion of JAX models via jit compile, providing StableHLO (SHLO) graph tott-mlir
compiler - See getting_started.md for an overview and getting started guide.
- Leverages a PJRT interface to integrate JAX (and in the future other frameworks),
This repo is a part of Tenstorrent’s bounty program. If you are interested in helping to improve tt-forge, please make sure to read the Tenstorrent Bounty Program Terms and Conditions before heading to the issues tab. Look for the issues that are tagged with both “bounty” and difficulty level!