8000 [feature request] symmetric matrix square root · Issue #25481 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[feature request] symmetric matrix square 8000 root #25481

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
yaroslavvb opened this issue Aug 30, 2019 · 48 comments
Open

[feature request] symmetric matrix square root #25481

yaroslavvb opened this issue Aug 30, 2019 · 48 comments
Labels
function request A request for a new function or the addition of new arguments/modes to an existing function. module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@yaroslavvb
Copy link
Contributor
yaroslavvb commented Aug 30, 2019

This is needed when incorporating curvature information into optimization

There's PyTorch implementation of symmetric matrix square root op here but they use PyTorch for backward pass only, and use scipy for forward pass

In the meantime, see #25481 (comment) for an implementation dispatching to eigh.

cc @vincentqb @vishwakftw @jianyuh @nikitaved @pearu @mruberry @heitorschueroff @ssnl

@yaroslavvb
Copy link
Contributor Author

I actually need a batched version of this, cc @vishwakftw for ideas

@vishwakftw
Copy link
Contributor
vishwakftw commented Aug 30, 2019

One naive idea would be to compute the symmetric eigendecomposition and compute the square root of the eigenvalues, and then reconstruct the matrix.

edit: I didn’t see the code sorry. I’ll look into this.

@vadimkantorov
Copy link
Contributor

There is also a Newton step-based matrix squared root: #9983 (comment)

@vishwakftw vishwakftw added feature A request for a proper, new feature. module: operators module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul labels Aug 30, 2019
@yf225 yf225 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 30, 2019
@yaroslavvb
Copy link
Contributor Author

@vadimkantorov these are the same guys that did PyTorch op for matrix square root. So I wonder why they chose to use scipy for the forward pass. They mention that Newton step convergence radius is limited. I've seen this problem when playing with Newon method for matrix inversion (Schultz iteration), the initial guess had to be pretty close to the answer to get reasonable performance.

@vadimkantorov
Copy link
Contributor
vadimkantorov commented Aug 31, 2019

It seems that https://github.com/msubhransu/matrix-sqrt/blob/master/matrix_sqrt.py does not use any scipy... (I couldn't find any "scipy" substring). I must be missing something.

@yaroslavvb
Copy link
Contributor Author

You are right, I was thinking of this implementation, which is from different authors https://github.com/steveli/pytorch-sqrtm

@vadimkantorov
Copy link
Contributor

@yaroslavvb Ah, I see. Though they seem to use scipy both for forward and backward (using a Sylvester equation solver in scipy) :)

@xmodar
Copy link
xmodar commented Oct 21, 2019

I have few points to make here since I worked on the same problem before:

  • b.t() should be replaced by b.transpose(-2, -1) to support batching
  • b.diag() should be replaced by b.diag_embed() to support batching
  • A @ b.diag_embed() should be replaced by A * b.unsqeeuze(-2) for efficiency and batching
  • _, s, v = matrix.svd() is more stable and efficient, for some reason, than s, v = matrix.symeig(eigenvectors=True) to pass torch.autograd.gradcheck
  • scipy.linalg.pinvh (v1.3.1) multiplies the condition value by max(a.shape) instead of 1E3 or 1E6
  • cond can be computed without taking the absolute value of the eigenvalues since we are taking their square root and assuming positive definite matrices
  • eps should not be hard coded and can be computed for any given type or by using @vadimkantorov suggestion torch.finfo
  • Putting it all together:
def symsqrt(matrix):
    """Compute the square root of a positive definite matrix."""
    # perform the decomposition
    # s, v = matrix.symeig(eigenvectors=True)
    _, s, v = matrix.svd()  # passes torch.autograd.gradcheck()
    # truncate small components
    above_cutoff = s > s.max() * s.size(-1) * torch.finfo(s.dtype).eps
    s = s[..., above_cutoff]
    v = v[..., above_cutoff]
    # compose the square root matrix
    return (v * s.sqrt().unsqueeze(-2)) @ v.transpose(-2, -1)
def special_sylvester(a, b):
    """Solves the eqation `A @ X + X @ A = B` for a positive definite `A`."""
    # https://math.stackexchange.com/a/820313
    s, v = a.symeig(eigenvectors=True)
    d = s.unsqueeze(-1)
    d = d + d.transpose(-2, -1)
    vt = v.transpose(-2, -1)
    c = vt @ b @ v
    return v @ (c / d) @ vt

@vadimkantorov
Copy link
Contributor
vadimkantorov commented Oct 21, 2019

@ModarTensai about eps, is torch.finfo(s.dtype).eps useful here? another way would be allow user passing their own eps if they wish

@xmodar
Copy link
xmodar commented Oct 21, 2019

You are absolutely right. Last time I checked, they didn't port finfo from numpy yet. I will update it in the original comment.

@yaroslavvb
Copy link
Contributor Author

@ModarTensai nice tips!

  • I got the 1e3 and 1e6 factors from scipy.linalg.pinv2 . However, in master, they were removed and I couldn't track down where they came from
  • solve_sylvester can be extended to no solution/multiple solutions case by discarding eigenvalues below the cut-off. Empirically this seem to give the same answer as LeastSquares on Kronecker-expanded equations
  • for some applications symsqrt(A) can be replaced by a "non-symmetric matrix square root" A=BB' which can be faster (ie, 10-100x faster when A=Hessian of cross-entropy loss)

@rharish101
Copy link

@ModarTensai The symsqrt function does not work for batched inputs, eg. of shape [32, 100, 100]. The error encountered is:

<ipython-input-2-730fdb238ed4> in symsqrt(matrix)
      7     above_cutoff = s > s.max() * s.size(-1) * torch.finfo(s.dtype).eps
      8     s = s[..., above_cutoff]
----> 9     v = v[..., above_cutoff]
     10     # compose the square root matrix
     11     return (v * s.sqrt().unsqueeze(-2)) @ v.transpose(-2, -1)

IndexError: The shape of the mask [32, 100] at index 0does not match the shape of the indexed tensor [32, 100, 100] at index 1

@xmodar
Copy link
xmodar commented Jan 20, 2020

@rharish101, thank you for pointing this out. We cannot do that anymore. Here is a quick fix for you:

def symsqrt(matrix):
    """Compute the square root of a positive definite matrix."""
    _, s, v = matrix.svd()
    zero = torch.zeros((), device=s.device, dtype=s.dtype)
    threshold = s.max(-1).values * s.size(-1) * torch.finfo(s.dtype).eps
    s = s.where(s > threshold.unsqueeze(-1), zero)  # zero out small components
    return (v * s.sqrt().unsqueeze(-2)) @ v.transpose(-2, -1)

A more performent code should truncate the common columns in the batch and zero out the rest.

@vadimkantorov
Copy link
Contributor

( @ModarTensai could also probably save unsqueeze() on threshold if keepdim=True is used in s.max(dim = -1, keepdim = True).values )

@xmodar
Copy link
xmodar commented Jan 21, 2020

@vadimkantorov, that is true but the line became too long for my taste :)
@rharish101, here is the same code again but with truncation as well.

def symsqrt(matrix):
    """Compute the square root of a positive definite matrix."""
    _, s, v = matrix.svd()
    good = s > s.max(-1, True).values * s.size(-1) * torch.finfo(s.dtype).eps
    components = good.sum(-1)
    common = components.max()
    unbalanced = common != components.min()
    if common < s.size(-1):
        s = s[..., :common]
        v = v[..., :common]
        if unbalanced:
            good = good[..., :common]
    if unbalanced:
        s = s.where(good, torch.zeros((), device=s.device, dtype=s.dtype))
    return (v * s.sqrt().unsqueeze(-2)) @ v.transpose(-2, -1)

You can test it out like this:

x = torch.randn(5, 10, 10).double()
x = x @ x.transpose(-1, -2)
y = symsqrt(x)
print(torch.allclose(x, y @ y.transpose(-1, -2)))
x.requires_grad = True
torch.autograd.gradcheck(symsqrt, [x])
torch.autograd.gradgradcheck(symsqrt, [x])

