PyTorch implementation of the generalized Polyharmonic Spline interpolation (also known as Thin Plate Spline in 2D). It learns a smooth elastic mapping between two Euclidean spaces with support for:
- Arbitrary input and output dimensions
- Arbitrary spline order
k
- Optional regularization
- Supports CPU and GPU parallelization
Useful for interpolation, deformation fields, and smooth non-linear regression.
For a NumPy implementation, see tps.
This implementation is much faster than the NumPy one, thanks to the cpu //. Using gpu seems not to be much faster for fitting (linear system solving), but is much faster to transform (as this is simply a matrix multiplication).
$ pip install torch-tps
import torch
from torch_tps import ThinPlateSpline
# Control points
X_train = torch.random.normal(0, 1, (800, 3)) # 800 points in R^3
Y_train = torch.random.normal(0, 1, (800, 2)) # Values for each point (800 values in R^2)
# New source points to interpolate
X_test = torch.random.normal(0, 1, (3000, 3))
# Initialize spline model (Regularization is controled with alpha parameter)
tps = ThinPlateSpline(alpha=0.5) # Use device="cuda" to switch to gpu
# Fit spline from control points
tps.fit(X_train, Y_train)
# Interpolate new points
Y_test = tps.transform(X_test)
See the example/
folder for scripts showing:
- Interpolation in 1D, 2D, 3D
- Arbitrary input and output dimensions
- Image warping with elastic deformation
Example of increasing/decreasing/randomly deforming a dog's face using sparse control points.
Code: example/image_warping.py
The model solves the regularized interpolation problem:
With solution:
Where:
- G(r): radial basis function (RBF) (depends on
order
and the input dimensiond
) - P(x): a polynomial of degree
order - 1
Default kernel (TPS):
General kernel: $$ \begin{aligned} &G(r) = r^{(2 \text{order} - d)} & \text{ if d is odd}\ &G(r) = r^{(2\text{order} - d)} \log(r) & \text{ otherwise}\end{aligned} $$
Creates a general polyharmonic spline interpolator (Default to TPS in 2D and natural cubic splines in 1D).
- alpha (float): Regularization strength (default 0.0)
- order (int): Spline order (default is 2 for TPS)
- enforce_tps_kernel (bool): Force TPS kernel r^2 log r, even when mathematically suboptimal
- device (torch.device): Use "cuda" to enable gpu computations. Default to "cpu".
Fits the model to control point pairs.
- X:
(n, d)
input coordinates - Y:
(n, v)
target coordinates
Returns: self
Applies the learned mapping to new input points.
- X:
(n', d)
points
Returns: (n', v)
interpolated values
git clone https://github.com/raphaelreme/tps.git
cd tps
pip install -e .
MIT License