Open
Description
For long term maintenance simplicity, consider replacing the custom cuda kernels redrock.zscan.batch_dot_product_3d3d
and batch_dot_product_3d2d
with einsum
magic as suggested by @dmargala :
For example, batched A.T.dot(A) and A.T.dot(b) would be:
cp.einsum("...ji,...jk", A, A)
cp.einsum("...ji,...j", A, b)
Those aren't a drop-in replacement for the call signature of batch_dot_product_3d3d
, but I think we are using it for that A.T.dot(A)
purpose. Profile test it against current implementation and also check for correctness.
Also consider moving functions like this into redrock.utils
or a separate redrock.linalg
or similar module instead of zscan.