@rharish101
Copy link

Thank you all for your prompt replies!

@JonathanVacher
Copy link
JonathanVacher commented Feb 11, 2020

Hi everyone,
I don't know exactly why but torch.svd() generates errors when the gradient is required while torch.symeig() is fine with the gradient but computations are run on CPUs.

For these reasons I implement a custom torch.autograd.Function following this implementation: https://github.com/msubhransu/matrix-sqrt
I didn't know exactly how to handle the silent parameters inside the function so I added the device as an input. Now, all the computations are run on the GPU.

Edit: well, I figured it out, the device is stored in input.device

class MatrixSquareRoot(Function):
    """Square root of a positive definite matrix.
    NOTE: matrix square root is not differentiable for matrices with
          zero eigenvalues.
    """    
    @staticmethod
    def forward(ctx, input):
        dim = input.shape[0]
        norm = torch.norm(input.double())
        Y = input.double()/norm
        I = torch.eye(dim,dim,device=input.device).double()
        Z = torch.eye(dim,dim,device=input.device).double()
        for i in range(20):
            T = 0.5*(3.0*I - Z.mm(Y))
            Y = Y.mm(T)
            Z = T.mm(Z)
        sqrtm = Y*torch.sqrt(norm)
        ctx.mark_dirty(Y,I,Z)
        ctx.save_for_backward(sqrtm)
        return sqrtm

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = None
        sqrtm, = ctx.saved_tensors
        dim = sqrtm.shape[0]
        norm = torch.norm(sqrtm)
        A = sqrtm/norm
        I = torch.eye(dim, dim, device=sqrtm.device).double()
        Q = grad_output.double()/norm
        for i in range(20):
            Q = 0.5*(Q.mm(3.0*I-A.mm(A))-A.t().mm(A.t().mm(Q)-Q.mm(A)))
            A = 0.5*A.mm(3.0*I-A.mm(A))
        grad_input = 0.5*Q
        return grad_input    
sqrtm = MatrixSquareRoot.apply

@yaroslavvb
Copy link
Contributor Author

@JonathanVacher torch.symeig should run on GPU if the input is on GPU. What are the SVD failures? There's a rare SVD failure on singular matrices due to a limitation of gesdd algorithm, fails both in magma and scipy implementations -- #25978 (comment) . Using Newton iteration seems to require double precision to provide similar level of accuracy, but this makes it 10x slower than decomposition-based approaches

I've benchmarked the three solutions in the colab here on 4000 matrix and T4 GPU, modifying your code to use same precision as the input -- Pytorch symmetric square root

symeig: 1.4 seconds
svd: 4.4 seconds
Newton (30 iterations, double precision): 33 seconds

Switching to single precision makes it run at similar speed as svd/symeig versions but at lower accuracy. So it seems Newton-based implementation is inferior to decomposition-based. However, it could be useful for stochastic setting -- when the target is a noisy sample, a single Newton step may give desired precision, in which case this approach would be an order of magnitude faster

@JonathanVacher
Copy link

Ok, so if it's work for you I may have a problem.
1/ with my implementation, my algorithm is running with 1 CPU core and 1 GPU
2/ with symeig my algorithm is running but 8 CPU cores are used in addition to 1 GPU
3/ with svd it crashes at the 1st iteration when computing a svd.
I tried to heavily regularize my the covariances so I can conclude that it's not because of matrix singularity (and anyway it works with symeig).

Intel MKL ERROR: Parameter 4 was incorrect on entry to SLASCL.
Intel MKL ERROR: Parameter 4 was incorrect on entry to SLASCL.
Traceback (most recent call last):
[...]
 _, s, v = matrix.svd()  # passes torch.autograd.gradcheck()
RuntimeError: svd_cuda: the updating process of SBDSDC did not converge (error: 14)

Here are my env infos:

