-
Notifications
You must be signed in to change notification settings - Fork 24.4k
[feature request] symmetric matrix square 8000 root #25481
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I actually need a batched version of this, cc @vishwakftw for ideas |
One naive idea would be to compute the symmetric eigendecomposition and compute the square root of the eigenvalues, and then reconstruct the matrix. edit: I didn’t see the code sorry. I’ll look into this. |
There is also a Newton step-based matrix squared root: #9983 (comment) |
@vadimkantorov these are the same guys that did PyTorch op for matrix square root. So I wonder why they chose to use scipy for the forward pass. They mention that Newton step convergence radius is limited. I've seen this problem when playing with Newon method for matrix inversion (Schultz iteration), the initial guess had to be pretty close to the answer to get reasonable performance. |
It seems that https://github.com/msubhransu/matrix-sqrt/blob/master/matrix_sqrt.py does not use any scipy... (I couldn't find any "scipy" substring). I must be missing something. |
You are right, I was thinking of this implementation, which is from different authors https://github.com/steveli/pytorch-sqrtm |
@yaroslavvb Ah, I see. Though they seem to use scipy both for forward and backward (using a Sylvester equation solver in scipy) :) |
I have few points to make here since I worked on the same problem before:
def symsqrt(matrix):
"""Compute the square root of a positive definite matrix."""
# perform the decomposition
# s, v = matrix.symeig(eigenvectors=True)
_, s, v = matrix.svd() # passes torch.autograd.gradcheck()
# truncate small components
above_cutoff = s > s.max() * s.size(-1) * torch.finfo(s.dtype).eps
s = s[..., above_cutoff]
v = v[..., above_cutoff]
# compose the square root matrix
return (v * s.sqrt().unsqueeze(-2)) @ v.transpose(-2, -1)
def special_sylvester(a, b):
"""Solves the eqation `A @ X + X @ A = B` for a positive definite `A`."""
# https://math.stackexchange.com/a/820313
s, v = a.symeig(eigenvectors=True)
d = s.unsqueeze(-1)
d = d + d.transpose(-2, -1)
vt = v.transpose(-2, -1)
c = vt @ b @ v
return v @ (c / d) @ vt |
@ModarTensai about eps, is |
You are absolutely right. Last time I checked, they didn't port |
@ModarTensai nice tips!
|
@ModarTensai The <ipython-input-2-730fdb238ed4> in symsqrt(matrix)
7 above_cutoff = s > s.max() * s.size(-1) * torch.finfo(s.dtype).eps
8 s = s[..., above_cutoff]
----> 9 v = v[..., above_cutoff]
10 # compose the square root matrix
11 return (v * s.sqrt().unsqueeze(-2)) @ v.transpose(-2, -1)
IndexError: The shape of the mask [32, 100] at index 0does not match the shape of the indexed tensor [32, 100, 100] at index 1 |
@rharish101, thank you for pointing this out. We cannot do that anymore. Here is a quick fix for you: def symsqrt(matrix):
"""Compute the square root of a positive definite matrix."""
_, s, v = matrix.svd()
zero = torch.zeros((), device=s.device, dtype=s.dtype)
threshold = s.max(-1).values * s.size(-1) * torch.finfo(s.dtype).eps
s = s.where(s > threshold.unsqueeze(-1), zero) # zero out small components
return (v * s.sqrt().unsqueeze(-2)) @ v.transpose(-2, -1) A more performent code should truncate the common columns in the batch and zero out the rest. |
( @ModarTensai could also probably save |
@vadimkantorov, that is true but the line became too long for my taste :) def symsqrt(matrix):
"""Compute the square root of a positive definite matrix."""
_, s, v = matrix.svd()
good = s > s.max(-1, True).values * s.size(-1) * torch.finfo(s.dtype).eps
components = good.sum(-1)
common = components.max()
unbalanced = common != components.min()
if common < s.size(-1):
s = s[..., :common]
v = v[..., :common]
if unbalanced:
good = good[..., :common]
if unbalanced:
s = s.where(good, torch.zeros((), device=s.device, dtype=s.dtype))
return (v * s.sqrt().unsqueeze(-2)) @ v.transpose(-2, -1) You can test it out like this: x = torch.randn(5, 10, 10).double()
x = x @ x.transpose(-1, -2)
y = symsqrt(x)
print(torch.allclose(x, y @ y.transpose(-1, -2)))
x.requires_grad = True
torch.autograd.gradcheck(symsqrt, [x])
torch.autograd.gradgradcheck(symsqrt, [x]) |
Thank you all for your prompt replies! |
Hi everyone, For these reasons I implement a custom torch.autograd.Function following this implementation: https://github.com/msubhransu/matrix-sqrt Edit: well, I figured it out, the device is stored in input.device class MatrixSquareRoot(Function):
"""Square root of a positive definite matrix.
NOTE: matrix square root is not differentiable for matrices with
zero eigenvalues.
"""
@staticmethod
def forward(ctx, input):
dim = input.shape[0]
norm = torch.norm(input.double())
Y = input.double()/norm
I = torch.eye(dim,dim,device=input.device).double()
Z = torch.eye(dim,dim,device=input.device).double()
for i in range(20):
T = 0.5*(3.0*I - Z.mm(Y))
Y = Y.mm(T)
Z = T.mm(Z)
sqrtm = Y*torch.sqrt(norm)
ctx.mark_dirty(Y,I,Z)
ctx.save_for_backward(sqrtm)
return sqrtm
@staticmethod
def backward(ctx, grad_output):
grad_input = None
sqrtm, = ctx.saved_tensors
dim = sqrtm.shape[0]
norm = torch.norm(sqrtm)
A = sqrtm/norm
I = torch.eye(dim, dim, device=sqrtm.device).double()
Q = grad_output.double()/norm
for i in range(20):
Q = 0.5*(Q.mm(3.0*I-A.mm(A))-A.t().mm(A.t().mm(Q)-Q.mm(A)))
A = 0.5*A.mm(3.0*I-A.mm(A))
grad_input = 0.5*Q
return grad_input
sqrtm = MatrixSquareRoot.apply |
@JonathanVacher torch.symeig should run on GPU if the input is on GPU. What are the SVD failures? There's a rare SVD failure on singular matrices due to a limitation of gesdd algorithm, fails both in magma and scipy implementations -- #25978 (comment) . Using Newton iteration seems to require double precision to provide similar level of accuracy, but this makes it 10x slower than decomposition-based approaches I've benchmarked the three solutions in the colab here on 4000 matrix and T4 GPU, modifying your code to use same precision as the input -- Pytorch symmetric square root
Switching to single precision makes it run at similar speed as svd/symeig versions but at lower accuracy. So it seems Newton-based implementation is inferior to decomposition-based. However, it could be useful for stochastic setting -- when the target is a noisy sample, a single Newton step may give desired precision, in which case this approach would be an order of magnitude faster |
Ok, so if it's work for you I may have a problem.
Here are my env infos:
Any idea ? My Cmake version ? |
@JonathanVacher You could rule out difference in environment by adding your code to my colab notebook and running it there, or copy-pasting code from that notebook and running it in your environment. Wall-clock time is more useful measure than CPU utilization, running Your SVD error says |
@yaroslavvb Magma uses MKL for some intermediate computation under the hood: I've had some MKL errors earlier with GPU eig() computations before: #9384 (when it's incorrect parameter error, I assume this is a bug in either Magma or MKL) |
Ok I don't understand at all how is it possible to observe such a difference in timing ! All implementations work now. I will need to check the rest of my code. I think everything is on the GPU memory though... Will see. |
@JonathanVacher just reran this code on V100 and see same timing as above. It seems switching from T4 to V100 doesn't affect svd/symeig speed, meanwhile matmul-based Newton method becomes 4x faster in single precision (I expected 3x faster) and 30x faster in double precision. So for V100 cards, Newton method seems better in all respects. It seems |
Hi All, I was wondering if there has been any updates on this? It has been about 8 months since the last comment above. I have a related question to this which I have posted in Stackoverflow here: https://stackoverflow.com/questions/64462253/pytorch-square-root-of-a-positive-semi-definite-matrix. I was wondering if anyone could assist me with this question. Would really appreciate it! |
Many implementations are available in this issue and you can use any of them. There is no direct implementation in pytorch as the issue is still open. |
I am just about to start working on general matrix square root, FR. My plan is to use the IN iteration (Higman, Functions of Matrices, page 142, formula 6.20) with scaling and some compensation scheme for better precision. The good thing it does handle rank deficient cases as the iteration does start with a matrix Looks like the methods based on the Schur decomposition are best for CPU, so for symmetric matrices it is probably better to just resort to |
@ModarTensai , thank you for the sylvester equation code! It will be of use for the analytic backward! |
@nikitaved, you are most welcome. Please, let me know if you need anything else. Currently, I am busy with CVPR but I will gladly join this effort, if needed, after the deadline. |
Hi @JonathanVacher, many thanks for your reply. However, I am looking for an implementation which works for positive semi-definite matrices. All the implementations that are available in this issue are only applicable for positive definite matrices? |
Should work for singular matrices, that's what |
@leockl, if it is the backward you are concerned about (reciprocals of pairwise sums of eigenvalues), note that the matrix square root is a function of a matrix, which allows to get the gradient by running the function of a matrix on a 4x larger input.
It is exactly how the backward for the matrix exponential is implemented in PyTorch. Maybe you could try this approach? Only forward methods, no need for explicit backward unless you can do it more efficiently (so that you could use |
On the note regarding |
Are there any updates on this issue? |
Posted by Rohan Anil on Twitter: https://twitter.com/_arohan_/status/1436378163219558402 https://github.com/google-research/google-research/blob/master/scalable_shampoo/pytorch/matrix_functions.py#L79 - matrix inverse powers (using Newton) |
Thanks for providing such a function to do the square root. When I checked some math books, I found that a symmetric matrix isn't equivalent to a semi-definite matrix. So, this function only works well when the matrix is semi-definite. If not, the eigen values of the matrix could be negative. |
Just for reference, the SVD of a SPD matrix is just its eigenvalue decomposition. As such, it's more efficient to use Changing that in an implementation given above, an implementation of this function that works for both symmetric and Hermitian matrices may be given as per: import torch
from torch import linalg
def sqrtmh(A):
"""Compute the square root of a Symmetric or Hermitian positive definite matrix or batch of matrices"""
L, Q = linalg.eigh(A)
zero = torch.zeros((), device=L.device, dtype=L.dtype)
threshold = L.max(-1).values * L.size(-1) * torch.finfo(L.dtype).eps
L = L.where(L > threshold.unsqueeze(-1), zero) # zero out small components
return (Q * L.sqrt().unsqueeze(-2)) @ Q.mH Note that this implementation also gives correct gradients for almost all matrices, but its gradients are incorrect (they actually diverge when they shouldn't) for matrices with repeated eigenvalues. Now, this is not that big of a deal, because the algorithms for pytorch/torch/csrc/autograd/FunctionsManual.cpp Lines 3165 to 3191 in 5e6f296
Having all these things in consideration and seeing the performance benefits of having a Newton rule, we could consider adding this function to |
Hi, this code doesn't work for me. |
You should update to a newer version of PyTorch. |
Got it, Thanks. And I found this function seems produce different output compared with scipy.linalg.sqrtm, And btw the function that @xmodar provided also has the same output as yours. Is the function has bigger error compared with scipy.linalg.sqrtm? |
Or does Now, I don't think this discussion should go here but into the forum or stackoverflow. Please feel free to post it there, cross link it editing your post if you want, and we can discuss all these points there. |
@DRJYYDS It seems that torch does not use pairwise summation for the trace, which can lead to large error accumulation. pytorch/aten/src/ATen/native/ReduceOps.cpp Lines 1245 to 1275 in 0276d56
They should be using a variant of |
Also interesting that trace does not seem to support batches :( |
If you need batching, a workaround (for square matrices) is |
One can show that for the matrix square root, the Vector-Jacobian-Product is given by |
Anyone else still looking, I adapted the sqrt_newton_schulz from matrix-sqrt to include the early stopping for divergence used in Tensorflow's implementation. On my test set of 1000 batches, it outperformed I'm sure this would be variable for deferent test sets since it is a variable length loop and it probably has cases I missed. Hope it help anyone out who needs it! def sqrt_newton_schulz(A, numIters=200):
""" Newton-Schulz iterations method to get matrix square root.
Page 231, Eq 2.6b
http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.6.8799&rep=rep1&type=pdf
Args:
A: the symmetric PSD matrix whose matrix square root be computed
numIters: Maximum number of iterations.
Returns:
A^0.5
Tensorflow Source:
https://github.com/tensorflow/tensorflow/blob/df3a3375941b9e920667acfe72fb4c33a8f45503/tensorflow/contrib/opt/python/training/matrix_functions.py#L26C1-L73C42
Torch Source:
https://github.com/msubhransu/matrix-sqrt/blob/cc2289a3ed7042b8dbacd53ce8a34da1f814ed2f/matrix_sqrt.py#L74
"""
normA = torch.linalg.matrix_norm(A, keepdim=True)
err = normA + 1.0
I = torch.eye(*A.shape[-2:], dtype=A.dtype)
Z = torch.eye(*A.shape[-2:], dtype=A.dtype).expand_as(A)
Y = A / normA
for i in range(numIters):
T = 0.5*(3.0*I - Z.bmm(Y))
Y_new = Y.bmm(T)
Z_new = T.bmm(Z)
# This method require that we check for divergence every step.
# Compute the error in approximation.
mat_a_approx = torch.bmm(Y_new, Y_new) * normA
residual = A - mat_a_approx
current_err = torch.linalg.matrix_norm(residual, keepdim=True) / normA
if torch.all(current_err > err):
break
err = current_err
Y = Y_new
Z = Z_new
sA = Y*torch.sqrt(normA)
return sA |
Still no sqrtm in 2024 ? |
We don't have anyone currently implementing new features in |
Over SciPy we are getting ready to switch to a C rewrite of It is straightforward C code (to my eyes that is, and I'm no C expert) so maybe it is also interesting for you folks to do your own magic spin on it. |
Uh oh!
There was an error while loading. Please reload this page.
This is needed when incorporating curvature information into optimization
There's PyTorch implementation of symmetric matrix square root op here but they use PyTorch for backward pass only, and use scipy for forward pass
In the meantime, see #25481 (comment) for an implementation dispatching to
eigh
.cc @vincentqb @vishwakftw @jianyuh @nikitaved @pearu @mruberry @heitorschueroff @ssnl
The text was updated successfully, but these errors were encountered: