output |
---|
github_document |
The goal of TabPFN is to ...
You can install the development version of TabPFN like so:
require(pak)
pak(c("topepo/TabPFN"), ask = FALSE)
The package requires a virtual environment to be created and registered with reticulated. If you don't have one, you can create one. First, load the reticulate package:
require(reticulate)
and this code can be used to create an environment and install the relevant packages:
virtualenv_create(
"r-tabpfn",
packages = c("numpy", "tabpfn"),
python_version = "<3.12"
)
then tell reticulate to use it:
use_virtualenv("~/.virtualenvs/r-tabpfn")
On starting the TabPFN, it will see if the python packages are installed.
library(TabPFN)
To fit a model:
reg_mod <- TabPFN(mtcars[1:25, -1], mtcars$mpg[1:25])
reg_mod
#> TabPFN Regression Model
#> Training set
#> ℹ 25 data points
#> ℹ 10 predictors
In addition to the x/y interface shown above, there are also formula and recipes interfaces.
Prediction follows the usual S3 predict()
method:
predict(reg_mod, mtcars[26:32, -1])
#> # A tibble: 7 × 1
#> .pred
#> <dbl>
#> 1 31.3
#> 2 23.7
#> 3 25.4
#> 4 14.6
#> 5 19.3
#> 6 13.9
#> 7 22.6
While TabPFN isn’t a tidymodels package, it follows their prediction convention: a data frame is always returned with a standard set of column names.
For a classification model, the outcome should always be a factor vector. For example, using these data from the modeldata package:
require(modeldata)
require(ggplot2)
#> Loading required package: ggplot2
two_cls_train <- parabolic[1:400, ]
two_cls_val <- parabolic[401:500,]
grid <- expand.grid(X1 = seq(-5.1, 5.0, length.out = 25),
X2 = seq(-5.5, 4.0, length.out = 25))
cls_mod <- TabPFN(class ~ ., data = two_cls_train)
grid_pred <- predict(cls_mod, grid)
grid_pred
#> # A tibble: 625 × 3
#> .pred_Class1 .pred_Class2 .pred_class
#> <dbl> <dbl> <chr>
#> 1 0.986 0.0137 Class1
#> 2 0.990 0.0100 Class1
#> 3 0.992 0.00816 Class1
#> 4 0.994 0.00646 Class1
#> 5 0.993 0.00654 Class1
#> 6 0.990 0.0101 Class1
#> 7 0.979 0.0213 Class1
#> 8 0.943 0.0572 Class1
#> 9 0.862 0.138 Class1
#> 10 0.696 0.304 Class1
#> # ℹ 615 more rows
The fit looks fairly good when shown with out-of-sample data:
cbind(grid, grid_pred) |>
ggplot(aes(X1, X2)) +
geom_point(data = two_cls_val, aes(col = class, pch = class),
alpha = 3 / 4, cex = 3) +
geom_contour(aes(z = .pred_Class1), breaks = 1/ 2, col = "black", linewidth = 1) +
coord_equal(ratio = 1)
AutoTabPFN (called "TabPFN (PHE)" in the original TabPFN paper) is an ensemble version of TabPFN, that automatically runs a hyperparameter search and build an ensemble of TabPFN models. It is slower to train and use, but on average yields better predictions.
Using AutoTabPFN requires that you install the tabpfn-community.post_hoc_ensembles package. At the time of writing, this is only available from GitHub, and can be installed from a terminal using:
cd ~/.virtualenvs/r-tabpfn
source bin/activate
git clone https://github.com/PriorLabs/tabpfn-community
pip install -e tabpfn-community[post_hoc_ensembles]
To train an AutoTabPFN ensemble, simply use AutoTabPFN()
analogously to how TabPFN()
is used. You can then use predict()
for predictions.
reg_mod <- AutoTabPFN(mpg ~ ., data = mtcars[1:25, ])
predict(reg_mod, mtcars[26:32, -1])
Please note that the TabPFN project is released with a Contributor Code of Conduct. By contributing to this project, you agree to abide by its terms.