PyTorch version: 1.4.0
Is debug build: No
CUDA used to build PyTorch: 10.1

OS: Ubuntu 18.04.3 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: version 3.10.2

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 10.1.105
GPU models and configuration: 
GPU 0: TITAN V
GPU 1: TITAN V
GPU 2: TITAN V
GPU 3: TITAN V

Nvidia driver version: 430.50
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.0

Versions of relevant libraries:
[pip3] numpy==1.18.1
[pip3] torch==1.4.0
[pip3] torchfile==0.1.0
[pip3] torchvision==0.4.2
[conda] Could not collect

Any idea ? My Cmake version ?

@yaroslavvb
Copy link
Contributor Author
yaroslavvb commented Feb 11, 2020

@JonathanVacher You could rule out difference in environment by adding your code to my colab notebook and running it there, or copy-pasting code from that notebook and running it in your environment.

Wall-clock time is more useful measure than CPU utilization, running taskset 0x1 python test_svd.py drops my CPU usage from 32 cores to 1 core without affecting wall-clock time much.

Your SVD error says Intel MKL so I'm guessing you are feeding CPU tensors both into your symeig and svd experiments. Google says this error is caused by nan or inf values in your input tensor, so try assert np.isfinite(tensor.cpu().numpy()). As a side-note, I've seen crashes in MKL which were fixed (at the expense of speed) by setting OMP_NUM_THREADS=1 environment variable

@vadimkantorov
Copy link
Contributor
vadimkantorov commented Feb 11, 2020

@yaroslavvb Magma uses MKL for some intermediate computation under the hood: I've had some MKL errors earlier with GPU eig() computations before: #9384 (when it's incorrect parameter error, I assume this is a bug in either Magma or MKL)

@JonathanVacher
Copy link
JonathanVacher commented Feb 11, 2020

Ok I don't understand at all how is it possible to observe such a difference in timing !
Edit: I updated the image, I forgot to synchronize with the correct GPU. It's still faster for me with 20 Newton iteration.

All implementations work now. I will need to check the rest of my code. I think everything is on the GPU memory though... Will see.

matrix-sqrt

@yaroslavvb
Copy link
Contributor Author
yaroslavvb commented Feb 12, 2020

@JonathanVacher just reran this code on V100 and see same timing as above. It seems switching from T4 to V100 doesn't affect svd/symeig speed, meanwhile matmul-based Newton method becomes 4x faster in single precision (I expected 3x faster) and 30x faster in double precision. So for V100 cards, Newton method seems better in all respects.

It seems torch.symeig and torch.svd are having hard time saturating V100 compute units. I'm seeing power draw of about 100 watts when running those routines in the benchmark above, whereas Newton-method uses close to 300 Watts. Another advantage is that it works with half-precision. Getting about 2% relative error with 20 iterations in half precision on 8k-by-8k matrices, about 6x faster than symeig approach

@leockl
Copy link
leockl commented Oct 21, 2020

Hi All,

I was wondering if there has been any updates on this? It has been about 8 months since the last comment above.

I have a related question to this which I have posted in Stackoverflow here: https://stackoverflow.com/questions/64462253/pytorch-square-root-of-a-positive-semi-definite-matrix. I was wondering if anyone could assist me with this question. Would really appreciate it!

@JonathanVacher
Copy link

Many implementations are available in this issue and you can use any of them. There is no direct implementation in pytorch as the issue is still open.

@nikitaved
Copy link
Collaborator
nikitaved commented Oct 21, 2020

I am just about to start working on general matrix square root, FR. My plan is to use the IN iteration (Higman, Functions of Matrices, page 142, formula 6.20) with scaling and some compensation scheme for better precision. The good thing it does handle rank deficient cases as the iteration does start with a matrix X + I.
It is also possible to use the basic Newton method, which has a quadratic (although local) convergence for inputs with small eigenvalues, however, I do not immediately see how to extend it to arbitrary matrices. It is possible to employ a proper rescaling only for inputs of full rank (and after the reslace we can loose the rank). The rank deficient case is not that obvious to me yet...

