-
Notifications
You must be signed in to change notification settings - Fork 24.4k
[feature request] Add matrix functions #9983
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Matrix power can be easily implemented in PyTorch by following the |
Are these worth porting into ATen? If so, I can take it up. |
Yes, please! |
A generalization based on SVD (https://en.wikipedia.org/wiki/Matrix_function) may also be useful (for simple graph signal processing): diagonalize the matrix somehow (svd by default, maybe some optimized svd for symmetric matrices), apply the user-supplied function to eigenvalues, re-multiply the factors back. For diagonalizable matrices, this way of doing matrix power may also be faster (for large powers). This would also enable matrix square root, though for matrix square root there seem to be specialized approx algorithms based on newton iteration. |
@zou3519 @fmassa I went through the SciPy implementation for
|
Upvote for this, matrix exponential |
The approach by @vadimkantorov seems to make most sense IMO because there's already a (differentiable) PyTorch function for SVD, and after that the exponentiation is just elemwise exp for the diagonal.
|
Already implemented |
OK great to hear that @ssnl! Can you point out the function (also for reference to anyone else that happens to stumble on this thread)? Can't find it in the docs... |
This issue is partially resolved: |
@harpone here it is for reference https://pytorch.org/docs/master/torch.html?highlight=matrix_power#torch.matrix_power |
Sorry about that. I updated the issue title and desc to reflect the current status :) |
@ssnl It would be nice if matrix_power also supported non-integer powers like 0.5 and other cases. It may though require a full svd approach or some custom Newton-step approximation (saw it in some paper for a fast matrix square root estimation without doing an svd) |
@vadimkantorov Yep that would be nice. Do you still have a link to the paper? |
@ssnl For Newton-step for matrix square root: https://arxiv.org/abs/1707.06772 |
Since this involves a paper, maybe can this go into |
I think a larger feature to decide is about supporting non-integer powers in matrix_power. (The paper is just about a particular faster matrix square root approximation scheme - I agree, it can go to contrib.) |
Can developers please tell us that by when we should expect the implementations of expm and logm functions in pytorch? |
This might be helpful and CC @SkafteNicki |
Would this allow back-prop with functions involving matrix_expo? |
That guy seems to have implemented this as well. I tried to build his module but gut stuck with linkage issues. The pure pytorch example was working correctly |
@ferrine I never got the cuda implementation to work, because I got stuck with some linkage to cuSOLVER. The pure pytorch example works fine and can back-prop gradients. The cuda implementation was just to get a faster implementation, since I use matrix exponentials quite a bit myself. |
I've managed to implement (with a simple test) matrix exponential using torch script here: |
Just in case you need a taste for the limitations of various matrix exponential algorithms: https://www.cs.cornell.edu/cv/ResearchPDF/19ways+.pdf |
Are there any updates on this? I'm looking forward to the matrix log! |
Sorry, no update on matrix log or matrix sqrt. We are planning to review better SciPy linear algebra compatibility in the near future, however. |
@nikitaved has done work on matrix sqrt, which should be unblocked
A93C
once his |
also looking forward to matrix sqrt! thanks for sharing an update on this |
The issue with the matrix sqrt that I see is its implementation for CUDA, it is not so straightforward to translate the SciPy's implementation to the gpu (no Schur decomposition, dependence in loops). It is possible, however, to implement an iterative algorithm either only for the GPU, or for both the CPU and CUDA. If we agree on iterative algorithms, then I could start working on it. |
Any idea when matrix log will be available? |
As @nikitaved mentioned above, it is particularly tricky to implement import scipy.linalg
import torch
def adjoint(A, E, f):
A_H = A.mH.to(E.dtype)
n = A.size(0)
M = torch.zeros(2*n, 2*n, dtype=E.dtype, device=E.device)
M[:n, :n] = A_H
M[n:, n:] = A_H
M[:n, n:] = E
return f(M)[:n, n:].to(A.dtype)
def logm_scipy(A):
return torch.from_numpy(scipy.linalg.logm(A.cpu(), disp=False)[0]).to(A.device)
class Logm(torch.autograd.Function):
@staticmethod
def forward(ctx, A):
assert A.ndim == 2 and A.size(0) == A.size(1) # Square matrix
assert A.dtype in (torch.float32, torch.float64, torch.complex64, torch.complex128)
ctx.save_for_backward(A)
return logm_scipy(A)
@staticmethod
def backward(ctx, G):
A, = ctx.saved_tensors
return adjoint(A, G, logm_scipy)
logm = Logm.apply
A = torch.rand(3, 3, dtype=torch.float64, requires_grad=True)
torch.autograd.gradcheck(logm, A)
A = torch.rand(3, 3, dtype=torch.complex128, requires_grad=True)
torch.autograd.gradcheck(logm, A) It does not support batches, as If you replace the call to |
Any updates on the matrix square root? |
Not much yet for the same reasons as for the import scipy.linalg
import torch
def adjoint(A, E, f):
A_H = A.T.conj().to(E.dtype)
n = A.size(0)
M = torch.zeros(2*n, 2*n, dtype=E.dtype, device=E.device)
M[:n, :n] = A_H
M[n:, n:] = A_H
M[:n, n:] = E
return f(M)[:n, n:].to(A.dtype)
def sqrtm_scipy(A):
return torch.from_numpy(scipy.linalg.sqrtm(A.cpu(), disp=False)[0]).to(A.device)
class Sqrtm(torch.autograd.Function):
@staticmethod
def forward(ctx, A):
assert A.ndim == 2 and A.size(0) == A.size(1) # Square matrix
assert A.dtype in (torch.float32, torch.float64, torch.complex64, torch.complex128)
ctx.save_for_backward(A)
return sqrtm_scipy(A)
@staticmethod
def backward(ctx, G):
A, = ctx.saved_tensors
return adjoint(A, G, sqrtm_scipy)
sqrtm = Sqrtm.apply
A = torch.rand(3, 3, dtype=torch.float64, requires_grad=True)
torch.autograd.gradcheck(sqrtm, A)
A = torch.rand(3, 3, dtype=torch.complex128, requires_grad=True)
torch.autograd.gradcheck(sqrtm, A) |
@nikitaved, @lezcano any progress with implementing |
Those issues are not fixed in TF. TF doesn't provide any better functionality or substantial performance improvements than
|
Here's a simple implementation of matrix logarithm that does not use scipy. Importantly, it can run on devices other than CPUs and be backpropagated through, but it assumes that the matrix is diagonalizable. Hopefully, this helps: def torch_logm(A):
lam, V = torch.linalg.eig(A)
V_inv = torch.inverse(V).to(torch.complex128)
V = V.to(torch.complex128)
if not torch.allclose(A.to(torch.complex128), V @ torch.diag(lam).to(torch.complex128) @ V_inv):
raise ValueError("Matrix is not diagonalizable, cannot compute matrix logarithm!")
log_A_prime = torch.diag(lam.log())
return V @ log_A_prime @ V_inv Quick example usage: >>> A = torch.tensor(np.random.rand(3, 3), requires_grad=True).to('cuda')
>>> A.retain_grad()
>>> torch_logm(A).sum().backward()
>>> A.grad
# tensor([[0.4875, 1.0091, 0.7602],
# [0.1207, 0.2930, 1.2993],
# [0.6632, 0.6987, 0.3534]], device='cuda:0', dtype=torch.float64) |
The issue with this implementation is that the computation is very badly conditioned, as it doesn't exist a stable algorithm to compute the eigenvalues of an arbitrary matrix. This implementation in particular has a number of other issues. For example, it computes the inverse of the eigenvalues explicitly, which is also problematic. Even more, it recomputes the eigenvalues again and calls them If you want a stable and numerically correct implementation of this function, albeit with some limitations due to the use of SciPy, see #9983 (comment). Feel free to have a peak at the SciPy implementation for this function, to see why is it so tricky to implement, and even more, why is it so tricky to implement in a GPU (it has lots of data-dependent parts). |
Thanks for the quick response! I've edited the above to remove the extra computation of Also, why is computing the inverse of the eigenvectors explicitly a problem? How would you go about not doing so? |
These questions related to classical linear algebra may be better suited for the forum: https://discuss.pytorch.org/ Feel free to ping me there and I may be able to help. |
Hi, there. I implement an approximate version of matrix logarithm based on the discussion above (support GPU, batch, complex). It works well in my course work (density matrix, eigenvalue between 0 and 1, quantum stuff). you may try it if you need one.
(I am not sure whether it works in general and how accurate it is. Also, its performance maybe not good) add: Sorry for interruption, please ignore this which is not a good implementation (both accuracy and performance) |
For constant integer exponent matrix powers, it may also be implemented / or lowered as a product of power-of-two-exponent matrix powers (corresponding to binary representation of the exponent) #84569 (comment) |
That's currently available in |
Then it's interesting whether it uses https://en.wikipedia.org/wiki/Exponentiation_by_squaring and when it makes sense for lowering (the used memory would be higher)... |
It is implemented this way, yes. Any other way would be just too slow. pytorch/aten/src/ATen/native/LinearAlgebra.cpp Lines 655 to 667 in 777ac63
|
Probably worth mentioning this in the docs explicitly. I also wonder if it's supported for various scripting / onnx export. At least, does it lower x.matrix_power(2) to x @ x? |
Can I BUMP this with another request for torch.matrix_log? |
Uh oh!
There was an error while loading. Please reload this page.
Add matrix power as implemented by numpy.linalg.matrix_power, matrix exponential as implemented by scipy.linalg.expm, and matrix logarithm as implemented by scipy.linalg.logm.
cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @anjali411 @jianyuh @nikitaved @pearu @mruberry @heitorschueroff @walterddr @IvanYashchuk @xwang233 @lezcano @rgommers
The text was updated successfully, but these errors were encountered: