8000 GitHub - t-kalinowski/TabPFN: Foundation Model for Tabular Data via reticulate
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

t-kalinowski/TabPFN

 
 

Repository files navigation

output
github_document

TabPFN

Lifecycle: experimental CRAN status

The goal of TabPFN is to ...

Installation

You can install the development version of TabPFN like so:

require(pak)
pak(c("topepo/TabPFN"), ask = FALSE)

Example

The package requires a virtual environment to be created and registered with reticulated. If you don't have one, you can create one. First, load the reticulate package:

require(reticulate)

and this code can be used to create an environment and install the relevant packages:

virtualenv_create(
	"r-tabpfn",
	packages = c("numpy", "tabpfn"),
	python_version = "<3.12"
)

then tell reticulate to use it:

use_virtualenv("~/.virtualenvs/r-tabpfn")

On starting the TabPFN, it will see if the python packages are installed.

library(TabPFN)

To fit a model:

reg_mod <- TabPFN(mtcars[1:25, -1], mtcars$mpg[1:25])
reg_mod
#> TabPFN Regression Model
#> Training set
#> ℹ 25 data points
#> ℹ 10 predictors

In addition to the x/y interface shown above, there are also formula and recipes interfaces.

Prediction follows the usual S3 predict() method:

predict(reg_mod, mtcars[26:32, -1])
#> # A tibble: 7 × 1
#>   .pred
#>   <dbl>
#> 1  31.3
#> 2  23.7
#> 3  25.4
#> 4  14.6
#> 5  19.3
#> 6  13.9
#> 7  22.6

While TabPFN isn’t a tidymodels package, it follows their prediction convention: a data frame is always returned with a standard set of column names.

For a classification model, the outcome should always be a factor vector. For example, using these data from the modeldata package:

require(modeldata)
require(ggplot2)
#> Loading required package: ggplot2

two_cls_train <- parabolic[1:400,  ]
two_cls_val   <- parabolic[401:500,]
grid <- expand.grid(X1 = seq(-5.1, 5.0, length.out = 25), 
                    X2 = seq(-5.5, 4.0, length.out = 25))

cls_mod <- TabPFN(class ~ ., data = two_cls_train)

grid_pred <- predict(cls_mod, grid)
grid_pred
#> # A tibble: 625 × 3
#>    .pred_Class1 .pred_Class2 .pred_class
#>           <dbl>        <dbl> <chr>      
#>  1        0.986      0.0137  Class1     
#>  2        0.990      0.0100  Class1     
#>  3        0.992      0.00816 Class1     
#>  4        0.994      0.00646 Class1     
#>  5        0.993      0.00654 Class1     
#>  6        0.990      0.0101  Class1     
#>  7        0.979      0.0213  Class1     
#>  8        0.943      0.0572  Class1     
#>  9        0.862      0.138   Class1     
#> 10        0.696      0.304   Class1     
#> # ℹ 615 more rows

The fit looks fairly good when shown with out-of-sample data:

cbind(grid, grid_pred) |>
 ggplot(aes(X1, X2)) + 
 geom_point(data = two_cls_val, aes(col = class, pch = class), 
            alpha = 3 / 4, cex = 3) +
 geom_contour(aes(z = .pred_Class1), breaks = 1/ 2, col = "black", linewidth = 1) +
 coord_equal(ratio = 1)
plot of chunk boundaries

plot of chunk boundaries

AutoTabPFN

AutoTabPFN (called "TabPFN (PHE)" in the original TabPFN paper) is an ensemble version of TabPFN, that automatically runs a hyperparameter search and build an ensemble of TabPFN models. It is slower to train and use, but on average yields better predictions.

Using AutoTabPFN requires that you install the tabpfn-community.post_hoc_ensembles package. At the time of writing, this is only available from GitHub, and can be installed from a terminal using:

cd ~/.virtualenvs/r-tabpfn
source bin/activate
git clone https://github.com/PriorLabs/tabpfn-community
pip install -e tabpfn-community[post_hoc_ensembles]

To train an AutoTabPFN ensemble, simply use AutoTabPFN() analogously to how TabPFN() is used. You can then use predict() for predictions.

reg_mod <- AutoTabPFN(mpg ~ ., data = mtcars[1:25, ])
predict(reg_mod, mtcars[26:32, -1])

Code of Conduct

Please note that the TabPFN project is released with a Contributor Code of Conduct. By contributing to this project, you agree to abide by its terms.

About

Foundation Model for Tabular Data via reticulate

Resources

License

Code of conduct

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • R 100.0%
0