Looks like the methods based on the Schur decomposition are best for CPU, so for symmetric matrices it is probably better to just resort to symeig.

@nikitaved
Copy link
Collaborator

@ModarTensai , thank you for the sylvester equation code! It will be of use for the analytic backward!

@xmodar
Copy link
xmodar commented Oct 21, 2020

@nikitaved, you are most welcome. Please, let me know if you need anything else. Currently, I am busy with CVPR but I will gladly join this effort, if needed, after the deadline.

@leockl
Copy link
leockl commented Oct 23, 2020

Hi @JonathanVacher, many thanks for your reply.

However, I am looking for an implementation which works for positive semi-definite matrices. All the implementations that are available in this issue are only applicable for positive definite matrices?

@yaroslavvb
Copy link
Contributor Author

Hi @JonathanVacher, many thanks for your reply.

However, I am looking for an implementation which works for positive semi-definite matrices. All the implementations that are available in this issue are only applicable for positive definite matrices?

Should work for singular matrices, that's what above_cutoff logic is for

@nikitaved
Copy link
Collaborator
nikitaved commented Oct 24, 2020

@leockl, if it is the backward you are concerned about (reciprocals of pairwise sums of eigenvalues), note that the matrix square root is a function of a matrix, which allows to get the gradient by running the function of a matrix on a 4x larger input.
Check out:

Mathias, Roy.
A Chain Rule for Matrix Functions and Applications.
SIAM J. Matrix Anal. Appl. 17 (1996): 610-620.

It is exactly how the backward for the matrix exponential is implemented in PyTorch.

Maybe you could try this approach? Only forward methods, no need for explicit backward unless you can do it more efficiently (so that you could use symeig in forward, for example, symeig.backward has some limitations).

@nikitaved
Copy link
Collaborator
nikitaved commented Oct 24, 2020

On the note regarding symeig.backward stability. It is only stable for inputs of full rank (minus 1) with distinct eigenvalues with some gap, as one of the intermediate steps of the algorithm is computing the reciprocals of pairwise differences between the eigenvalues.

@vfdev-5 vfdev-5 mentioned this issue Jun 4, 2021
3 tasks
@RylanSchaeffer
Copy link

Are there any updates on this issue?

@vadimkantorov
Copy link
Contributor

@xmlyqing00
Copy link

@vadimkantorov, that is true but the line became too long for my taste :) @rharish101, here is the same code again but with truncation as well.

def symsqrt(matrix):
    """Compute the square root of a positive definite matrix."""
    _, s, v = matrix.svd()
    good = s > s.max(-1, True).values * s.size(-1) * torch.finfo(s.dtype).eps
    components = good.sum(-1)
    common = components.max()
    unbalanced = common != components.min()
    if common < s.size(-1):
        s = s[..., :common]
        v = v[..., :common]
        if unbalanced:
            good = good[..., :common]
    if unbalanced:
        s = s.where(good, torch.zeros((), device=s.device, dtype=s.dtype))
    return (v * s.sqrt().unsqueeze(-2)) @ v.transpose(-2, -1)

You can test it out like this:

x = torch.randn(5, 10, 10).double()
x = x @ x.transpose(-1, -2)
y = symsqrt(x)
print(torch.allclose(x, y @ y.transpose(-1, -2)))
x.requires_grad = True
torch.autograd.gradcheck(symsqrt, [x])
torch.autograd.gradgradcheck(symsqrt, [x])

Thanks for providing such a function to do the square root. When I checked some math books, I found that a symmetric matrix isn't equivalent to a semi-definite matrix. So, this function only works well when the matrix is semi-definite. If not, the eigen values of the matrix could be negative.

@lezcano
Copy link
Collaborator
lezcano commented Feb 8, 2022

Just for reference, the SVD of a SPD matrix is just its eigenvalue decomposition. As such, it's more efficient to use linalg.eigh than linalg.svd. Note that the comment in #25481 (comment) does not apply anymore, as now we have a nice implementation of the gradients of eigh that works both in the real and complex case (!!)

Changing that in an implementation given above, an implementation of this function that works for both symmetric and Hermitian matrices may be given as per:

import torch
from torch import linalg

def sqrtmh(A):
    """Compute the square root of a Symmetric or Hermitian positive definite matrix or batch of matrices"""
    L, Q = linalg.eigh(A)
    zero = torch.zeros((), device=L.device, dtype=L.dtype)
    threshold = L.max(-1).values * L.size(-1) * torch.finfo(L.dtype).eps
    L = L.where(L > threshold.unsqueeze(-1), zero)  # zero out small components
    return (Q * L.sqrt().unsqueeze(-2)) @ Q.mH

Note that this implementation also gives correct gradients for almost all matrices, but its gradients are incorrect (they actually diverge when they shouldn't) for matrices with repeated eigenvalues. Now, this is not that big of a deal, because the algorithms for eigh and svd themselves are also problematic when used on matrices with repeated eigenvalues, and people don't really complain :)
Now, if one wants to be extra careful, it's possible to implement the backward of sqrtmh via the usual trick explained in https://epubs.siam.org/doi/pdf/10.1137/S0895479895283409. As an example for how to do this, we have an implementation in C++ at

// Based on:
//
// Mathias, Roy.
// A Chain Rule for Matrix Functions and Applications.
// SIAM J. Matrix Anal. Appl. 17 (1996): 610-620.
template <typename func_t>
Tensor differential_analytic_matrix_function(
const Tensor& self, const Tensor& grad,
const func_t& matrix_function,
const bool adjoint // Choose between forward (adjoint=false) or backward AD (adjoint=true)
) {
// Given an analytic matrix function, this computes the differential (forward AD)
// or the adjoint of the differential (backward AD)
auto A = adjoint ? self.transpose(-2, -1).conj() : self;
auto meta_grad_sizes = A.sizes().vec();
meta_grad_sizes[A.dim() - 2] *= 2;
meta_grad_sizes[A.dim() - 1] *= 2;
auto n = A.size(-1);
auto meta_grad = at::zeros(meta_grad_sizes, grad.options());
meta_grad.narrow(-2, 0, n).narrow(-1, 0, n).copy_(A);
meta_grad.narrow(-2, n, n).narrow(-1, n, n).copy_(A);
meta_grad.narrow(-2, 0, n).narrow(-1, n, n).copy_(grad);
return matrix_function(meta_grad).narrow(-2, 0, n).narrow(-1, n, n);
}

Having all these things in consideration and seeing the performance benefits of having a Newton rule, we could consider adding this function to torch.linalg with two different backends: spectral / iterative and explain in the docs the trade-off between them. cc @nikitaved

@DRJYYDS
Copy link
DRJYYDS commented Apr 26, 2022

Just for reference, the SVD of a SPD matrix is just its eigenvalue decomposition. As such, it's CEB7 more efficient to use linalg.eigh than linalg.svd. Note that the comment in #25481 (comment) does not apply anymore, as now we have a nice implementation of the gradients of eigh that works both in the real and complex case (!!)

Changing that in an implementation given above, an implementation of this function that works for both symmetric and Hermitian matrices may be given as per:

import torch
from torch import linalg

def sqrtmh(A):
    """Compute the square root of a Symmetric or Hermitian positive definite matrix or batch of matrices"""
    L, Q = linalg.eigh(A)
    zero = torch.zeros((), device=L.device, dtype=L.dtype)
    threshold = L.max(-1).values * L.size(-1) * torch.finfo(L.dtype).eps
    L = L.where(L > threshold.unsqueeze(-1), zero)  # zero out small components
    return (Q * L.sqrt().unsqueeze(-2)) @ Q.mH

Note that this implementation also gives correct gradients for almost all matrices, but its gradients are incorrect (they actually diverge when they shouldn't) for matrices with repeated eigenvalues. Now, this is not that big of a deal, because the algorithms for eigh and svd themselves are also problematic when used on matrices with repeated eigenvalues, and people don't really complain :) Now, if one wants to be extra careful, it's possible to implement the backward of sqrtmh via the usual trick explained in https://epubs.siam.org/doi/pdf/10.1137/S0895479895283409. As an example for how to do this, we have an implementation in C++ at

// Based on:
//
// Mathias, Roy.
// A Chain Rule for Matrix Functions and Applications.
// SIAM J. Matrix Anal. Appl. 17 (1996): 610-620.
template <typename func_t>
Tensor differential_analytic_matrix_function(
const Tensor& self, const Tensor& grad,
const func_t& matrix_function,
const bool adjoint // Choose between forward (adjoint=false) or backward AD (adjoint=true)
) {
// Given an analytic matrix function, this computes the differential (forward AD)
// or the adjoint of the differential (backward AD)
auto A = adjoint ? self.transpose(-2, -1).conj() : self;
auto meta_grad_sizes = A.sizes().vec();
meta_grad_sizes[A.dim() - 2] *= 2;
meta_grad_sizes[A.dim() - 1] *= 2;
auto n = A.size(-1);
auto meta_grad = at::zeros(meta_grad_sizes, grad.options());
meta_grad.narrow(-2, 0, n).narrow(-1, 0, n).copy_(A);
meta_grad.narrow(-2, n, n).narrow(-1, n, n).copy_(A);
meta_grad.narrow(-2, 0, n).narrow(-1, n, n).copy_(grad);
return matrix_function(meta_grad).narrow(-2, 0, n).narrow(-1, n, n);
}

Having all these things in consideration and seeing the performance benefits of having a Newton rule, we could consider adding this function to torch.linalg with two different backends: spectral / iterative and explain in the docs the trade-off between them. cc @nikitaved

Hi, this code doesn't work for me.
The error tip is 'Tensor' object has no attribute 'mH'
could you help me to fix this?

@lezcano
Copy link
Collaborator
lezcano commented Apr 26, 2022

You should update to a newer version of PyTorch.

@DRJYYDS
Copy link
DRJYYDS commented Apr 26, 2022

You should update to a newer version of PyTorch.

Got it, Thanks.

And I found this function seems produce different output compared with scipy.linalg.sqrtm,

1a776a59a007734415c00f5d1662675

And btw the function that @xmodar provided also has the same output as yours.

Is the function has bigger error compared with scipy.linalg.sqrtm?

@lezcano
Copy link
Collaborator
lezcano commented Apr 26, 2022

Or does linalg.sqrtm have a larger error than this one (I'd doubt that tbh). You should measure that.

Now, I don't think this discussion should go here but into the forum or stackoverflow. Please feel free to post it there, cross link it editing your post if you want, and we can discuss all these points there.

@randolf-scholz
Copy link
Contributor
randolf-scholz commented Nov 2, 2023

@DRJYYDS It seems that torch does not use pairwise summation for the trace, which can lead to large error accumulation.

// NOTE: this could be implemented via diag and sum, but this has perf problems,
// see https://github.com/pytorch/pytorch/pull/47305,
Tensor trace_cpu(const Tensor& self) {
Tensor result;
// Returns the ScalarType of the self tensor if the tensor is non integral type
// In the case, self is an integer type tensor, at::kLong is return since promote_integers
// is set to true
ScalarType dtype = get_dtype_from_self(self, c10::nullopt, true);
result = at::empty({}, self.options().dtype(dtype));
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(self.scalar_type(), "trace", [&] {
using accscalar_t = at::acc_type<scalar_t, false>;
accscalar_t sum = 0;
const auto* t_data = self.data_ptr<scalar_t>();
int64_t t_stride_0, t_stride_1, t_diag_size;
TORCH_CHECK(self.dim() == 2, "trace: expected a matrix, but got tensor with dim ", self.dim());
t_stride_0 = self.stride(0);
t_stride_1 = self.stride(1);
t_diag_size = std::min(self.size(0), self.size(1));
for (const auto i : c10::irange(t_diag_size)) {
sum += t_data[i * (t_stride_0 + t_stride_1)];
}
set_result< F438 ;scalar_t>(result, sum);
});
return result;
}

They should be using a variant of .diag().sum(), but this apparently had some performance issues: #47305

@vadimkantorov
Copy link
Contributor

Also interesting that trace does not seem to support batches :(

@randolf-scholz
Copy link
Contributor

If you need batching, a workaround (for square matrices) is torch.einsum("...ii -> ...", batched_tensor)

@randolf-scholz
Copy link
Contributor

One can show that for the matrix square root, the Vector-Jacobian-Product is given by $G ⟼ \text{solve}(\sqrt{A}X+X\sqrt{A}=G)$ (see https://math.stackexchange.com/q/540361), which is a continuous time Lyapunov-equation. This might be a preferable approach for the backward pass, if good solvers are available.

@SpontaneousDuck
Copy link

Anyone else still looking, I adapted the sqrt_newton_schulz from matrix-sqrt to include the early stopping for divergence used in Tensorflow's implementation. On my test set of 1000 batches, it outperformed sqrtmh with lower error (3.1764e-16 vs 1.1543e-15) but can take significantly longer (100ms vs 9ms) depending on use case. It ended up being faster on the matrices I was working on by the best and closest comparison from this list was sqrtmh.

I'm sure this would be variable for deferent test sets since it is a variable length loop and it probably has cases I missed. Hope it help anyone out who needs it!

def sqrt_newton_schulz(A, numIters=200):
    """ Newton-Schulz iterations method to get matrix square root.
    Page 231, Eq 2.6b
    http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.6.8799&rep=rep1&type=pdf

    Args:
        A: the symmetric PSD matrix whose matrix square root be computed
        numIters: Maximum number of iterations.

    Returns:
        A^0.5

    Tensorflow Source:
        https://github.com/tensorflow/tensorflow/blob/df3a3375941b9e920667acfe72fb4c33a8f45503/tensorflow/contrib/opt/python/training/matrix_functions.py#L26C1-L73C42
    Torch Source:
        https://github.com/msubhransu/matrix-sqrt/blob/cc2289a3ed7042b8dbacd53ce8a34da1f814ed2f/matrix_sqrt.py#L74
    """

    normA = torch.linalg.matrix_norm(A, keepdim=True)
    err = normA + 1.0
    I = torch.eye(*A.shape[-2:], dtype=A.dtype)
    Z = torch.eye(*A.shape[-2:], dtype=A.dtype).expand_as(A)
    Y = A / normA
    for i in range(numIters):
        T = 0.5*(3.0*I - Z.bmm(Y))
        Y_new = Y.bmm(T)
        Z_new = T.bmm(Z)

        # This method require that we check for divergence every step.
        # Compute the error in approximation.
        mat_a_approx = torch.bmm(Y_new, Y_new) * normA
        residual = A - mat_a_approx
        current_err = torch.linalg.matrix_norm(residual, keepdim=True) / normA
        if torch.all(current_err > err):
            break

        err = current_err
        Y = Y_new
        Z = Z_new

    sA = Y*torch.sqrt(normA)
    
    return sA

@bgalerne
Copy link

Still no sqrtm in 2024 ?

@lezcano
Copy link
Collaborator
lezcano commented Jun 26, 2024

We don't have anyone currently implementing new features in torch.linalg. We would accept a contribution that implements this function dispatching it to BLAS and implementing a fallback for CUDA.

@ilayn
Copy link
ilayn commented Mar 28, 2025

Over SciPy we are getting ready to switch to a C rewrite of sqrtm to address many issues that were historically not handled including batched inputs scipy/scipy#22406 Uses the now-standard blocked recursion and tries its best to stay real-valued for real inputs. I would highly appreciate if interested folks can stress test it.

It is straightforward C code (to my eyes that is, and I'm no C expert) so maybe it is also interesting for you folks to do your own magic spin on it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
function request A request for a new function or the addition of new arguments/modes to an existing function. module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

0