diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 26c2b28..199cc38 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -18,25 +18,17 @@ jobs: python-version: [3.8, '3.11'] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - - name: Install flake8 - run: | - pip install flake8 - - name: Lint with flake8 - run: | - flake8 . --count --show-source --statistics - name: Install dependencies env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: | python -m pip install --upgrade pip pip install wheel - pip install jax jaxlib - pip install torch pip install . - name: Install pytest run: | diff --git a/IDEAS b/IDEAS new file mode 100644 index 0000000..319c896 --- /dev/null +++ b/IDEAS @@ -0,0 +1,10 @@ + +add conjugate class + +add tensor product class ? + + + +implementer: + - symmetric_tensor_power + - anti_symmetric_tensor_power \ No newline at end of file diff --git a/lie_nn/__init__.py b/lie_nn/__init__.py index e6de6fc..6f5aec0 100644 --- a/lie_nn/__init__.py +++ b/lie_nn/__init__.py @@ -1,53 +1,69 @@ __version__ = "0.0.0" -from ._src.rep import Rep, GenericRep, check_representation_triplet -from ._src.irrep import TabulatedIrrep -from ._src.reduced_rep import MulIrrep, ReducedRep +from ._src.rep import ( + Rep, + GenericRep, + Irrep, + TabulatedIrrep, + MulRep, + SumRep, + QRep, + ConjRep, + ReducedRep, +) + +from lie_nn import utils as utils + from ._src.change_basis import change_basis -from ._src.reduce import reduce from ._src.change_algebra import change_algebra -from ._src.tensor_product import tensor_product, tensor_power +from ._src.multiply import multiply from ._src.direct_sum import direct_sum -from ._src.reduced_tensor_product import ( - reduced_tensor_product_basis, # TODO: find a better API - reduced_symmetric_tensor_product_basis, -) -from ._src.clebsch_gordan import clebsch_gordan -from ._src.infer_change_of_basis import infer_change_of_basis from ._src.conjugate import conjugate + +from ._src.infer_change_of_basis import infer_change_of_basis +from ._src.properties import is_unitary, is_irreducible, are_isomorphic + +from ._src.reduce import reduce +from ._src.tensor_product import tensor_product, tensor_power +from ._src.clebsch_gordan import clebsch_gordan from ._src.real import make_explicitly_real, is_real -from ._src.properties import is_unitary from ._src.group_product import group_product -from ._src.is_irreducible import is_irreducible +from ._src.symmetric_tensor_power import symmetric_tensor_power from lie_nn import irreps as irreps -from lie_nn import util as util from lie_nn import finite as finite +from lie_nn import test as test + __all__ = [ "Rep", "GenericRep", - "check_representation_triplet", + "Irrep", "TabulatedIrrep", - "MulIrrep", + "MulRep", + "SumRep", + "QRep", + "ConjRep", "ReducedRep", + "utils", "change_basis", - "reduce", "change_algebra", + "multiply", + "direct_sum", + "conjugate", + "infer_change_of_basis", + "is_unitary", + "is_irreducible", + "are_isomorphic", + "reduce", "tensor_product", "tensor_power", - "direct_sum", - "reduced_tensor_product_basis", - "reduced_symmetric_tensor_product_basis", "clebsch_gordan", - "infer_change_of_basis", - "conjugate", "make_explicitly_real", "is_real", - "is_unitary", "group_product", - "is_irreducible", + "symmetric_tensor_power", "irreps", - "util", "finite", + "test", ] diff --git a/lie_nn/_src/change_algebra.py b/lie_nn/_src/change_algebra.py index 14eb341..c876176 100644 --- a/lie_nn/_src/change_algebra.py +++ b/lie_nn/_src/change_algebra.py @@ -1,9 +1,11 @@ import numpy as np +from multimethod import multimethod -from .rep import GenericRep, Rep +import lie_nn as lie -def change_algebra(rep: Rep, Q: np.ndarray) -> GenericRep: +@multimethod +def change_algebra(rep: lie.Rep, Q: np.ndarray) -> lie.GenericRep: """Apply change of basis to algebra. .. math:: @@ -16,8 +18,23 @@ def change_algebra(rep: Rep, Q: np.ndarray) -> GenericRep: """ iQ = np.linalg.pinv(Q) - return GenericRep( + return lie.GenericRep( A=np.einsum("ia,jb,abc,ck->ijk", Q, Q, rep.A, iQ), X=np.einsum("ia,auv->iuv", Q, rep.X), H=rep.H, ) + + +@multimethod +def change_algebra(rep: lie.QRep, Q: np.ndarray) -> lie.Rep: # noqa: F811 + return lie.change_basis(rep.Q, change_algebra(rep.rep, Q)) + + +@multimethod +def change_algebra(rep: lie.SumRep, Q: np.ndarray) -> lie.Rep: # noqa: F811 + return lie.direct_sum(*[change_algebra(subrep, Q) for subrep in rep.reps]) + + +@multimethod +def change_algebra(rep: lie.MulRep, Q: np.ndarray) -> lie.Rep: # noqa: F811 + return lie.multiply(rep.mul, change_algebra(rep.rep, Q)) diff --git a/lie_nn/_src/change_basis.py b/lie_nn/_src/change_basis.py index d8bb15b..62938db 100644 --- a/lie_nn/_src/change_basis.py +++ b/lie_nn/_src/change_basis.py @@ -1,13 +1,11 @@ import numpy as np -from multipledispatch import dispatch +from multimethod import multimethod -from .irrep import TabulatedIrrep -from .reduced_rep import MulIrrep, ReducedRep -from .rep import GenericRep, Rep +import lie_nn as lie -@dispatch(Rep, object) -def change_basis(rep: Rep, Q: np.ndarray) -> GenericRep: +@multimethod +def change_basis(Q: np.ndarray, rep: lie.Rep) -> lie.QRep: """Apply change of basis to generators. .. math:: @@ -19,29 +17,17 @@ def change_basis(rep: Rep, Q: np.ndarray) -> GenericRep: v' = Q v """ - iQ = np.linalg.pinv(Q) - return GenericRep( - A=rep.algebra(), - X=Q @ rep.continuous_generators() @ iQ, - H=Q @ rep.discrete_generators() @ iQ, - ) + assert Q.shape == (rep.dim, rep.dim), (Q.shape, rep.dim) + if np.allclose(Q.imag, 0.0, atol=1e-10): + Q = Q.real -@dispatch(ReducedRep, object) -def change_basis(rep: ReducedRep, Q: np.ndarray) -> ReducedRep: # noqa: F811 - Q = Q if rep.Q is None else Q @ rep.Q - return ReducedRep(A=rep.A, irreps=rep.irreps, Q=Q) + if np.allclose(Q, np.eye(rep.dim), atol=1e-10): + return rep + return lie.QRep(Q, rep, force=True) -@dispatch(MulIrrep, object) -def change_basis(rep: MulIrrep, Q: np.ndarray) -> ReducedRep: # noqa: F811 - return ReducedRep( - A=rep.algebra(), - irreps=(rep,), - Q=Q, - ) - -@dispatch(TabulatedIrrep, object) -def change_basis(rep: TabulatedIrrep, Q: np.ndarray) -> ReducedRep: # noqa: F811 - return change_basis(MulIrrep(mul=1, rep=rep), Q) +@multimethod +def change_basis(Q: np.ndarray, rep: lie.QRep) -> lie.Rep: # noqa: F811 + return change_basis(Q @ rep.Q, rep.rep) diff --git a/lie_nn/_src/clebsch_gordan.py b/lie_nn/_src/clebsch_gordan.py index 4879b0b..6b8c95f 100644 --- a/lie_nn/_src/clebsch_gordan.py +++ b/lie_nn/_src/clebsch_gordan.py @@ -1,14 +1,14 @@ import numpy as np -from multipledispatch import dispatch +from multimethod import multimethod -from .infer_change_of_basis import infer_change_of_basis -from .irrep import TabulatedIrrep -from .rep import Rep -from .tensor_product import tensor_product +import lie_nn as lie -@dispatch(Rep, Rep, Rep) -def clebsch_gordan(rep1: Rep, rep2: Rep, rep3: Rep, *, round_fn=lambda x: x) -> np.ndarray: + +@multimethod +def clebsch_gordan( + rep1: lie.Rep, rep2: lie.Rep, rep3: lie.Rep, *, round_fn=lambda x: x +) -> np.ndarray: r"""Computes the Clebsch-Gordan coefficient of the triplet (rep1, rep2, rep3). Args: @@ -20,15 +20,19 @@ def clebsch_gordan(rep1: Rep, rep2: Rep, rep3: Rep, *, round_fn=lambda x: x) -> The Clebsch-Gordan coefficient of the triplet (rep1, rep2, rep3). It is an array of shape ``(number_of_paths, rep1.dim, rep2.dim, rep3.dim)``. """ - tp = tensor_product(rep1, rep2) - cg = infer_change_of_basis(tp, rep3, round_fn=round_fn) + tp = lie.tensor_product(rep1, rep2) + cg = lie.infer_change_of_basis(tp, rep3, round_fn=round_fn) cg = cg.reshape((-1, rep3.dim, rep1.dim, rep2.dim)) cg = np.moveaxis(cg, 1, 3) return cg -@dispatch(TabulatedIrrep, TabulatedIrrep, TabulatedIrrep) +@multimethod def clebsch_gordan( # noqa: F811 - rep1: TabulatedIrrep, rep2: TabulatedIrrep, rep3: TabulatedIrrep, *, round_fn=lambda x: x + rep1: lie.TabulatedIrrep, + rep2: lie.TabulatedIrrep, + rep3: lie.TabulatedIrrep, + *, + round_fn=lambda x: x, ) -> np.ndarray: return rep1.clebsch_gordan(rep1, rep2, rep3) diff --git a/lie_nn/_src/clebsch_gordan_test.py b/lie_nn/_src/clebsch_gordan_test.py deleted file mode 100644 index 4be0742..0000000 --- a/lie_nn/_src/clebsch_gordan_test.py +++ /dev/null @@ -1,29 +0,0 @@ -import numpy as np - -import lie_nn as lie - - -def test_cg_irrep(): - Q = np.array([[0, 1, 1], [2, 2, 1], [-3, 1, 0]]) - rep1 = lie.change_basis(lie.irreps.SU2(2), Q) - - Q = np.array([[1, 2, 0], [2, 1, 0], [0, -1, 1.0]]) - rep2 = lie.change_basis(lie.irreps.SU2(2), Q) - - Q = np.array([[1, 0, 0], [1, 1, -1], [0, 0, 1.0]]) - rep3 = lie.change_basis(lie.irreps.SU2(2), Q) - - lie.check_representation_triplet(rep1, rep2, rep3) - - -def test_cg_generic(): - Q = np.array([[0, 1, 1], [2, 2, 1], [-3, 1, 0]]) - rep1 = lie.change_basis(lie.irreps.SO3(1), Q) - - Q = np.random.randn(5, 5) - rep2 = lie.change_basis(lie.irreps.SO3(2), Q) - - Q = np.array([[1, 0, 0], [1, 1, -1], [0, 0, 1.0]]) - rep3 = lie.change_basis(lie.irreps.SO3(1), Q) - - lie.check_representation_triplet(rep1, rep2, rep3) diff --git a/lie_nn/_src/conjugate.py b/lie_nn/_src/conjugate.py index 317be7b..968c2a3 100644 --- a/lie_nn/_src/conjugate.py +++ b/lie_nn/_src/conjugate.py @@ -1,15 +1,24 @@ import numpy as np -from multipledispatch import dispatch +from multimethod import multimethod -from .rep import GenericRep, Rep +import lie_nn as lie -# TODO(mario): Implement conjugate for Irreps +@multimethod +def conjugate(rep: lie.Rep) -> lie.GenericRep: + return lie.ConjRep(rep, force=True) -@dispatch(Rep) -def conjugate(rep: Rep) -> GenericRep: - return GenericRep( - A=rep.A, - X=np.conjugate(rep.X), - H=np.conjugate(rep.H), - ) + +@multimethod +def conjugate(rep: lie.QRep) -> lie.Rep: # noqa: F811 + return lie.change_basis(np.conjugate(rep.Q), conjugate(rep.rep)) + + +@multimethod +def conjugate(rep: lie.SumRep) -> lie.Rep: # noqa: F811 + return lie.direct_sum(*[conjugate(subrep) for subrep in rep.reps]) + + +@multimethod +def conjugate(rep: lie.MulRep) -> lie.Rep: # noqa: F811 + return lie.multiply(rep.mul, conjugate(rep.rep)) diff --git a/lie_nn/_src/direct_sum.py b/lie_nn/_src/direct_sum.py index 02c3b34..1429842 100644 --- a/lie_nn/_src/direct_sum.py +++ b/lie_nn/_src/direct_sum.py @@ -1,16 +1,72 @@ import numpy as np -from multipledispatch import dispatch +from multimethod import multimethod -from .rep import GenericRep, Rep -from .util import direct_sum as _direct_sum +import lie_nn as lie -@dispatch(Rep, Rep) -def direct_sum(rep1: Rep, rep2: Rep) -> GenericRep: - assert np.allclose(rep1.A, rep2.A) # same lie algebra - assert rep1.H.shape[0] == rep2.H.shape[0] # same discrete dimension - return GenericRep( - A=rep1.A, - X=_direct_sum(rep1.X, rep2.X), - H=_direct_sum(rep1.H, rep2.H), +def direct_sum(*reps) -> lie.Rep: + assert len(reps) > 0 + if len(reps) == 1: + return reps[0] + return _direct_sum(reps[0], direct_sum(*reps[1:])) + + +def _chk(r1, r2): + assert np.allclose(r1.A, r2.A, atol=1e-10) + assert len(r1.H) == len(r2.H) + + +@multimethod +def _direct_sum(rep1: lie.Rep, rep2: lie.Rep) -> lie.SumRep: + _chk(rep1, rep2) + return lie.SumRep((rep1, rep2), force=True) + + +@multimethod +def _direct_sum(sumrep: lie.SumRep, rep: lie.Rep) -> lie.SumRep: # noqa: F811 + _chk(sumrep, rep) + return lie.SumRep(sumrep.reps + (rep,), force=True) + + +@multimethod +def _direct_sum(rep: lie.Rep, sumrep: lie.SumRep) -> lie.SumRep: # noqa: F811 + _chk(rep, sumrep) + return lie.SumRep((rep,) + sumrep.reps, force=True) + + +@multimethod +def _direct_sum(sumrep1: lie.SumRep, sumrep2: lie.SumRep) -> lie.SumRep: # noqa: F811 + _chk(sumrep1, sumrep2) + return lie.SumRep(sumrep1.reps + sumrep2.reps, force=True) + + +@multimethod +def _direct_sum(qrep: lie.QRep, rep: lie.Rep) -> lie.Rep: # noqa: F811 + _chk(qrep, rep) + return lie.change_basis( + lie.utils.direct_sum(qrep.Q, np.eye(rep.dim)), direct_sum(qrep.rep, rep) + ) + + +@multimethod +def _direct_sum(rep: lie.Rep, qrep: lie.QRep) -> lie.Rep: # noqa: F811 + _chk(rep, qrep) + return lie.change_basis( + lie.utils.direct_sum(np.eye(rep.dim), qrep.Q), direct_sum(rep, qrep.rep) + ) + + +@multimethod +def _direct_sum(qrep: lie.QRep, rep: lie.SumRep) -> lie.Rep: # noqa: F811 + _chk(qrep, rep) + return lie.change_basis( + lie.utils.direct_sum(qrep.Q, np.eye(rep.dim)), direct_sum(qrep.rep, rep) + ) + + +@multimethod +def _direct_sum(rep: lie.SumRep, qrep: lie.QRep) -> lie.Rep: # noqa: F811 + _chk(rep, qrep) + return lie.change_basis( + lie.utils.direct_sum(np.eye(rep.dim), qrep.Q), direct_sum(rep, qrep.rep) ) diff --git a/lie_nn/_src/discrete_groups/__init__.py b/lie_nn/_src/finite/__init__.py similarity index 100% rename from lie_nn/_src/discrete_groups/__init__.py rename to lie_nn/_src/finite/__init__.py diff --git a/lie_nn/_src/finite.py b/lie_nn/_src/finite/perm.py similarity index 96% rename from lie_nn/_src/finite.py rename to lie_nn/_src/finite/perm.py index a186847..465d96a 100644 --- a/lie_nn/_src/finite.py +++ b/lie_nn/_src/finite/perm.py @@ -1,12 +1,8 @@ -import numpy as np - -# import math -import lie_nn as lie from typing import Tuple +import numpy as np -def _num_transpositions(n: int): - return n * (n - 1) // 2 +import lie_nn as lie def _permutation_matrix(p: Tuple[int, ...]) -> np.ndarray: @@ -63,7 +59,7 @@ def create_trivial(self) -> lie.GenericRep: return lie.GenericRep(A=self.A, X=np.zeros((0, 1, 1)), H=np.ones((len(self.H), 1, 1))) -class Sn_standard(lie.Rep): +class Sn_standard(lie.Irrep): """Standard representation of S(n) Basis for S(m+1): @@ -118,7 +114,7 @@ def create_trivial(self) -> lie.GenericRep: return lie.GenericRep(A=self.A, X=np.zeros((0, 1, 1)), H=np.ones((len(self.H), 1, 1))) -class Sn_trivial(lie.Rep): +class Sn_trivial(lie.Irrep): def __init__(self, n) -> None: super().__init__() self.n = n diff --git a/lie_nn/_src/discrete_groups/perm.py b/lie_nn/_src/finite/sn.py similarity index 100% rename from lie_nn/_src/discrete_groups/perm.py rename to lie_nn/_src/finite/sn.py diff --git a/lie_nn/_src/group_product.py b/lie_nn/_src/group_product.py index e5ad37e..a9a8a54 100644 --- a/lie_nn/_src/group_product.py +++ b/lie_nn/_src/group_product.py @@ -1,13 +1,16 @@ -# Group direct product - -from dataclasses import dataclass from typing import Iterator import numpy as np -from multipledispatch import dispatch +from multimethod import multimethod + +import lie_nn as lie + -from .irrep import TabulatedIrrep -from .rep import GenericRep, Rep +def group_product(*reps) -> lie.Rep: + assert len(reps) > 0 + if len(reps) == 1: + return reps[0] + return _group_product(reps[0], group_product(*reps[1:])) def _get_dtype(*args): @@ -17,8 +20,8 @@ def _get_dtype(*args): return x.dtype -@dispatch(Rep, Rep) -def group_product(rep1: Rep, rep2: Rep) -> GenericRep: +@multimethod +def _group_product(rep1: lie.Rep, rep2: lie.Rep) -> lie.GenericRep: A1 = rep1.A A2 = rep2.A lie_dim = rep1.lie_dim + rep2.lie_dim @@ -42,18 +45,17 @@ def group_product(rep1: Rep, rep2: Rep) -> GenericRep: H[: H1.shape[0]] = np.einsum("aij,kl->aikjl", H1, I2).reshape(H1.shape[0], dim, dim) H[H1.shape[0] :] = np.einsum("ij,akl->aikjl", I1, H2).reshape(H2.shape[0], dim, dim) - return GenericRep(A=A, X=X, H=H) + return lie.GenericRep(A=A, X=X, H=H) -@dispatch(Rep, Rep, Rep) -def group_product(rep1: Rep, rep2: Rep, rep3: Rep) -> GenericRep: # noqa: F811 - return group_product(group_product(rep1, rep2), rep3) +class TabulatedIrrepProduct(lie.TabulatedIrrep): + rep1: lie.TabulatedIrrep + rep2: lie.TabulatedIrrep - -@dataclass(frozen=True) -class TabulatedIrrepProduct(TabulatedIrrep): - rep1: TabulatedIrrep - rep2: TabulatedIrrep + def __init__(self, rep1, rep2) -> None: + super().__init__() + self.rep1 = rep1 + self.rep2 = rep2 @classmethod def from_string(cls, s: str) -> "TabulatedIrrepProduct": @@ -126,13 +128,8 @@ def algebra(rep: "TabulatedIrrepProduct") -> np.ndarray: return A -@dispatch(TabulatedIrrep, TabulatedIrrep) -def group_product(rep1: TabulatedIrrep, rep2: TabulatedIrrep) -> TabulatedIrrep: # noqa: F811 +@multimethod +def _group_product( # noqa: F811 + rep1: lie.TabulatedIrrep, rep2: lie.TabulatedIrrep +) -> lie.TabulatedIrrep: return TabulatedIrrepProduct(rep1, rep2) - - -@dispatch(TabulatedIrrep, TabulatedIrrep, TabulatedIrrep) -def group_product( # noqa: F811 - rep1: TabulatedIrrep, rep2: TabulatedIrrep, rep3: TabulatedIrrep -) -> TabulatedIrrep: - return TabulatedIrrepProduct(TabulatedIrrepProduct(rep1, rep2), rep3) diff --git a/lie_nn/_src/infer_change_of_basis.py b/lie_nn/_src/infer_change_of_basis.py index 21aecab..6d11ec6 100644 --- a/lie_nn/_src/infer_change_of_basis.py +++ b/lie_nn/_src/infer_change_of_basis.py @@ -1,12 +1,12 @@ import numpy as np +from multimethod import multimethod -from .rep import Rep -from .util import infer_change_of_basis as _infer_change_of_basis -# TODO (mario): Can be specialized for ReducedRep +import lie_nn as lie -def infer_change_of_basis(rep1: Rep, rep2: Rep, round_fn=lambda x: x) -> np.ndarray: +@multimethod +def infer_change_of_basis(rep1: lie.Rep, rep2: lie.Rep, *, round_fn=lambda x: x) -> np.ndarray: r"""Infers the change of basis matrix between two representations. .. math:: @@ -26,21 +26,57 @@ def infer_change_of_basis(rep1: Rep, rep2: Rep, round_fn=lambda x: x) -> np.ndar The change of basis matrix ``Q``. """ # Check the group structure - assert np.allclose(rep1.algebra(), rep2.algebra()) + assert np.allclose(rep1.A, rep2.A) - Y1 = np.concatenate([rep1.continuous_generators(), rep1.discrete_generators()]) - Y2 = np.concatenate([rep2.continuous_generators(), rep2.discrete_generators()]) + Y1 = np.concatenate([rep1.X, rep1.H]) + Y2 = np.concatenate([rep2.X, rep2.H]) - A = _infer_change_of_basis(Y2, Y1, round_fn=round_fn) - np.testing.assert_allclose( - np.einsum("aij,bjk->abik", Y2, A), - np.einsum("bij,ajk->abik", A, Y1), - rtol=1e-8, - atol=1e-8, - ) + A = lie.utils.infer_change_of_basis(Y2, Y1, round_fn=round_fn) + # np.testing.assert_allclose( + # np.einsum("aij,bjk->abik", Y2, A), + # np.einsum("bij,ajk->abik", A, Y1), + # rtol=1e-8, + # atol=1e-8, + # ) assert A.dtype in [np.float64, np.complex128] A = A * np.sqrt(rep2.dim) A = round_fn(A) return A + + +@multimethod +def infer_change_of_basis( # noqa: F811 + rep1: lie.QRep, rep2: lie.Rep, *, round_fn=lambda x: x +) -> np.ndarray: + r""" + Q \rho_1 = \rho_2 Q + (Q q) \rho_1 q^{-1} = \rho_2 Q + """ + inv = np.linalg.pinv(rep1.Q) + return infer_change_of_basis(rep1.rep, rep2, round_fn=round_fn) @ inv + + +@multimethod +def infer_change_of_basis( # noqa: F811 + rep1: lie.Rep, rep2: lie.QRep, *, round_fn=lambda x: x +) -> np.ndarray: + r""" + Q \rho_1 = \rho_2 Q + Q \rho_1 = q \rho_2 (q^{-1} Q) + """ + return rep2.Q @ infer_change_of_basis(rep1, rep2.rep, round_fn=round_fn) + + +@multimethod +def infer_change_of_basis( # noqa: F811 + rep1: lie.QRep, rep2: lie.QRep, *, round_fn=lambda x: x +) -> np.ndarray: + r""" + Q \rho_1 = \rho_2 Q + Q q1 \rho_1 q1^{-1} = q2 \rho_2 q2^{-1} Q + (q2^{-1} Q q1) \rho_1 = \rho_2 q2^{-1} Q q1 + """ + inv1 = np.linalg.pinv(rep1.Q) + return rep2.Q @ infer_change_of_basis(rep1.rep, rep2.rep, round_fn=round_fn) @ inv1 diff --git a/lie_nn/_src/irrep.py b/lie_nn/_src/irrep.py deleted file mode 100644 index d8b43c1..0000000 --- a/lie_nn/_src/irrep.py +++ /dev/null @@ -1,43 +0,0 @@ -from dataclasses import dataclass -from typing import Iterator - - -import numpy as np -from .rep import Rep - - -@dataclass(frozen=True) -class TabulatedIrrep(Rep): - @classmethod - def from_string(cls, string: str) -> "TabulatedIrrep": - raise NotImplementedError - - def __mul__(rep1: "TabulatedIrrep", rep2: "TabulatedIrrep") -> Iterator["TabulatedIrrep"]: - # Selection rule - raise NotImplementedError - - @property - def dim(rep: "TabulatedIrrep") -> int: - raise NotImplementedError - - def __lt__(rep1: "TabulatedIrrep", rep2: "TabulatedIrrep") -> bool: - # This is used for sorting the irreps - raise NotImplementedError - - @classmethod - def iterator(cls) -> Iterator["TabulatedIrrep"]: - # Requirements: - # - the first element must be the trivial representation - # - the elements must be sorted by the __lt__ method - raise NotImplementedError - - @classmethod - def create_trivial(cls) -> "TabulatedIrrep": - return cls.iterator().__next__() - - @classmethod - def clebsch_gordan( - cls, rep1: "TabulatedIrrep", rep2: "TabulatedIrrep", rep3: "TabulatedIrrep" - ) -> np.ndarray: - # return an array of shape ``(number_of_paths, rep1.dim, rep2.dim, rep3.dim)`` - raise NotImplementedError diff --git a/lie_nn/_src/irreps/o3_real.py b/lie_nn/_src/irreps/o3_real.py index 8380e57..f9137a6 100644 --- a/lie_nn/_src/irreps/o3_real.py +++ b/lie_nn/_src/irreps/o3_real.py @@ -4,7 +4,7 @@ import numpy as np -from ..irrep import TabulatedIrrep +from ..rep import TabulatedIrrep from .so3_real import SO3 diff --git a/lie_nn/_src/irreps/sl2c.py b/lie_nn/_src/irreps/sl2c.py index 4220a6e..9c6f4b9 100644 --- a/lie_nn/_src/irreps/sl2c.py +++ b/lie_nn/_src/irreps/sl2c.py @@ -5,8 +5,8 @@ import numpy as np -from ..irrep import TabulatedIrrep -from ..util import permutation_sign, vmap +from ..rep import TabulatedIrrep +from ..utils import permutation_sign, vmap from .su2 import SU2, clebsch_gordanSU2mat diff --git a/lie_nn/_src/irreps/so13.py b/lie_nn/_src/irreps/so13.py index 9b32c92..1a350d9 100644 --- a/lie_nn/_src/irreps/so13.py +++ b/lie_nn/_src/irreps/so13.py @@ -7,7 +7,7 @@ import lie_nn as lie from scipy.linalg import sqrtm -from ..irrep import TabulatedIrrep +from ..rep import TabulatedIrrep from .sl2c import SL2C @@ -39,7 +39,7 @@ def __mul__(rep1: "SO13", rep2: "SO13") -> Iterator["SO13"]: def clebsch_gordan(cls, rep1: "SO13", rep2: "SO13", rep3: "SO13") -> np.ndarray: # Call the generic implementation return lie.clebsch_gordan( - lie.GenericRep.from_rep(rep1), rep2, rep3, round_fn=lie.util.round_to_sqrt_rational + lie.GenericRep.from_rep(rep1), rep2, rep3, round_fn=lie.utils.round_to_sqrt_rational ) @property @@ -69,8 +69,8 @@ def continuous_generators(rep: "SO13") -> np.ndarray: X[3:] *= 1j # Make the generators explicitly real, if possible - S = lie.util.infer_change_of_basis( - np.conjugate(X), X, round_fn=lie.util.round_to_sqrt_rational + S = lie.utils.infer_change_of_basis( + np.conjugate(X), X, round_fn=lie.utils.round_to_sqrt_rational ) * np.sqrt(rep.dim) if S.shape[0] == 0: assert rep.l != rep.k @@ -80,7 +80,7 @@ def continuous_generators(rep: "SO13") -> np.ndarray: W = sqrtm(S[0]) iW = np.linalg.inv(W) X = W @ X @ iW - return lie.util.round_to_sqrt_rational(X.real) + return lie.utils.round_to_sqrt_rational(X.real) def algebra(rep=None) -> np.ndarray: # [X_i, X_j] = A_ijk X_k @@ -88,14 +88,14 @@ def algebra(rep=None) -> np.ndarray: # for generators J_0, J_1, J_2, K_0, K_1, K_2 for i, j, k in itertools.permutations((0, 1, 2)): - algebra[i, j, k] = lie.util.permutation_sign((i, j, k)) # [J_i, J_j] = eps_ijk J_k - algebra[3 + i, 3 + j, k] = -lie.util.permutation_sign( + algebra[i, j, k] = lie.utils.permutation_sign((i, j, k)) # [J_i, J_j] = eps_ijk J_k + algebra[3 + i, 3 + j, k] = -lie.utils.permutation_sign( (i, j, k) ) # [K_i, K_j] = -eps_ijk J_k - algebra[i, 3 + j, 3 + k] = lie.util.permutation_sign( + algebra[i, 3 + j, 3 + k] = lie.utils.permutation_sign( (i, j, k) ) # [J_i, K_j] = eps_ijk K_k - algebra[3 + i, j, 3 + k] = lie.util.permutation_sign( + algebra[3 + i, j, 3 + k] = lie.utils.permutation_sign( (i, j, k) ) # [K_i, J_j] = eps_ijk K_k diff --git a/lie_nn/_src/irreps/so13_test.py b/lie_nn/_src/irreps/so13_test.py deleted file mode 100644 index 7fc7bac..0000000 --- a/lie_nn/_src/irreps/so13_test.py +++ /dev/null @@ -1,8 +0,0 @@ -import lie_nn as lie - - -def test_fourvector(): - vec = lie.irreps.SO13.four_vector() - - vec.check_algebra_vs_generators() - lie.check_representation_triplet(vec, vec, vec) diff --git a/lie_nn/_src/irreps/so3_real.py b/lie_nn/_src/irreps/so3_real.py index 1982bca..628b9db 100644 --- a/lie_nn/_src/irreps/so3_real.py +++ b/lie_nn/_src/irreps/so3_real.py @@ -4,7 +4,7 @@ import numpy as np -from ..irrep import TabulatedIrrep +from ..rep import TabulatedIrrep from .su2 import SU2 diff --git a/lie_nn/_src/irreps/su2.py b/lie_nn/_src/irreps/su2.py index 3d156ae..d44c7af 100644 --- a/lie_nn/_src/irreps/su2.py +++ b/lie_nn/_src/irreps/su2.py @@ -5,7 +5,7 @@ import numpy as np -from ..irrep import TabulatedIrrep +from ..rep import TabulatedIrrep @dataclass(frozen=True) diff --git a/lie_nn/_src/irreps/su2_real.py b/lie_nn/_src/irreps/su2_real.py index c1b4f86..b8d0248 100644 --- a/lie_nn/_src/irreps/su2_real.py +++ b/lie_nn/_src/irreps/su2_real.py @@ -3,9 +3,9 @@ from typing import Iterator import numpy as np -from ..util import is_half_integer, is_integer, round_to_sqrt_rational +from ..utils import is_half_integer, is_integer, round_to_sqrt_rational -from ..irrep import TabulatedIrrep +from ..rep import TabulatedIrrep from .su2 import SU2 from lie_nn import clebsch_gordan @@ -27,6 +27,7 @@ def change_basis_real_to_complex(j: float) -> np.ndarray: raise ValueError(f"j={j} is not an integer") +# TODO: move this class somewhere else, and make it not a TabulatedIrrep @dataclass(frozen=True) class SU2Real(TabulatedIrrep): j: float # j is a half-integer diff --git a/lie_nn/_src/irreps/sun.py b/lie_nn/_src/irreps/sun.py index cd0ace2..8fb9afb 100644 --- a/lie_nn/_src/irreps/sun.py +++ b/lie_nn/_src/irreps/sun.py @@ -7,8 +7,8 @@ import numpy as np -from ..irrep import TabulatedIrrep -from ..util import commutator, nullspace, round_to_sqrt_rational +from ..rep import TabulatedIrrep +from ..utils import commutator, nullspace, round_to_sqrt_rational WEIGHT = Tuple[int, ...] GT_PATTERN = Tuple[WEIGHT, ...] diff --git a/lie_nn/_src/irreps/u1.py b/lie_nn/_src/irreps/u1.py new file mode 100644 index 0000000..849e308 --- /dev/null +++ b/lie_nn/_src/irreps/u1.py @@ -0,0 +1,58 @@ +import itertools +from dataclasses import dataclass +from typing import Iterator + +import numpy as np + +from ..rep import TabulatedIrrep + + +@dataclass(frozen=True) +class U1(TabulatedIrrep): + m: int + + def __post_init__(self): + assert isinstance(self.m, int) + assert self.m >= 0 + + @classmethod + def from_string(cls, string: str) -> "U1": + return cls(m=int(string)) + + def __mul__(rep1: "U1", rep2: "U1") -> Iterator["U1"]: + assert isinstance(rep2, U1) + return [U1(m=rep1.m + rep2.m)] + + @classmethod + def clebsch_gordan(cls, rep1: "U1", rep2: "U1", rep3: "U1") -> np.ndarray: + # return an array of shape ``(number_of_paths, rep1.dim, rep2.dim, rep3.dim)`` + if rep3 in rep1 * rep2: + return np.ones((1, 1, 1, 1)) + else: + return np.zeros((0, 1, 1, 1)) + + @property + def dim(rep: "U1") -> int: + return 1 + + def is_scalar(rep: "U1") -> bool: + """Equivalent to ``j == 0``""" + return rep.m == 0 + + def __lt__(rep1: "U1", rep2: "U1") -> bool: + return rep1.m < rep2.m + + @classmethod + def iterator(cls) -> Iterator["U1"]: + for m in itertools.count(0): + yield U1(m=m) + + def discrete_generators(rep: "U1") -> np.ndarray: + return np.zeros((0, rep.dim, rep.dim)) + + def continuous_generators(rep: "U1") -> np.ndarray: + return np.array([[[1j * rep.m]]]) + + def algebra(rep=None) -> np.ndarray: + # [X_i, X_j] = A_ijk X_k + return np.zeros((1, 1, 1)) diff --git a/lie_nn/_src/irreps/z2.py b/lie_nn/_src/irreps/z2.py index e3c2268..154c2e6 100644 --- a/lie_nn/_src/irreps/z2.py +++ b/lie_nn/_src/irreps/z2.py @@ -3,7 +3,7 @@ import numpy as np -from ..irrep import TabulatedIrrep +from ..rep import TabulatedIrrep @dataclass(frozen=True) diff --git a/lie_nn/_src/is_irreducible.py b/lie_nn/_src/is_irreducible.py deleted file mode 100644 index 43662df..0000000 --- a/lie_nn/_src/is_irreducible.py +++ /dev/null @@ -1,9 +0,0 @@ -import numpy as np - -from .rep import Rep -from .util import is_irreducible as _is_irreducible - - -def is_irreducible(rep: Rep, *, epsilon: float = 1e-10) -> bool: - """Returns True if the representation is irreducible.""" - return _is_irreducible(np.concatenate([rep.X, rep.H], axis=0), epsilon=epsilon) diff --git a/lie_nn/_src/jax_utils.py b/lie_nn/_src/jax_utils.py deleted file mode 100644 index 4e8fee8..0000000 --- a/lie_nn/_src/jax_utils.py +++ /dev/null @@ -1,51 +0,0 @@ -from dataclasses import dataclass -import jax -import jax.numpy as jnp -from .irrep import TabulatedIrrep - - -def static_jax_pytree(cls): - cls = dataclass(frozen=True)(cls) - jax.tree_util.register_pytree_node(cls, lambda x: ((), x), lambda x, _: x) - return cls - - -@jax.jit -def matrix_power(F, n): - upper_limit = 32 - init_carry = n, F, jnp.eye(F.shape[0]) - - def body(carry, _): - # One step of the iteration - n, z, result = carry - new_n, bit = jnp.divmod(n, 2) - - new_result = jax.lax.cond(bit, lambda x: z @ x, lambda x: x, result) - - # No more computation necessary if n = 0 - # Is there a better way to early break rather than just returning something empty? - new_z = jax.lax.cond(new_n, lambda z: z @ z, lambda _: jnp.empty(z.shape), z) - - return (new_n, new_z, new_result), None - - result = jax.lax.cond( - n == 1, - lambda _: F, - lambda _: jax.lax.scan(body, init_carry, None, length=upper_limit)[0][2], - None, - ) - - return result - - -def exp_map( - rep: "TabulatedIrrep", continuous_params: jnp.ndarray, discrete_params: jnp.ndarray -) -> jnp.ndarray: - # return a matrix of shape ``(rep.dim, rep.dim)`` - discrete = jax.vmap(matrix_power)(rep.discrete_generators(), discrete_params) - output = jax.scipy.linalg.expm( - jnp.einsum("a,aij->ij", continuous_params, rep.continuous_generators()) - ) - for x in reversed(discrete): - output = x @ output - return output diff --git a/lie_nn/_src/multiply.py b/lie_nn/_src/multiply.py new file mode 100644 index 0000000..c271074 --- /dev/null +++ b/lie_nn/_src/multiply.py @@ -0,0 +1,37 @@ +from multimethod import multimethod + +import lie_nn as lie + +import numpy as np + + +@multimethod +def multiply(mul: int, rep: lie.Rep) -> lie.Rep: + if mul == 1: + return rep + + return lie.MulRep(mul, rep, force=True) + + +@multimethod +def multiply(mul: int, mulrep: lie.MulRep) -> lie.Rep: # noqa: F811 + return multiply(mul * mulrep.mul, mulrep.rep) + + +@multimethod +def multiply(mul: int, qrep: lie.QRep) -> lie.Rep: # noqa: F811 + return lie.change_basis(lie.utils.direct_sum(*(qrep.Q,) * mul), multiply(mul, qrep.rep)) + + +@multimethod +def multiply(mul: int, sumrep: lie.SumRep) -> lie.Rep: # noqa: F811 + Q = np.zeros((mul, sumrep.dim, mul * sumrep.dim)) + k = 0 + j = 0 + for subrep in sumrep.reps: + for u in range(mul): + Q[u, k : k + subrep.dim, j : j + subrep.dim] = np.eye(subrep.dim) + j += subrep.dim + k += subrep.dim + Q = Q.reshape(mul * sumrep.dim, mul * sumrep.dim) + return lie.change_basis(Q, lie.direct_sum(*[multiply(mul, subrep) for subrep in sumrep.reps])) diff --git a/lie_nn/_src/properties.py b/lie_nn/_src/properties.py index 9e22991..f496d27 100644 --- a/lie_nn/_src/properties.py +++ b/lie_nn/_src/properties.py @@ -1,13 +1,94 @@ import numpy as np +from multimethod import multimethod + import lie_nn as lie +# is_irreducible: + + +@multimethod +def is_irreducible(rep: lie.Rep, *, epsilon: float = 1e-10) -> bool: + """Returns True if the representation is irreducible.""" + return lie.utils.is_irreducible(np.concatenate([rep.X, rep.H], axis=0), epsilon=epsilon) + + +@multimethod +def is_irreducible(rep: lie.ConjRep, *, epsilon: float = 1e-10) -> bool: # noqa: F811 + return is_irreducible(rep.rep, epsilon=epsilon) + + +@multimethod +def is_irreducible(rep: lie.MulRep, *, epsilon: float = 1e-10) -> bool: # noqa: F811 + return is_irreducible(rep.rep, epsilon=epsilon) + + +@multimethod +def is_irreducible(rep: lie.SumRep, *, epsilon: float = 1e-10) -> bool: # noqa: F811 + return all(is_irreducible(subrep, epsilon=epsilon) for subrep in rep.reps) + + +@multimethod +def is_irreducible(rep: lie.QRep, *, epsilon: float = 1e-10) -> bool: # noqa: F811 + return is_irreducible(rep.rep, epsilon=epsilon) + + +@multimethod +def is_irreducible(rep: lie.Irrep, *, epsilon: float = 1e-10) -> bool: # noqa: F811 + return True + + +# is_unitary: + -def is_unitary(rep: lie.Rep) -> bool: +@multimethod +def is_unitary(rep: lie.Rep, *, epsilon: float = 1e-10) -> bool: X = rep.continuous_generators() H = rep.discrete_generators() - H_unit = np.allclose(H @ np.conj(np.transpose(H, (0, 2, 1))), np.eye(rep.dim), atol=1e-13) # exp(X) @ exp(X^H) = 1 # X + X^H = 0 - X_unit = np.allclose(X + np.conj(np.transpose(X, (0, 2, 1))), 0, atol=1e-13) + H_unit = np.allclose(H @ np.conj(np.transpose(H, (0, 2, 1))), np.eye(rep.dim), atol=epsilon) + X_unit = np.allclose(X + np.conj(np.transpose(X, (0, 2, 1))), 0, atol=epsilon) return H_unit and X_unit + + +@multimethod +def is_unitary(rep: lie.SumRep) -> bool: # noqa: F811 + return all(is_unitary(subrep) for subrep in rep.reps) + + +@multimethod +def is_unitary(rep: lie.MulRep) -> bool: # noqa: F811 + return is_unitary(rep.rep) + + +# are_isomorphic: + + +@multimethod +def are_isomorphic(rep1: lie.Rep, rep2: lie.Rep, *, epsilon: float = 1e-10) -> bool: + if isinstance(rep1, lie.MulRep) and isinstance(rep2, lie.MulRep) and rep1.mul == rep2.mul: + return are_isomorphic(rep1.rep, rep2.rep, epsilon=epsilon) + + return lie.utils.are_isomorphic( + np.concatenate([rep1.X, rep1.H], axis=0), + np.concatenate([rep2.X, rep2.H], axis=0), + epsilon=epsilon, + ) + + +@multimethod +def are_isomorphic( # noqa: F811 + rep1: lie.ConjRep, rep2: lie.ConjRep, *, epsilon: float = 1e-10 +) -> bool: + return are_isomorphic(rep1.rep, rep2.rep, epsilon=epsilon) + + +@multimethod +def are_isomorphic(rep1: lie.QRep, rep2: lie.Rep, *, epsilon: float = 1e-10) -> bool: # noqa: F811 + return are_isomorphic(rep1.rep, rep2, epsilon=epsilon) + + +@multimethod +def are_isomorphic(rep1: lie.Rep, rep2: lie.QRep, *, epsilon: float = 1e-10) -> bool: # noqa: F811 + return are_isomorphic(rep1, rep2.rep, epsilon=epsilon) diff --git a/lie_nn/_src/reduce.py b/lie_nn/_src/reduce.py index 09d93e7..7cf1d00 100644 --- a/lie_nn/_src/reduce.py +++ b/lie_nn/_src/reduce.py @@ -1,44 +1,113 @@ import numpy as np -from multipledispatch import dispatch +from multimethod import multimethod -from .infer_change_of_basis import infer_change_of_basis -from .reduced_rep import MulIrrep, ReducedRep -from .rep import GenericRep, Rep -from .util import decompose_rep_into_irreps +import lie_nn as lie -@dispatch(MulIrrep) -def reduce(rep: MulIrrep) -> ReducedRep: - return ReducedRep( - A=rep.algebra(), - irreps=(rep,), - Q=None, - ) +@multimethod +def reduce(rep: lie.ConjRep) -> lie.ReducedRep: # noqa: F811 + red = reduce(rep.rep) + reps = tuple((mul, lie.conjugate(ir)) for mul, ir in red.reps) + return lie.ReducedRep(A=rep.A, num_H=len(rep.H), Q=red.Q.conj(), reps=reps, force=True) + + +@multimethod +def reduce(rep: lie.MulRep) -> lie.ReducedRep: # noqa: F811 + red = reduce(rep.rep) + reps = tuple((rep.mul * mul, ir) for mul, ir in red.reps) + Q = np.concatenate([np.repeat(q, rep.mul, axis=1) for q in red.split_Q()], axis=1) + return lie.ReducedRep(A=rep.A, num_H=len(rep.H), Q=Q, reps=reps, force=True) + + +@multimethod +def reduce(rep: lie.QRep) -> lie.ReducedRep: # noqa: F811 + red = reduce(rep.rep) + return lie.ReducedRep(A=rep.A, num_H=len(rep.H), Q=rep.Q @ red.Q, reps=red.reps, force=True) -@dispatch(ReducedRep) -def reduce(rep: ReducedRep) -> ReducedRep: # noqa: F811 - # TODO if we change ReducedRep into SumRep, then reduce its constituents - return rep +@multimethod +def reduce(rep: lie.SumRep) -> lie.ReducedRep: # noqa: F811 + if rep.dim == 0: + return lie.ReducedRep( + A=rep.A, num_H=len(rep.H), Q=np.eye(rep.dim), reps=((1, rep),), force=True + ) + reds = [reduce(subrep) for subrep in rep.reps] + Q = lie.utils.direct_sum(*[red.Q for red in reds]) + mulirs = sum([red.reps for red in reds], ()) + blocks = [] + i = 0 + for mul, ir in mulirs: + Qi = Q[:, i : i + mul * ir.dim] + blocks.append((mul, ir, Qi)) + i += mul * ir.dim + blocks.sort(key=lambda x: x[1].dim) + merged_blocks = [] + + while len(blocks) > 0: + mul, ir, Qi = blocks.pop(0) + j = len(blocks) - 1 + while j >= 0: + mul2, ir2, Qi2 = blocks[j] + if lie.are_isomorphic(ir, ir2): + mul += mul2 + q = lie.infer_change_of_basis(ir, ir2) # q ir = ir2 q + assert len(q) == 1 + Qi = np.concatenate((Qi, Qi2 @ q[0]), axis=1) + blocks.pop(j) + j -= 1 + merged_blocks.append((mul, ir, Qi)) + + Q = np.concatenate([Qi for _, _, Qi in merged_blocks], axis=1) + mulirs = tuple((mul, ir) for mul, ir, _ in merged_blocks) + return lie.ReducedRep(A=rep.A, num_H=len(rep.H), Q=Q, reps=mulirs, force=True) + + +@multimethod +def reduce(rep: lie.Irrep) -> lie.ReducedRep: # noqa: F811 + return lie.ReducedRep( + A=rep.A, num_H=len(rep.H), Q=np.eye(rep.dim), reps=((1, rep),), force=True + ) -@dispatch(Rep) -def reduce(rep: Rep) -> ReducedRep: # noqa: F811 + +@multimethod +def reduce(rep: lie.Rep) -> lie.ReducedRep: # noqa: F811 r"""Reduce an unknown representation to a reduced form. This operation is slow and should be avoided if possible. """ - Ys = decompose_rep_into_irreps(np.concatenate([rep.X, rep.H])) - Ys = sorted(Ys, key=lambda x: x.shape[1]) - d = rep.lie_dim - Qs = [] - irs = [] - for Y in Ys: - ir = GenericRep(rep.A, Y[:d], Y[d:]) - Q = infer_change_of_basis(ir, rep) - Q = np.einsum("mij->imj", Q).reshape((rep.dim, ir.dim)) - Qs.append(Q) - irs.append(ir) - - Q = np.concatenate(Qs, axis=1) - rep = ReducedRep(rep.A, tuple(MulIrrep(1, ir) for ir in irs), Q) - return rep + + def try_reduce(): + Ys = lie.utils.decompose_rep_into_irreps(np.concatenate([rep.X, rep.H])) + d = rep.lie_dim + Qs = [] + irs = [] + for mul, Y in Ys: + ir = lie.GenericRep(rep.A, Y[:d], Y[d:]) + Q = lie.infer_change_of_basis(ir, rep) + if len(Q) != mul: + return None + + Q = np.einsum("mij->imj", Q).reshape((rep.dim, mul * ir.dim)) + Qs.append(Q) + mul_ir = (mul, ir) + irs.append(mul_ir) + + Q = np.concatenate(Qs, axis=1) + if np.allclose(Q, np.eye(rep.dim), atol=1e-10): + Q = np.eye(rep.dim) + + return lie.ReducedRep(A=rep.A, num_H=len(rep.H), Q=Q, reps=tuple(irs), force=True) + + import time + + t = time.time() + for n in range(100): + red = try_reduce() + if red is not None: + return red + if time.time() - t > 1: + break + + raise ValueError( + f"Could not reduce representation after {time.time() - t} seconds and {n} tries." + ) diff --git a/lie_nn/_src/reduce_test.py b/lie_nn/_src/reduce_test.py deleted file mode 100644 index 8a3ba26..0000000 --- a/lie_nn/_src/reduce_test.py +++ /dev/null @@ -1,12 +0,0 @@ -import numpy as np -import lie_nn as lie - - -def test_reduce(): - rep1 = lie.change_basis(lie.irreps.SU2(2), np.random.randn(3, 3)) - rep2 = lie.tensor_product(rep1, rep1) - rep2 = lie.GenericRep.from_rep(rep2) - - rep3 = lie.reduce(rep2) - - np.testing.assert_allclose(rep2.X, rep3.X, atol=1e-5) diff --git a/lie_nn/_src/reduced_rep.py b/lie_nn/_src/reduced_rep.py deleted file mode 100644 index 98eb03c..0000000 --- a/lie_nn/_src/reduced_rep.py +++ /dev/null @@ -1,140 +0,0 @@ -import dataclasses -from typing import Optional, Tuple, Union, Type - -import numpy as np - -from .rep import Rep -from .irrep import TabulatedIrrep -from .util import direct_sum - - -@dataclasses.dataclass -class MulIrrep(Rep): - mul: int - rep: Rep - - @classmethod - def from_string(cls, string: str, cls_irrep: Type[Rep]) -> "MulIrrep": - if "x" in string: - mul, rep = string.split("x") - else: - mul, rep = 1, string - return cls(mul=int(mul), rep=cls_irrep.from_string(rep)) - - @property - def dim(self) -> int: - return self.mul * self.rep.dim - - def algebra(self) -> np.ndarray: - return self.rep.algebra() - - def continuous_generators(self) -> np.ndarray: - X = self.rep.continuous_generators() - if X.shape[0] == 0: - return np.empty((0, self.dim, self.dim)) - return np.stack([direct_sum(*[x for _ in range(self.mul)]) for x in X], axis=0) - - def discrete_generators(self) -> np.ndarray: - H = self.rep.discrete_generators() - if H.shape[0] == 0: - return np.empty((0, self.dim, self.dim)) - return np.stack([direct_sum(*[x for _ in range(self.mul)]) for x in H], axis=0) - - def create_trivial(self) -> Rep: - return self.rep.create_trivial() - - def __repr__(self) -> str: - return f"{self.mul}x{self.rep}" - - -class ReducedRep(Rep): - r"""Representation of the form - - .. math:: - Q (\osum_i m_i \rho_i ) Q^{-1} - """ - _A: np.ndarray - irreps: Tuple[MulIrrep, ...] - Q: Optional[np.ndarray] # change of basis matrix - - def __init__(self, A: np.ndarray, irreps: Tuple[MulIrrep, ...], Q: Optional[np.ndarray] = None): - self._A = A - self.irreps = irreps - self.Q = Q - - @classmethod - def from_string( - cls, string: str, cls_irrep: Type[TabulatedIrrep], Q: Optional[np.ndarray] = None - ) -> "ReducedRep": - return cls.from_irreps( - [MulIrrep.from_string(term, cls_irrep) for term in string.split("+")], Q - ) - - @classmethod - def from_irreps( - cls, - mul_irreps: Tuple[Union[Rep, Tuple[int, Rep], MulIrrep], ...], - Q: Optional[np.ndarray] = None, - ) -> "ReducedRep": - A = None - irreps = [] - - for mul_ir in mul_irreps: - if isinstance(mul_ir, tuple): - mul_ir = MulIrrep(mul=mul_ir[0], rep=mul_ir[1]) - elif isinstance(mul_ir, MulIrrep): - pass - elif isinstance(mul_ir, Rep): - mul_ir = MulIrrep(mul=1, rep=mul_ir) - - assert mul_ir.mul >= 0 - assert isinstance(mul_ir.rep, Rep) - - irreps += [mul_ir] - - if A is None: - A = mul_ir.algebra() - else: - assert np.allclose(A, mul_ir.algebra()) - - dim = sum(mul_ir.dim for mul_ir in irreps) - assert Q is None or Q.shape == (dim, dim) - - return cls(A=A, irreps=irreps, Q=Q) - - @property - def dim(self) -> int: - return sum(irrep.dim for irrep in self.irreps) - - def algebra(self) -> np.ndarray: - return self._A - - def continuous_generators(self) -> np.ndarray: - Xs = [] - for i in range(self.lie_dim): - X = direct_sum(*[mul_ir.continuous_generators()[i] for mul_ir in self.irreps]) - if self.Q is not None: - X = self.Q @ X @ np.linalg.inv(self.Q) - Xs += [X] - return np.stack(Xs) - - def discrete_generators(self) -> np.ndarray: - n = self.irreps[0].discrete_generators().shape[0] # TODO: support empty irreps - if n == 0: - return np.empty((0, self.dim, self.dim)) - Xs = [] - for i in range(n): - X = direct_sum(*[mul_ir.discrete_generators()[i] for mul_ir in self.irreps]) - if self.Q is not None: - X = self.Q @ X @ np.linalg.inv(self.Q) - Xs += [X] - return np.stack(Xs) - - def create_trivial(self) -> Rep: - return self.irreps[0].create_trivial() - - def __repr__(self) -> str: - r = " + ".join(repr(mul_ir) for mul_ir in self.irreps) - if self.Q is not None: - r = f"Q ({r}) Q^{-1}" - return r diff --git a/lie_nn/_src/reduced_tensor_product.py b/lie_nn/_src/reduced_tensor_product.py deleted file mode 100644 index 8d113f5..0000000 --- a/lie_nn/_src/reduced_tensor_product.py +++ /dev/null @@ -1,453 +0,0 @@ -""" -History of the different versions of the code: -- Initially developed by Mario Geiger in `e3nn` -- Ported in julia by Song Kim https://github.com/songk42/ReducedTensorProduct.jl -- Ported in `e3nn-jax` by Mario Geiger -- Ported in `lie_nn` by Mario Geiger and Ilyes Batatia -""" -import functools -import itertools -from typing import FrozenSet, List, Optional, Tuple, Union - -from lie_nn import TabulatedIrrep, ReducedRep, MulIrrep, Rep -import numpy as np -import lie_nn._src.discrete_groups.perm as perm -from .util import basis_intersection, round_to_sqrt_rational, prod -from typing import NamedTuple - - -class RepArray(NamedTuple): - rep: Rep - array: np.ndarray - list: List[np.ndarray] - - -class IrrepsArray: - irreps: Tuple[MulIrrep, ...] - list: List[np.ndarray] - - @property - def array(self): - return np.concatenate([np.reshape(x, x.shape[:-2] + (-1,)) for x in self.list], axis=-1) - - def __init__(self, *, irreps, list): - assert len(irreps) == len(list) - shapes = [] - for mul_ir, x in zip(irreps, list): - assert x.shape[-2] == mul_ir.mul - assert x.shape[-1] == mul_ir.rep.dim - shapes.append(x.shape[:-2]) - assert len(set(shapes)) == 1 - self.irreps = tuple(irreps) - self.list = list - - def sorted(self): - indices = list(range(len(self.irreps))) - indices = sorted(indices, key=lambda i: (self.irreps[i].rep, self.irreps[i].mul)) - return IrrepsArray( - list=[self.list[i] for i in indices], - irreps=tuple(self.irreps[i] for i in indices), - ) - - def simplify(self): - muls = [] - irreps = [] - list = [] - - for i, mul_irrep in enumerate(self.irreps): - mul, irrep = mul_irrep.mul, mul_irrep.rep - x = self.list[i] - - if i == 0 or irrep != irreps[-1]: - irreps.append(irrep) - list.append(x) - muls.append(mul) - else: - list[-1] = np.concatenate([list[-1], x], axis=-2) - muls[-1] += mul - return IrrepsArray( - list=list, irreps=tuple(MulIrrep(mul, irrep) for mul, irrep in zip(muls, irreps)) - ) - - def reshape(self, shape): - assert shape[-1] == -1 or shape[-1] == sum(mul_irrep.mul for mul_irrep in self.irreps) - - x_list = [] - for mul_irrep, x in zip(self.irreps, self.list): - mul, irrep = mul_irrep.mul, mul_irrep.rep - x_list.append(x.reshape(shape[:-1] + (mul, irrep.dim))) - - return IrrepsArray(list=x_list, irreps=self.irreps) - - -def _to_reducedrep(irreps) -> ReducedRep: - if isinstance(irreps, TabulatedIrrep): - irreps = MulIrrep(1, irreps) - if isinstance(irreps, MulIrrep): - irreps = ReducedRep.from_irreps([irreps]) - assert isinstance(irreps, ReducedRep) - return irreps - - -def reduced_tensor_product_basis( - formula_or_irreps_list: Union[str, List[ReducedRep]], - *, - epsilon: float = 1e-5, - **irreps_dict, -) -> RepArray: - r"""Reduce a tensor product of multiple irreps subject - to some permutation symmetry given by a formula. - - Args: - formula_or_irreps_list (str or list of Irreps): a formula - of the form ``ijk=jik=ikj`` or ``ijk=-jki``. - The left hand side is the original formula and the right hand side are - the signed permutations. - If no index symmetry is present, a list of irreps can be given instead. - - epsilon (float): the tolerance for the Gram-Schmidt orthogonalization. Default: ``1e-5`` - irreps_dict (dict): the irreps of each index of the formula. For instance ``i="1x1o"``. - - Returns: - RepArray: The change of basis - The shape is ``(d1, ..., dn, irreps_out.dim)`` - where ``di`` is the dimension of the index ``i`` and ``n`` - is the number of indices in the formula. - """ - - if isinstance(formula_or_irreps_list, (tuple, list)): - irreps_list = formula_or_irreps_list - irreps_tuple = tuple(_to_reducedrep(irreps) for irreps in irreps_list) - formulas: FrozenSet[Tuple[int, Tuple[int, ...]]] = frozenset( - {(1, tuple(range(len(irreps_tuple))))} - ) - out = _reduced_tensor_product_basis(irreps_tuple, formulas, epsilon) - return RepArray(ReducedRep.from_irreps(out.irreps), out.array, out.list) - - formula = formula_or_irreps_list - f0, perm_repr = germinate_perm_repr(formula) - - irreps_dict = {i: _to_reducedrep(irs) for i, irs in irreps_dict.items()} - - for i in irreps_dict: - if len(i) != 1: - raise TypeError(f"got an unexpected keyword argument '{i}'") - - for _sign, p in perm_repr: - f = "".join(f0[i] for i in p) - for i, j in zip(f0, f): - if i in irreps_dict and j in irreps_dict and irreps_dict[i] != irreps_dict[j]: - raise RuntimeError(f"irreps of {i} and {j} should be the same") - if i in irreps_dict: - irreps_dict[j] = irreps_dict[i] - if j in irreps_dict: - irreps_dict[i] = irreps_dict[j] - - for i in f0: - if i not in irreps_dict: - raise RuntimeError(f"index {i} has no irreps associated to it") - - for i in irreps_dict: - if i not in f0: - raise RuntimeError(f"index {i} has an irreps but does not appear in the fomula") - - irreps_tuple = tuple(irreps_dict[i] for i in f0) - - out = _reduced_tensor_product_basis(irreps_tuple, perm_repr, epsilon) - return RepArray(ReducedRep.from_irreps(out.irreps), out.array, out.list) - - -def reduced_symmetric_tensor_product_basis( - irreps: ReducedRep, - order: int, - *, - epsilon: float = 1e-5, -) -> RepArray: - r"""Reduce a symmetric tensor product. - - Args: - irreps (Irreps): the irreps of each index. - order (int): the order of the tensor product. i.e. the number of indices. - - Returns: - RepArray: The change of basis - The shape is ``(d, ..., d, irreps_out.dim)`` - where ``d`` is the dimension of ``irreps``. - """ - irreps = _to_reducedrep(irreps) - perm_repr: FrozenSet[Tuple[int, Tuple[int, ...]]] = frozenset( - (1, p) for p in itertools.permutations(range(order)) - ) - out = _reduced_tensor_product_basis(tuple([irreps] * order), perm_repr, epsilon) - return RepArray(ReducedRep.from_irreps(out.irreps), out.array, out.list) - - -# @functools.lru_cache(maxsize=None) -def _reduced_tensor_product_basis( - irreps_tuple: Tuple[ReducedRep, ...], - perm_repr: FrozenSet[Tuple[int, Tuple[int, ...]]], - epsilon: float, -) -> IrrepsArray: - dims = tuple(irps.dim for irps in irreps_tuple) - - def get_initial_basis(reduced_rep: ReducedRep, i: int) -> List[np.ndarray]: - x = np.reshape( - np.eye(reduced_rep.dim) if reduced_rep.Q is None else np.linalg.inv(reduced_rep.Q).T, - (1,) * i + (reduced_rep.dim,) + (1,) * (len(irreps_tuple) - i - 1) + (reduced_rep.dim,), - ) - x_list = [] - cursor = 0 - for mul_ir in reduced_rep.irreps: - mul, ir = mul_ir.mul, mul_ir.rep - x_list.append( - x[..., cursor : cursor + mul * ir.dim].reshape(x.shape[:-1] + (mul, ir.dim)) - ) - cursor += mul * ir.dim - return x_list - - bases = [ - ( - frozenset({i}), - IrrepsArray(list=get_initial_basis(reduced_rep, i), irreps=reduced_rep.irreps), - ) - for i, reduced_rep in enumerate(irreps_tuple) - ] - - while True: - if len(bases) == 1: - f, b = bases[0] - assert f == frozenset(range(len(irreps_tuple))) - return b.sorted().simplify() - - if len(bases) == 2: - (fa, a) = bases[0] - (fb, b) = bases[1] - f = frozenset(fa | fb) - ab = reduce_basis_product(a, b) - if len(subrepr_permutation(f, perm_repr)) == 1: - return ab.sorted().simplify() - p = reduce_subgroup_permutation(f, perm_repr, dims) - ab = constrain_rotation_basis_by_permutation_basis( - ab, p, epsilon=epsilon, round_fn=round_to_sqrt_rational - ) - return ab.sorted().simplify() - - # greedy algorithm - min_p = np.inf - best = None - - for i in range(len(bases)): - for j in range(i + 1, len(bases)): - (fa, _) = bases[i] - (fb, _) = bases[j] - f = frozenset(fa | fb) - p_dim = reduce_subgroup_permutation(f, perm_repr, dims, return_dim=True) - if p_dim < min_p: - min_p = p_dim - best = (i, j, f) - - i, j, f = best - del bases[j] - del bases[i] - sub_irreps = tuple(irreps_tuple[i] for i in f) - sub_perm_repr = subrepr_permutation(f, perm_repr) - ab = _reduced_tensor_product_basis(sub_irreps, sub_perm_repr, epsilon) - ab = ab.reshape(tuple(dims[i] if i in f else 1 for i in range(len(dims))) + (-1,)) - bases = [(f, ab)] + bases - - -@functools.lru_cache(maxsize=None) -def germinate_perm_repr(formula: str) -> Tuple[str, FrozenSet[Tuple[int, Tuple[int, ...]]]]: - """Convert the formula (generators) into a group.""" - formulas = [(-1 if f.startswith("-") else 1, f.replace("-", "")) for f in formula.split("=")] - s0, f0 = formulas[0] - assert s0 == 1 - - for _s, f in formulas: - if len(set(f)) != len(f) or set(f) != set(f0): - raise RuntimeError(f"{f} is not a permutation of {f0}") - if len(f0) != len(f): - raise RuntimeError(f"{f0} and {f} don't have the same number of indices") - - # `perm_repr` is a list of (sign, permutation of indices) - # each formula can be viewed as a permutation of the original formula - perm_repr = { - (s, tuple(f.index(i) for i in f0)) for s, f in formulas - } # set of generators (permutations) - - # they can be composed, for instance if you have ijk=jik=ikj - # you also have ijk=jki - # applying all possible compositions creates an entire group - while True: - n = len(perm_repr) - perm_repr = perm_repr.union([(s, perm.inverse(p)) for s, p in perm_repr]) - perm_repr = perm_repr.union( - [(s1 * s2, perm.compose(p1, p2)) for s1, p1 in perm_repr for s2, p2 in perm_repr] - ) - if len(perm_repr) == n: - break # we break when the set is stable => it is now a group \o/ - - return f0, frozenset(perm_repr) - - -def reduce_basis_product( - basis1: IrrepsArray, - basis2: IrrepsArray, - filter_ir_out: Optional[List[TabulatedIrrep]] = None, -) -> IrrepsArray: - """Reduce the product of two basis.""" - basis1 = basis1.sorted().simplify() - basis2 = basis2.sorted().simplify() - - new_irreps: List[Tuple[int, TabulatedIrrep]] = [] - new_list: List[np.ndarray] = [] - - for mul_ir1, x1 in zip(basis1.irreps, basis1.list): - mul1, ir1 = mul_ir1.mul, mul_ir1.rep - for mul_ir2, x2 in zip(basis2.irreps, basis2.list): - mul2, ir2 = mul_ir2.mul, mul_ir2.rep - for ir in ir1 * ir2: - if filter_ir_out is not None and ir not in filter_ir_out: - continue - - cg = ir.clebsch_gordan(ir1, ir2, ir) - x = np.einsum( - "...ui,...vj,wijk->...wuvk", - x1, - x2, - cg, - ) - x = np.reshape(x, x.shape[:-4] + (cg.shape[0] * mul1 * mul2, ir.dim)) - new_irreps.append((cg.shape[0] * mul1 * mul2, ir)) - new_list.append(x) - - new = IrrepsArray(irreps=tuple(MulIrrep(mul, ir) for mul, ir in new_irreps), list=new_list) - return new.sorted().simplify() - - -def constrain_rotation_basis_by_permutation_basis( - rotation_basis: IrrepsArray, - permutation_basis: np.ndarray, - *, - epsilon=1e-5, - round_fn=lambda x: x, -) -> IrrepsArray: - """Constrain a rotation basis by a permutation basis. - - Args: - rotation_basis (e3nn.IrrepsArray): A rotation basis - permutation_basis (np.ndarray): A permutation basis - - Returns: - e3nn.IrrepsArray: A rotation basis that is constrained by the permutation basis. - """ - assert all(x.shape[:-2] == permutation_basis.shape[1:] for x in rotation_basis.list) - - perm = np.reshape(permutation_basis, (permutation_basis.shape[0], -1)) # (free, dim) - - new_irreps: List[Tuple[int, TabulatedIrrep]] = [] - new_list: List[np.ndarray] = [] - - for rotation_basis_mul_ir, rot_basis in zip(rotation_basis.irreps, rotation_basis.list): - mul, ir = rotation_basis_mul_ir.mul, rotation_basis_mul_ir.rep - R = rot_basis[..., 0] - R = np.reshape(R, (-1, mul)).T # (mul, dim) - - perm_opt = perm[~np.all(perm[:, ~np.all(R == 0, axis=0)] == 0, axis=1)] - P, _ = basis_intersection(R, perm_opt, epsilon=epsilon, round_fn=round_fn) - - if P.shape[0] > 0: - new_irreps.append((P.shape[0], ir)) - new_list.append(np.einsum("vu,...ui->...vi", P, rot_basis)) - - return IrrepsArray(irreps=tuple(MulIrrep(mul, ir) for mul, ir in new_irreps), list=new_list) - - -def subrepr_permutation( - sub_f0: FrozenSet[int], perm_repr: FrozenSet[Tuple[int, Tuple[int, ...]]] -) -> FrozenSet[Tuple[int, Tuple[int, ...]]]: - sor = sorted(sub_f0) - return frozenset( - { - (s, tuple(sor.index(i) for i in p if i in sub_f0)) - for s, p in perm_repr - if all(i in sub_f0 or i == j for j, i in enumerate(p)) - } - ) - - -def reduce_subgroup_permutation( - sub_f0: FrozenSet[int], - perm_repr: FrozenSet[Tuple[int, Tuple[int, ...]]], - dims: Tuple[int, ...], - return_dim: bool = False, -) -> np.ndarray: - sub_perm_repr = subrepr_permutation(sub_f0, perm_repr) - sub_dims = tuple(dims[i] for i in sub_f0) - if len(sub_perm_repr) == 1: - if return_dim: - return prod(sub_dims) - return np.eye(prod(sub_dims)).reshape((prod(sub_dims),) + sub_dims) - base = reduce_permutation_base(sub_perm_repr, sub_dims) - if return_dim: - return len(base) - permutation_basis = reduce_permutation_matrix(base, sub_dims) - return np.reshape( - permutation_basis, (-1,) + tuple(dims[i] if i in sub_f0 else 1 for i in range(len(dims))) - ) - - -@functools.lru_cache(maxsize=None) -def full_base_fn(dims: Tuple[int, ...]) -> List[Tuple[int, ...]]: - return list(itertools.product(*(range(d) for d in dims))) - - -@functools.lru_cache(maxsize=None) -def reduce_permutation_base( - perm_repr: FrozenSet[Tuple[int, Tuple[int, ...]]], dims: Tuple[int, ...] -) -> FrozenSet[FrozenSet[FrozenSet[Tuple[int, Tuple[int, ...]]]]]: - full_base = full_base_fn(dims) # (0, 0, 0), (0, 0, 1), (0, 0, 2), ... (3, 3, 3) - # len(full_base) degrees of freedom in an unconstrained tensor - - # but there is constraints given by the group `formulas` - # For instance if `ij=-ji`, then 00=-00, 01=-01 and so on - base = set() - for x in full_base: - # T[x] is a coefficient of the tensor T and is related to other coefficient T[y] - # if x and y are related by a formula - xs = {(s, tuple(x[i] for i in p)) for s, p in perm_repr} - # s * T[x] are all equal for all (s, x) in xs - # if T[x] = -T[x] it is then equal to 0 and we lose this degree of freedom - if not (-1, x) in xs: - # the sign is arbitrary, put both possibilities - base.add(frozenset({frozenset(xs), frozenset({(-s, x) for s, x in xs})})) - - # len(base) is the number of degrees of freedom in the tensor. - - return frozenset(base) - - -@functools.lru_cache(maxsize=None) -def reduce_permutation_matrix( - base: FrozenSet[FrozenSet[FrozenSet[Tuple[int, Tuple[int, ...]]]]], dims: Tuple[int, ...] -) -> np.ndarray: - base = sorted( - [sorted([sorted(xs) for xs in x]) for x in base] - ) # requested for python 3.7 but not for 3.8 (probably a bug in 3.7) - - # First we compute the change of basis (projection) between full_base and base - d_sym = len(base) - Q = np.zeros((d_sym, prod(dims))) - - for i, x in enumerate(base): - x = max(x, key=lambda xs: sum(s for s, x in xs)) - for s, e in x: - j = 0 - for k, d in zip(e, dims): - j *= d - j += k - Q[i, j] = s / len(x) ** 0.5 - - np.testing.assert_allclose(Q @ Q.T, np.eye(d_sym)) - - return Q.reshape(d_sym, *dims) diff --git a/lie_nn/_src/reduced_tensor_product_test.py b/lie_nn/_src/reduced_tensor_product_test.py deleted file mode 100644 index 80d3318..0000000 --- a/lie_nn/_src/reduced_tensor_product_test.py +++ /dev/null @@ -1,19 +0,0 @@ -from lie_nn import reduced_symmetric_tensor_product_basis -from lie_nn.irreps import SO3 -from lie_nn import ReducedRep -import numpy as np - - -def test_tensor_product_basis_equivariance(): - irreps = ReducedRep.from_string("0+1+2+0", SO3) - Q = reduced_symmetric_tensor_product_basis(irreps, 3) - - params = (0.2, 0.1, 0.13) - - D_out = Q.rep.exp_map(params, ()) - Q1 = np.einsum("ijkx,xy->ijky", Q.array, D_out) - - D_in = irreps.exp_map(params, ()) - Q2 = np.einsum("ijkx,li,mj,nk->lmnx", Q.array, D_in, D_in, D_in) - - np.testing.assert_allclose(Q1, Q2, atol=1e-6, rtol=1e-6) diff --git a/lie_nn/_src/rep.py b/lie_nn/_src/rep.py index 7f79612..b81c065 100644 --- a/lie_nn/_src/rep.py +++ b/lie_nn/_src/rep.py @@ -1,15 +1,41 @@ -import dataclasses -from typing import Optional +from typing import Iterator, Optional, Tuple, Type import numpy as np import scipy.linalg -from .util import infer_algebra_from_generators, check_algebra_vs_generators +import lie_nn as lie class Rep: r"""Abstract Class, Representation of a Lie group.""" + def algebra(self) -> np.ndarray: + """Array of shape [lie_dim, lie_dim, lie_dim] + + Satisfying the Lie algebra commutation relations: + + .. math:: + + [X_i, X_j] = A_{ijk} X_k + + """ + raise NotImplementedError + + def continuous_generators(self) -> np.ndarray: + """Array of shape [lie_dim, dim, dim]""" + raise NotImplementedError + + def discrete_generators(self) -> np.ndarray: + """Array of shape [len(H), dim, dim]""" + raise NotImplementedError + + def create_trivial(self) -> "Rep": + """Create trivial representation (dim=1) + + With the same algebra and len(H) + """ + raise NotImplementedError + @property def lie_dim(self) -> int: A = self.algebra() @@ -28,32 +54,18 @@ def dim(self) -> int: # assert H.shape[1:] == (d, d) return d - def algebra(self) -> np.ndarray: - """``[X_i, X_j] = A_ijk X_k``""" - raise NotImplementedError - @property def A(self) -> np.ndarray: return self.algebra() - def continuous_generators(self) -> np.ndarray: - raise NotImplementedError - @property def X(self) -> np.ndarray: return self.continuous_generators() - def discrete_generators(self) -> np.ndarray: - raise NotImplementedError - @property def H(self) -> np.ndarray: return self.discrete_generators() - def create_trivial(self) -> "Rep": - # Create a trivial representation from the same group as self - raise NotImplementedError - def exp_map(self, continuous_params: np.ndarray, discrete_params: np.ndarray) -> np.ndarray: """Instanciate the representation @@ -64,6 +76,8 @@ def exp_map(self, continuous_params: np.ndarray, discrete_params: np.ndarray) -> Returns: ``(dim, dim)`` array """ + # TODO: now that we integrate the Sn group, that is non abelian. + # there is no more rule to how to parameterize the finite part of the represention. output = scipy.linalg.expm( np.einsum("a,aij->ij", continuous_params, self.continuous_generators()) ) @@ -81,13 +95,7 @@ def is_trivial(self) -> bool: and np.all(self.discrete_generators() == 1.0) ) - def check_algebra_vs_generators(rep: "Rep", rtol=1e-10, atol=1e-10): - check_algebra_vs_generators( - rep.algebra(), rep.continuous_generators(), rtol=rtol, atol=atol, assert_=True - ) - -@dataclasses.dataclass(init=False) class GenericRep(Rep): r"""Unknown representation""" _A: np.ndarray @@ -107,7 +115,7 @@ def from_generators( H: Optional[np.ndarray] = None, round_fn=lambda x: x, ) -> Optional["GenericRep"]: - A = infer_algebra_from_generators(X, round_fn=round_fn) + A = lie.utils.infer_algebra_from_generators(X, round_fn=round_fn) if A is None: return None if H is None: @@ -135,56 +143,276 @@ def __repr__(self) -> str: return f"GenericRep(dim={self.dim}, lie_dim={self.lie_dim}, len(H)={len(self.H)})" -def check_representation_triplet(rep1: Rep, rep2: Rep, rep3: Rep, rtol=1e-10, atol=1e-10): - assert np.allclose(rep1.algebra(), rep2.algebra(), rtol=rtol, atol=atol) - assert np.allclose(rep1.algebra(), rep3.algebra(), rtol=rtol, atol=atol) - - rep1.check_algebra_vs_generators(rtol=rtol, atol=atol) - rep2.check_algebra_vs_generators(rtol=rtol, atol=atol) - rep3.check_algebra_vs_generators(rtol=rtol, atol=atol) - - X1 = rep1.continuous_generators() # (lie_group_dimension, rep1.dim, rep1.dim) - X2 = rep2.continuous_generators() # (lie_group_dimension, rep2.dim, rep2.dim) - X3 = rep3.continuous_generators() # (lie_group_dimension, rep3.dim, rep3.dim) - assert X1.shape[0] == X2.shape[0] == X3.shape[0] - - from .clebsch_gordan import clebsch_gordan - - cg = clebsch_gordan(rep1, rep2, rep3) - assert cg.ndim == 1 + 3, (rep1, rep2, rep3, cg.shape) - assert cg.shape == (cg.shape[0], rep1.dim, rep2.dim, rep3.dim) - - # Orthogonality - # left_side = np.einsum('zijk,wijl->zkwl', cg, np.conj(cg)) - # right_side = np.eye(cg.shape[0] * rep3.dim) - # .reshape((cg.shape[0], rep3.dim, cg.shape[0], rep3.dim)) - # np.testing.assert_allclose(left_side, right_side, rtol=rtol, atol=atol) - - # if rep3 in rep1 * rep2: - # assert cg.shape[0] > 0 - # else: - # assert cg.shape[0] == 0 - - left_side = np.einsum("zijk,dlk->zdijl", cg, X3) - right_side = np.einsum("dil,zijk->zdljk", X1, cg) + np.einsum("djl,zijk->zdilk", X2, cg) - - for solution in range(cg.shape[0]): - for i in range(X1.shape[0]): - if not np.allclose( - left_side[solution][i], right_side[solution][i], rtol=rtol, atol=atol - ): - np.set_printoptions(precision=3, suppress=True) - print(rep1, rep2, rep3) - print('Left side: einsum("zijk,dlk->zdijl", cg, X3)') - print(left_side[solution][i]) - print( - "Right side: " - 'einsum("dil,zijk->zdljk", X1, cg) + einsum("djl,zijk->zdilk", X2, cg)' - ) - print(right_side[solution][i]) - np.set_printoptions(precision=8, suppress=False) - raise AssertionError( - f"Solution {solution}/{cg.shape[0]} for {rep1} * {rep2} = {rep3} " - "is not correct." - f"Clebsch-Gordan coefficient is not correct for Lie algebra generator {i}." - ) +class Irrep(Rep): + pass + + +class TabulatedIrrep(Irrep): + @classmethod + def from_string(cls, string: str) -> "TabulatedIrrep": + raise NotImplementedError + + def __mul__(rep1: "TabulatedIrrep", rep2: "TabulatedIrrep") -> Iterator["TabulatedIrrep"]: + # Selection rule + raise NotImplementedError + + @property + def dim(rep: "TabulatedIrrep") -> int: + raise NotImplementedError + + def __lt__(rep1: "TabulatedIrrep", rep2: "TabulatedIrrep") -> bool: + # This is used for sorting the irreps + raise NotImplementedError + + @classmethod + def iterator(cls) -> Iterator["TabulatedIrrep"]: + # Requirements: + # - the first element must be the trivial representation + # - the elements must be sorted by the __lt__ method + raise NotImplementedError + + @classmethod + def create_trivial(cls) -> "TabulatedIrrep": + return cls.iterator().__next__() + + @classmethod + def clebsch_gordan( + cls, rep1: "TabulatedIrrep", rep2: "TabulatedIrrep", rep3: "TabulatedIrrep" + ) -> np.ndarray: + # return an array of shape ``(number_of_paths, rep1.dim, rep2.dim, rep3.dim)`` + raise NotImplementedError + + +class MulRep(Rep): + mul: int + rep: Rep + + def __init__(self, mul: int, rep: Rep, *, force=False): + if not force: + raise RuntimeError("Use lie_nn.multiply instead") + self.mul = mul + self.rep = rep + + @classmethod + def from_string(cls, string: str, cls_irrep: Type[Rep]) -> "MulRep": + if "x" in string: + mul, rep = string.split("x") + else: + mul, rep = 1, string + return cls(mul=int(mul), rep=cls_irrep.from_string(rep)) + + @property + def dim(self) -> int: + return self.mul * self.rep.dim + + def algebra(self) -> np.ndarray: + return self.rep.algebra() + + def continuous_generators(self) -> np.ndarray: + X = self.rep.X + if X.shape[0] == 0: + return np.empty((0, self.dim, self.dim)) + return np.stack([lie.utils.direct_sum(*[x for _ in range(self.mul)]) for x in X], axis=0) + + def discrete_generators(self) -> np.ndarray: + H = self.rep.H + if H.shape[0] == 0: + return np.empty((0, self.dim, self.dim)) + return np.stack([lie.utils.direct_sum(*[x for _ in range(self.mul)]) for x in H], axis=0) + + def create_trivial(self) -> Rep: + return self.rep.create_trivial() + + def __repr__(self) -> str: + return f"{self.mul}x{self.rep}" + + +class SumRep(Rep): + r"""Representation of the form + + .. math:: + \osum_i \rho_i + """ + reps: Tuple[Rep, ...] + + def __init__(self, reps: Tuple[Rep, ...], *, force=False): + if not force: + raise RuntimeError("Use lie_nn.direct_sum instead") + assert len(reps) >= 1 + self.reps = tuple(reps) + + @classmethod + def from_string(cls, string: str, cls_irrep: Type[TabulatedIrrep]) -> "SumRep": + return cls([MulRep.from_string(term, cls_irrep) for term in string.split("+")]) + + @property + def dim(self) -> int: + return sum(rep.dim for rep in self.reps) + + def algebra(self) -> np.ndarray: + return self.reps[0].algebra() + + def continuous_generators(self) -> np.ndarray: + if self.lie_dim == 0: + return np.empty((0, self.dim, self.dim)) + Xs = [] + for i in range(self.lie_dim): + Xs += [lie.utils.direct_sum(*[rep.X[i] for rep in self.reps])] + return np.stack(Xs) + + def discrete_generators(self) -> np.ndarray: + n = len(self.reps[0].H) + if n == 0: + return np.empty((0, self.dim, self.dim)) + Hs = [] + for i in range(n): + Hs += [lie.utils.direct_sum(*[rep.H[i] for rep in self.reps])] + return np.stack(Hs) + + def create_trivial(self) -> Rep: + return self.reps[0].create_trivial() + + def __repr__(self) -> str: + r = " + ".join(repr(rep) for rep in self.reps) + return r + + +class QRep(Rep): + r"""Change of basis of a representation + + .. math:: + + Q \rho Q^{-1} + + where :math:`Q^{-1}` is the pseudo-inverse of :math:`Q`. + """ + rep: Rep + Q: np.ndarray + + def __init__(self, Q: np.ndarray, rep: Rep, *, force=False): + if not force: + raise RuntimeError("Use lie_nn.change_basis instead") + + assert Q.shape[0] == Q.shape[1] + assert Q.shape[0] == rep.dim + self.Q = Q + self.rep = rep + + @property + def dim(self) -> int: + return self.Q.shape[0] + + def algebra(self) -> np.ndarray: + return self.rep.algebra() + + def continuous_generators(self) -> np.ndarray: + return np.einsum("ij,ajk,kl->ail", self.Q, self.rep.X, np.linalg.pinv(self.Q)) + + def discrete_generators(self) -> np.ndarray: + return np.einsum("ij,ajk,kl->ail", self.Q, self.rep.H, np.linalg.pinv(self.Q)) + + def create_trivial(self) -> "Rep": + return self.rep.create_trivial() + + def __repr__(self) -> str: + return f"Q({self.rep})Q^{{-1}}" + + +class ConjRep(Rep): + rep: Rep + + def __init__(self, rep: Rep, *, force=False): + if not force: + raise RuntimeError("Use lie_nn.conjugate instead") + self.rep = rep + + @property + def dim(self) -> int: + return self.rep.dim + + @property + def lie_dim(self) -> int: + return self.rep.lie_dim + + def algebra(self) -> np.ndarray: + return self.rep.algebra() + + def continuous_generators(self) -> np.ndarray: + return np.conjugate(self.rep.X) + + def discrete_generators(self) -> np.ndarray: + return np.conjugate(self.rep.H) + + def create_trivial(self) -> "Rep": + return self.rep.create_trivial() + + def __repr__(self) -> str: + return f"({self.rep})*" + + +class ReducedRep(Rep): + _A: np.ndarray + num_H: int + Q: np.ndarray + reps: Tuple[Tuple[int, Rep], ...] + + def __init__( + self, + A: np.ndarray, + num_H: int, + Q: np.ndarray, + reps: Tuple[Tuple[int, Rep], ...], + *, + force=False, + ): + if not force: + raise RuntimeError("Use lie_nn.reduce instead") + self._A = A + self.num_H = num_H + self.Q = Q + self.reps = reps + + dim = 0 + for mul, rep in reps: + dim += mul * rep.dim + np.testing.assert_allclose(rep.A, A) + assert len(rep.H) == num_H + assert dim == Q.shape[0] + + def split_Q(self) -> Tuple[np.ndarray, ...]: + Qs = [] + i = 0 + for mul, rep in self.reps: + Qs.append(self.Q[:, i : i + mul * rep.dim]) + i += mul * rep.dim + return tuple(Qs) + + def _into(self) -> Rep: + if len(self.reps) == 0: + return GenericRep( + self.A, np.empty((self.A.shape[0], 0, 0)), np.empty((self.num_H, 0, 0)) + ) + return QRep( + self.Q, + SumRep([MulRep(mul, rep, force=True) for mul, rep in self.reps], force=True), + force=True, + ) + + @property + def dim(self) -> int: + return self.Q.shape[0] + + def algebra(self) -> np.ndarray: + return self._A + + def continuous_generators(self) -> np.ndarray: + return self._into().continuous_generators() + + def discrete_generators(self) -> np.ndarray: + return self._into().discrete_generators() + + def create_trivial(self) -> "Rep": + return self._into().create_trivial() + + def __repr__(self) -> str: + return f"ReducedRep({self._into()})" diff --git a/lie_nn/_src/symmetric_tensor_power.py b/lie_nn/_src/symmetric_tensor_power.py new file mode 100644 index 0000000..12ca83c --- /dev/null +++ b/lie_nn/_src/symmetric_tensor_power.py @@ -0,0 +1,135 @@ +import itertools + +import numpy as np +from multimethod import multimethod + +from .utils import kron, permutation_base, permutation_base_to_matrix +import lie_nn as lie + + +def _symmetric_perm_repr(n: int): + return frozenset((1, p) for p in itertools.permutations(range(n))) + + +def _symmetric_perm_matrix(d: int, n: int): + base = permutation_base(_symmetric_perm_repr(n), (d,) * n) + P = permutation_base_to_matrix(base, (d,) * n) # [symmetric, d, d, ... d] + P = np.reshape(P, (P.shape[0], -1)) # [symmetric, d**n] + return P + + +@multimethod +def symmetric_tensor_power(rep: lie.QRep, n: int) -> lie.Rep: # noqa: F811 + # out = P @ Q @ tensorpower(rep, n) @ Q^-1 @ P^T + Q = kron(*[rep.Q] * n) # [d**n, d**n] + P = _symmetric_perm_matrix(rep.dim, n) # [symmetric, d**n] + S = P @ Q @ P.T + + # out = S @ P @ tensorpower(rep, n) @ P^T @ S^-1 + return lie.change_basis(S, symmetric_tensor_power(rep.rep, n)) + + +@multimethod +def symmetric_tensor_power(rep: lie.SumRep, n: int) -> lie.Rep: # noqa: F811 + # for all subreps in rep.reps + # run symmetric_tensor_power(subrep, i) for i = 1, ..., n + raise NotImplementedError + + +@multimethod +def symmetric_tensor_power(rep: lie.MulRep, n: int) -> lie.Rep: # noqa: F811 + # stp = [symmetric_tensor_power(rep.rep, i) for i in range(0, n + 1)] + # i j if rep.mul == 2 and n == 3 + # 3 0 + # 2 1 + # 1 2 + # 0 3 + raise NotImplementedError + + +@multimethod +def symmetric_tensor_power(rep: lie.ConjRep, n: int) -> lie.Rep: # noqa: F811 + return lie.conjugate(symmetric_tensor_power(rep.rep, n)) + + +@multimethod +def symmetric_tensor_power(rep: lie.Rep, n: int) -> lie.Rep: # noqa: F811 + if n == 1: + return rep + + a = n // 2 + b = n - a + + Pa = _symmetric_perm_matrix(rep.dim, a) + Pb = _symmetric_perm_matrix(rep.dim, b) + + Ra = symmetric_tensor_power(rep, a) # Pa @ tensorpower(rep, n // 2) @ Pa^T + Rb = symmetric_tensor_power(rep, b) # Pb @ tensorpower(rep, n - n // 2) @ Pb^T + + Pab = kron(Pa, Pb) + Rab = lie.reduce(lie.tensor_product(Ra, Rb)) # Pab @ tensorpower(rep, n) @ Pab^T + + Pn = _symmetric_perm_matrix(rep.dim, n) + S = Pn @ Pab.T # [sym(n), sym(a) * sym(b)] + + print(f"project {Rab} of dim {Rab.dim} ") + print(f"with S of dim {S.shape} ") + print(S) + + # return _project(S, Rab) + raise NotImplementedError + + +# @multimethod +# def _project(S: np.ndarray, rep: QRep) -> Rep: # noqa: F811 +# # Assume S S^T = I but not S^T S = I +# # out = S @ Q @ rep @ Q^-1 @ S^T +# # u, s, vh = np.linalg.svd(S @ rep.Q, full_matrices=False) +# # np.testing.assert_allclose(u @ np.diag(s) @ vh, S @ rep.Q, atol=1e-10) +# # print(f"u {u.shape}\n{u}") +# # print(f"s {s.shape}\n{s}") +# # print(f"vh {vh.shape}\n{vh}") +# # return change_basis(u @ np.diag(s), _project(vh, rep.rep)) +# P = S @ rep.Q +# print(f"P {P.shape}\n{P}") +# return _project(P, rep.rep) + + +# @multimethod +# def _project(S: np.ndarray, rep: SumRep) -> Rep: # noqa: F811 +# # Assume each subrep is independently projected +# i = 0 +# reps = [] +# Qs = [] +# for subrep in rep.reps: +# j = i + subrep.dim +# Sij = S[:, i:j] +# i = j + +# print(f"subrep={subrep}, Sij {Sij.shape}\n{Sij}") + +# if np.linalg.norm(Sij) > 1e-10: +# reps.append(subrep) +# Qs.append(Sij) + +# Q = np.concatenate(Qs, axis=1) +# return lie.change_basis(Q, lie.direct_sum(*reps)) + + +# @multimethod +# def _project(S: np.ndarray, rep: MulRep) -> Rep: # noqa: F811 +# # Assume rep.rep is an irrep +# S = np.reshape(S, (S.shape[0], rep.mul, rep.dim)) +# x = S[:, :, 0].T # [mul, d_out] +# _, u = lie.util.gram_schmidt_with_change_of_basis(x) +# mul = u.shape[0] # new mul +# S = np.einsum("uv , dvi -> dui", u, S) # [d_out, mul', d] +# S = np.reshape(S, (S.shape[0], -1)) # [d_out, mul' * d] +# assert S.shape[0] == S.shape[1], S.shape +# return lie.change_basis(S, lie.MulRep(mul, rep.rep, force=True)) + + +# @multimethod +# def _project(S: np.ndarray, rep: Rep): # noqa: F811 +# assert S.shape[0] == S.shape[1], (S, rep.dim) +# return lie.change_basis(S, rep) diff --git a/lie_nn/_src/tensor_product.py b/lie_nn/_src/tensor_product.py index 603f91e..67cbe44 100644 --- a/lie_nn/_src/tensor_product.py +++ b/lie_nn/_src/tensor_product.py @@ -1,21 +1,17 @@ import numpy as np -from multipledispatch import dispatch +from multimethod import multimethod import lie_nn as lie -from .irrep import TabulatedIrrep -from .reduced_rep import MulIrrep, ReducedRep -from .rep import GenericRep, Rep - -@dispatch(Rep, Rep) -def tensor_product(rep1: Rep, rep2: Rep) -> GenericRep: +@multimethod +def tensor_product(rep1: lie.Rep, rep2: lie.Rep) -> lie.GenericRep: assert np.allclose(rep1.A, rep2.A) # same lie algebra X1, H1, I1 = rep1.X, rep1.H, np.eye(rep1.dim) X2, H2, I2 = rep2.X, rep2.H, np.eye(rep2.dim) assert H1.shape[0] == H2.shape[0] # same discrete dimension d = rep1.dim * rep2.dim - return GenericRep( + return lie.GenericRep( A=rep1.A, X=(np.einsum("aij,kl->aikjl", X1, I2) + np.einsum("ij,akl->aikjl", I1, X2)).reshape( X1.shape[0], d, d @@ -24,128 +20,123 @@ def tensor_product(rep1: Rep, rep2: Rep) -> GenericRep: ) -@dispatch(TabulatedIrrep, TabulatedIrrep) -def tensor_product(irrep1: TabulatedIrrep, irrep2: TabulatedIrrep) -> ReducedRep: # noqa: F811 +@multimethod +def tensor_product(irrep1: lie.TabulatedIrrep, irrep2: lie.TabulatedIrrep) -> lie.Rep: # noqa: F811 assert np.allclose(irrep1.A, irrep2.A) # same lie algebra CG_list = [] irreps_list = [] for ir_out in irrep1 * irrep2: - CG = np.moveaxis(lie.clebsch_gordan(irrep1, irrep2, ir_out), 3, 1) # [sol, ir3, ir1, ir2] + CG = np.moveaxis( + irrep1.clebsch_gordan(irrep1, irrep2, ir_out), 3, 1 + ) # [sol, ir3, ir1, ir2] mul = CG.shape[0] CG = CG.reshape(CG.shape[0] * CG.shape[1], CG.shape[2] * CG.shape[3]) CG_list.append(CG) - irreps_list.append(MulIrrep(mul=mul, rep=ir_out)) + irreps_list.append(lie.multiply(mul, ir_out)) CG = np.concatenate(CG_list, axis=0) Q = np.linalg.inv(CG) - return ReducedRep(A=irrep1.A, irreps=tuple(irreps_list), Q=Q) - - -@dispatch(MulIrrep, MulIrrep) -def tensor_product(mulirrep1: MulIrrep, mulirrep2: MulIrrep) -> ReducedRep: # noqa: F811 - assert np.allclose(mulirrep1.A, mulirrep2.A) # same lie algebra - m1, m2 = mulirrep1.mul, mulirrep2.mul - tp_irreps = tensor_product(mulirrep1.rep, mulirrep2.rep) - Q = tp_irreps.Q - irreps = tp_irreps.irreps - Q_out = [] - irreps_out = [] - s = 0 - for mul_ir in irreps: - irreps_out.append(MulIrrep(mul=m1 * m2 * mul_ir.mul, rep=mul_ir.rep)) - q = Q[:, s : s + mul_ir.dim].reshape( - mulirrep1.rep.dim, mulirrep2.rep.dim, mul_ir.mul, mul_ir.rep.dim - ) - q = np.einsum("ijsk,ur,vt->uivjrtsk", q, np.eye(m1), np.eye(m2)) - q = q.reshape( - q.shape[0] * q.shape[1] * q.shape[2] * q.shape[3], - q.shape[4] * q.shape[5] * q.shape[6] * q.shape[7], - ) - Q_out.append(q) - s += mul_ir.dim - Q_out = np.concatenate(Q_out, axis=-1) - return ReducedRep(A=mulirrep1.A, irreps=tuple(irreps_out), Q=Q_out) - - -@dispatch(ReducedRep, ReducedRep) -def tensor_product(rep1: ReducedRep, rep2: ReducedRep) -> ReducedRep: # noqa: F811 - q1 = np.eye(rep1.dim) if rep1.Q is None else rep1.Q - q2 = np.eye(rep2.dim) if rep2.Q is None else rep2.Q - Q_tp = np.einsum("ij,kl->ikjl", q1, q2).reshape(rep1.dim * rep2.dim, rep1.dim * rep2.dim) - mulir_list = [] - - Q = np.zeros((rep1.dim, rep2.dim, rep1.dim * rep2.dim), dtype=np.complex128) - k = 0 - i = 0 - for mulirrep1 in rep1.irreps: - j = 0 - for mulirrep2 in rep2.irreps: - reducedrep = tensor_product(mulirrep1, mulirrep2) - mulir_list += reducedrep.irreps - q = reducedrep.Q.reshape(mulirrep1.dim, mulirrep2.dim, reducedrep.dim) - Q[i : i + mulirrep1.dim, j : j + mulirrep2.dim, k : k + reducedrep.dim] = q - k += reducedrep.dim - j += mulirrep2.dim - i += mulirrep1.dim - assert k == rep1.dim * rep2.dim - Q = Q.reshape(rep1.dim * rep2.dim, rep1.dim * rep2.dim) - Q = Q_tp @ Q + return lie.change_basis(Q, lie.direct_sum(*irreps_list)) - if np.allclose(Q.imag, 0): - Q = Q.real - if np.allclose(Q, np.eye(Q.shape[0])): - Q = None - return ReducedRep(A=rep1.A, irreps=tuple(mulir_list), Q=Q) +@multimethod +def tensor_product(mulrep: lie.MulRep, rep: lie.Rep) -> lie.Rep: # noqa: F811 + assert np.allclose(mulrep.A, rep.A) # same lie algebra + return lie.multiply(mulrep.mul, tensor_product(mulrep.rep, rep)) -@dispatch(MulIrrep, TabulatedIrrep) -def tensor_product(mulirrep1: MulIrrep, irrep2: TabulatedIrrep) -> ReducedRep: # noqa: F811 - return tensor_product(mulirrep1, MulIrrep(mul=1, rep=irrep2)) +@multimethod +def tensor_product(mulrep: lie.MulRep, rep: lie.MulRep) -> lie.Rep: # noqa: F811 + assert np.allclose(mulrep.A, rep.A) # same lie algebra + return lie.multiply(mulrep.mul, tensor_product(mulrep.rep, rep)) -@dispatch(TabulatedIrrep, MulIrrep) -def tensor_product(irrep1: TabulatedIrrep, mulirrep2: MulIrrep) -> ReducedRep: # noqa: F811 - return tensor_product(MulIrrep(mul=1, rep=irrep1), mulirrep2) +@multimethod +def tensor_product(rep: lie.Rep, mulrep: lie.MulRep) -> lie.Rep: # noqa: F811 + assert np.allclose(rep.A, mulrep.A) # same lie algebra -@dispatch(MulIrrep, ReducedRep) -def tensor_product(mulirrep1: MulIrrep, rep2: ReducedRep) -> ReducedRep: # noqa: F811 - return tensor_product(ReducedRep(A=mulirrep1.A, irreps=(mulirrep1,), Q=None), rep2) + Q = np.reshape( + np.einsum( + "ij,mn,uv->ium vjn", + np.eye(rep.dim), + np.eye(mulrep.rep.dim), + np.eye(mulrep.mul), + ), + ( + rep.dim * mulrep.rep.dim * mulrep.mul, + rep.dim * mulrep.rep.dim * mulrep.mul, + ), + ) + return lie.change_basis(Q, lie.multiply(mulrep.mul, tensor_product(rep, mulrep.rep))) -@dispatch(ReducedRep, MulIrrep) -def tensor_product(rep1: ReducedRep, mulirrep2: MulIrrep) -> ReducedRep: # noqa: F811 - return tensor_product(ReducedRep(A=mulirrep2.A, irreps=(mulirrep2,), Q=None), rep1) +@multimethod +def tensor_product(sumrep: lie.SumRep, rep: lie.Rep) -> lie.Rep: # noqa: F811 + return lie.direct_sum(*[tensor_product(subrep, rep) for subrep in sumrep.reps]) -@dispatch(ReducedRep, TabulatedIrrep) -def tensor_product(rep1: ReducedRep, irrep2: TabulatedIrrep) -> ReducedRep: # noqa: F811 - return tensor_product(rep1, MulIrrep(mul=1, rep=irrep2)) +@multimethod +def tensor_product(sumrep: lie.SumRep, rep: lie.SumRep) -> lie.Rep: # noqa: F811 + return lie.direct_sum(*[tensor_product(subrep, rep) for subrep in sumrep.reps]) + + +@multimethod +def tensor_product(rep: lie.Rep, sumrep: lie.SumRep) -> lie.Rep: # noqa: F811 + list = [] + Q = np.zeros((rep.dim, sumrep.dim, rep.dim * sumrep.dim)) + k = 0 + j = 0 + for subrep in sumrep.reps: + tp = tensor_product(rep, subrep) + list += [tp] + q = np.eye(tp.dim).reshape(rep.dim, subrep.dim, tp.dim) + Q[:, j : j + subrep.dim, k : k + tp.dim] = q + k += tp.dim + j += subrep.dim + + Q = Q.reshape(rep.dim * sumrep.dim, rep.dim * sumrep.dim) + return lie.change_basis(Q, lie.direct_sum(*list)) -@dispatch(TabulatedIrrep, ReducedRep) -def tensor_product(irrep1: TabulatedIrrep, rep2: ReducedRep) -> ReducedRep: # noqa: F811 - return tensor_product(MulIrrep(mul=1, rep=irrep1), rep2) +@multimethod +def tensor_product(qrep: lie.QRep, rep: lie.Rep) -> lie.Rep: # noqa: F811 + dim = qrep.dim * rep.dim + Q = np.einsum("ij,kl->ikjl", qrep.Q, np.eye(rep.dim)).reshape(dim, dim) + return lie.change_basis(Q, tensor_product(qrep.rep, rep)) -# @dispatch(Rep, int) -def tensor_power(rep: Rep, n: int) -> Rep: - result = rep.create_trivial() +@multimethod +def tensor_product(rep: lie.Rep, qrep: lie.QRep) -> lie.Rep: # noqa: F811 + dim = rep.dim * qrep.dim + Q = np.einsum("ij,kl->ikjl", np.eye(rep.dim), qrep.Q).reshape(dim, dim) + return lie.change_basis(Q, tensor_product(rep, qrep.rep)) + + +@multimethod +def tensor_product(qrep1: lie.QRep, qrep2: lie.QRep) -> lie.Rep: # noqa: F811 + dim = qrep1.dim * qrep2.dim + Q = np.einsum("ij,kl->ikjl", qrep1.Q, qrep2.Q).reshape(dim, dim) + return lie.change_basis(Q, tensor_product(qrep1.rep, qrep2.rep)) + + +@multimethod +def tensor_product(rep1: lie.ConjRep, rep2: lie.ConjRep) -> lie.Rep: # noqa: F811 + return lie.conjugate(tensor_product(rep1.rep, rep2.rep)) + + +def tensor_power(rep: lie.Rep, n: int) -> lie.Rep: + result = None while True: if n & 1: - result = tensor_product(rep, result) + if result is None: + result = rep + else: + result = tensor_product(rep, result) n >>= 1 if n == 0: - return result + if result is None: + return rep.create_trivial() + else: + return result rep = tensor_product(rep, rep) - - -# @dispatch(ReducedRep, int) -# def tensor_power(rep: ReducedRep, n: int) -> ReducedRep: -# # TODO reduce into irreps and wrap with the change of basis that -# maps to the usual tensor product -# # TODO as well reduce into irreps of S_n -# # and diagonalize irreps of S_n in the same basis that diagonalizes -# irreps of S_{n-1} (unclear how to do this) -# raise NotImplementedError diff --git a/lie_nn/_src/tensor_product_test.py b/lie_nn/_src/tensor_product_test.py deleted file mode 100644 index cba22bc..0000000 --- a/lie_nn/_src/tensor_product_test.py +++ /dev/null @@ -1,41 +0,0 @@ -import itertools - -import numpy as np -import pytest -from lie_nn import GenericRep, TabulatedIrrep, MulIrrep, ReducedRep, tensor_product, tensor_power -from lie_nn.irreps import O3, SL2C, SO3, SO13, SU2 - - -def first_reps(IR: TabulatedIrrep, n: int): - return list(itertools.islice(IR.iterator(), n)) - - -REPRESENTATIONS = [O3, SU2, SO3, SL2C, SO13] -# TODO: add SU2Real (or remove it completely) -# Note: SU2Real are not Irreps and this might be a problem -# TODO: resolve tensor_product_consistency for SU3, SU4 - - -@pytest.mark.parametrize( - "ir1, ir2", - sum((list(itertools.product(first_reps(IR, 4), repeat=2)) for IR in REPRESENTATIONS), []), -) -def test_tensor_product_consistency(ir1, ir2): - rep1 = ReducedRep.from_irreps([(2, ir1), ir2]) - rep2 = ReducedRep.from_irreps([(3, ir1)]) - - tp1 = tensor_product(rep1, rep2) - tp2 = tensor_product(GenericRep.from_rep(rep1), GenericRep.from_rep(rep2)) - - np.testing.assert_allclose(tp1.X, tp2.X, atol=1e-10) - - -def test_tensor_product_types(): - assert isinstance(tensor_product(O3(l=1, p=1), O3(l=1, p=1)), ReducedRep) - assert isinstance(tensor_product(O3(l=1, p=1), MulIrrep(mul=2, rep=O3(l=1, p=1))), ReducedRep) - - -def test_tensor_power_types(): - assert isinstance(tensor_power(O3(l=1, p=1), 2), ReducedRep) - assert isinstance(tensor_power(MulIrrep(mul=2, rep=O3(l=1, p=1)), 2), ReducedRep) - assert isinstance(tensor_power(GenericRep.from_rep(O3(l=1, p=1)), 2), GenericRep) diff --git a/lie_nn/_src/tests.py b/lie_nn/_src/tests.py new file mode 100644 index 0000000..e34a227 --- /dev/null +++ b/lie_nn/_src/tests.py @@ -0,0 +1,60 @@ +import numpy as np +from .rep import Rep +from .utils import check_algebra_vs_generators + + +def check_representation_triplet(rep1: Rep, rep2: Rep, rep3: Rep, rtol=1e-10, atol=1e-10): + assert np.allclose(rep1.algebra(), rep2.algebra(), rtol=rtol, atol=atol) + assert np.allclose(rep1.algebra(), rep3.algebra(), rtol=rtol, atol=atol) + + check_algebra_vs_generators(rep1.A, rep1.X, rtol=rtol, atol=atol) + check_algebra_vs_generators(rep2.A, rep2.X, rtol=rtol, atol=atol) + check_algebra_vs_generators(rep3.A, rep3.X, rtol=rtol, atol=atol) + + X1 = rep1.continuous_generators() # (lie_group_dimension, rep1.dim, rep1.dim) + X2 = rep2.continuous_generators() # (lie_group_dimension, rep2.dim, rep2.dim) + X3 = rep3.continuous_generators() # (lie_group_dimension, rep3.dim, rep3.dim) + assert X1.shape[0] == X2.shape[0] == X3.shape[0] + + from .clebsch_gordan import clebsch_gordan + + cg = clebsch_gordan(rep1, rep2, rep3) + assert cg.ndim == 1 + 3, (rep1, rep2, rep3, cg.shape) + assert cg.shape == (cg.shape[0], rep1.dim, rep2.dim, rep3.dim) + + # Orthogonality + # left_side = np.einsum('zijk,wijl->zkwl', cg, np.conj(cg)) + # right_side = np.eye(cg.shape[0] * rep3.dim) + # .reshape((cg.shape[0], rep3.dim, cg.shape[0], rep3.dim)) + # np.testing.assert_allclose(left_side, right_side, rtol=rtol, atol=atol) + + # if rep3 in rep1 * rep2: + # assert cg.shape[0] > 0 + # else: + # assert cg.shape[0] == 0 + + left_side = np.einsum("zijk,dlk->zdijl", cg, X3) + right_side = np.einsum("dil,zijk->zdljk", X1, cg) + np.einsum("djl,zijk->zdilk", X2, cg) + + for solution in range(cg.shape[0]): + for i in range(X1.shape[0]): + if not np.allclose( + left_side[solution][i], right_side[solution][i], rtol=rtol, atol=atol + ): + np.set_printoptions(precision=3, suppress=True) + print(rep1, rep2, rep3) + print('Left side: einsum("zijk,dlk->zdijl", cg, X3)') + print(left_side[solution][i]) + print( + "Right side: " + 'einsum("dil,zijk->zdljk", X1, cg) + einsum("djl,zijk->zdilk", X2, cg)' + ) + print(right_side[solution][i]) + diff = left_side[solution][i] - right_side[solution][i] + print("Difference:", np.abs(diff).max()) + np.set_printoptions(precision=8, suppress=False) + raise AssertionError( + f"Solution {solution}/{cg.shape[0]} for {rep1} * {rep2} = {rep3} " + "is not correct." + f"Clebsch-Gordan coefficient is not correct for Lie algebra generator {i}." + ) diff --git a/lie_nn/_src/util.py b/lie_nn/_src/utils.py similarity index 81% rename from lie_nn/_src/util.py rename to lie_nn/_src/utils.py index 93ec850..2d8de0c 100644 --- a/lie_nn/_src/util.py +++ b/lie_nn/_src/utils.py @@ -1,8 +1,9 @@ +import functools +import itertools from functools import reduce -from typing import List, Optional, Tuple, Union +from typing import FrozenSet, List, Optional, Tuple, Union import numpy as np -import sympy as sp def prod(list_of_numbers: List[Union[int, float]]) -> Union[int, float]: @@ -88,6 +89,8 @@ def _round_to_sqrt_rational(x, max_denominator): def _round_to_sqrt_rational_sympy(x, max_denominator): + import sympy as sp + sign = np.sign(x) n, d = as_approx_integer_ratio(x**2) n, d = limit_denominator(n, d, max_denominator**2 + 1) @@ -210,9 +213,7 @@ def direct_sum(A, *BCD): def gram_schmidt(A: np.ndarray, *, epsilon=1e-4, round_fn=lambda x: x) -> np.ndarray: - """ - Orthogonalize a matrix using the Gram-Schmidt process. - """ + """Orthogonalize a matrix using the Gram-Schmidt process.""" assert A.ndim == 2, "Gram-Schmidt process only works for matrices." assert A.dtype in [ np.float64, @@ -230,6 +231,38 @@ def gram_schmidt(A: np.ndarray, *, epsilon=1e-4, round_fn=lambda x: x) -> np.nda return np.stack(Q) if len(Q) > 0 else np.empty((0, A.shape[1])) +def gram_schmidt_with_change_of_basis(A: np.ndarray, *, epsilon=1e-4, round_fn=lambda x: x): + """Gram-Schmidt process returning the change of basis matrix. + + Q, U = gram_schmidt_with_change_of_basis(A) + Q = U @ A + """ + assert A.ndim == 2, "Gram-Schmidt process only works for matrices." + assert A.dtype in [ + np.float64, + np.complex128, + ], "Gram-Schmidt process only works for float64 matrices." + Q = [] + U = [] + for i in range(A.shape[0]): + v = np.copy(A[i]) + u = np.zeros_like(v, shape=(A.shape[0],)) + u[i] = 1.0 + for v_, u_ in zip(Q, U): + c = np.dot(np.conj(v_), v) + v -= c * v_ + u -= c * u_ + norm = np.linalg.norm(v) + if norm > epsilon: + v = round_fn(v / norm) + u = u / norm + Q += [v] + U += [u] + if len(Q) > 0: + return np.stack(Q), np.stack(U) + return np.empty((0, A.shape[1])), np.empty((0, A.shape[0])) + + def extend_basis(A: np.ndarray, *, epsilon=1e-4, round_fn=lambda x: x, returns="Q") -> np.ndarray: """Add rows to A to make it full rank. @@ -560,6 +593,16 @@ def eigenspaces( return [(val, vec[:, i == j]) for j, val in enumerate(unique_val)] +def are_isomorphic(X1: np.array, X2: np.array, *, epsilon: float = 1e-10) -> bool: + """Checks if representations are isomorphic.""" + if X1.shape != X2.shape: + return False + Q = infer_change_of_basis(X1, X2, epsilon=epsilon) + w = np.random.rand(len(Q)) + M = np.einsum("n,nij->ij", w, Q) + return np.linalg.matrix_rank(M) == len(M) + + def decompose_rep_into_irreps( X: np.array, *, epsilon: float = 1e-10, round_fn=lambda x: x ) -> List[np.array]: @@ -567,7 +610,8 @@ def decompose_rep_into_irreps( Input: X: np.array [num_gen, d, d] - generators of a representation. Output: - Ys: List[np.array] - list of generators of irreducible representations. + List of (multiplicity, irreducible representation) pairs. + int , np.array [num_gen, d, d] """ Q = infer_change_of_basis(X, X, epsilon=epsilon, round_fn=round_fn) # X @ Q == Q @ X w = np.random.rand(len(Q)) @@ -581,7 +625,21 @@ def decompose_rep_into_irreps( B = gram_schmidt(B.T.conj() @ B, epsilon=epsilon, round_fn=round_fn) # Make it sparse!! Ys += [B @ X @ B.T.conj()] - return Ys + + Ys = sorted(Ys, key=lambda x: x.shape[1]) + + Zs = [] + while len(Ys) > 0: + Y = Ys.pop(0) + mul = 1 + for i in range(len(Ys) - 1, -1, -1): + if are_isomorphic(Y, Ys[i], epsilon=epsilon): + Ys.pop(i) + mul += 1 + + Zs += [(mul, Y)] + + return Zs def is_irreducible(X: np.array, *, epsilon: float = 1e-10) -> bool: @@ -604,3 +662,60 @@ def regular_representation(table: np.array) -> np.array: reg_rep = np.zeros((n, n, n)) reg_rep[g, gh, h] = 1 return reg_rep + + +@functools.lru_cache(maxsize=None) +def full_base_fn(dims: Tuple[int, ...]) -> List[Tuple[int, ...]]: + return list(itertools.product(*(range(d) for d in dims))) + + +@functools.lru_cache(maxsize=None) +def permutation_base( + perm_repr: FrozenSet[Tuple[int, Tuple[int, ...]]], dims: Tuple[int, ...] +) -> FrozenSet[FrozenSet[FrozenSet[Tuple[int, Tuple[int, ...]]]]]: + full_base = full_base_fn(dims) # (0, 0, 0), (0, 0, 1), (0, 0, 2), ... (3, 3, 3) + # len(full_base) degrees of freedom in an unconstrained tensor + + # but there is constraints given by the group `formulas` + # For instance if `ij=-ji`, then 00=-00, 01=-01 and so on + base = set() + for x in full_base: + # T[x] is a coefficient of the tensor T and is related to other coefficient T[y] + # if x and y are related by a formula + xs = {(s, tuple(x[i] for i in p)) for s, p in perm_repr} + # s * T[x] are all equal for all (s, x) in xs + # if T[x] = -T[x] it is then equal to 0 and we lose this degree of freedom + if not (-1, x) in xs: + # the sign is arbitrary, put both possibilities + base.add(frozenset({frozenset(xs), frozenset({(-s, x) for s, x in xs})})) + + # len(base) is the number of degrees of freedom in the tensor. + + return frozenset(base) + + +@functools.lru_cache(maxsize=None) +def permutation_base_to_matrix( + base: FrozenSet[FrozenSet[FrozenSet[Tuple[int, Tuple[int, ...]]]]], + dims: Tuple[int, ...], +) -> np.ndarray: + base = sorted( + [sorted([sorted(xs) for xs in x]) for x in base] + ) # requested for python 3.7 but not for 3.8 (probably a bug in 3.7) + + # First we compute the change of basis (projection) between full_base and base + d_sym = len(base) + Q = np.zeros((d_sym, prod(dims)), np.float64) + + for i, x in enumerate(base): + x = max(x, key=lambda xs: sum(s for s, x in xs)) + for s, e in x: + j = 0 + for k, d in zip(e, dims): + j *= d + j += k + Q[i, j] = s / len(x) ** 0.5 + + np.testing.assert_allclose(Q @ Q.T, np.eye(d_sym)) + + return Q.reshape(d_sym, *dims) diff --git a/lie_nn/finite.py b/lie_nn/finite.py index 995bad9..6867fe0 100644 --- a/lie_nn/finite.py +++ b/lie_nn/finite.py @@ -1,3 +1,3 @@ -from ._src.finite import Sn_natural, Sn_trivial, Sn_standard +from ._src.finite.perm import Sn_natural, Sn_trivial, Sn_standard __all__ = ["Sn_natural", "Sn_trivial", "Sn_standard"] diff --git a/lie_nn/irreps/__init__.py b/lie_nn/irreps/__init__.py index 97195b0..dc70e14 100644 --- a/lie_nn/irreps/__init__.py +++ b/lie_nn/irreps/__init__.py @@ -1,3 +1,4 @@ +from .._src.irreps.u1 import U1 from .._src.irreps.su2 import SU2 from .._src.irreps.su2_real import SU2Real from .._src.irreps.o3_real import O3 @@ -9,6 +10,7 @@ __all__ = [ + "U1", "SU2", "SU2Real", "O3", diff --git a/lie_nn/test.py b/lie_nn/test.py new file mode 100644 index 0000000..1d7e880 --- /dev/null +++ b/lie_nn/test.py @@ -0,0 +1,4 @@ +from ._src.utils import check_algebra_vs_generators +from ._src.tests import check_representation_triplet + +__all__ = ["check_algebra_vs_generators", "check_representation_triplet"] diff --git a/lie_nn/util.py b/lie_nn/utils.py similarity index 87% rename from lie_nn/util.py rename to lie_nn/utils.py index ddb11ae..fbfb21f 100644 --- a/lie_nn/util.py +++ b/lie_nn/utils.py @@ -1,4 +1,4 @@ -from ._src.util import ( +from ._src.utils import ( as_approx_integer_ratio, limit_denominator, round_to_sqrt_rational, @@ -9,18 +9,19 @@ kron, direct_sum, gram_schmidt, + gram_schmidt_with_change_of_basis, extend_basis, nullspace, sequential_nullspace, infer_change_of_basis, basis_intersection, - check_algebra_vs_generators, infer_algebra_from_generators, permutation_sign, unique_with_tol, decompose_rep_into_irreps, is_irreducible, regular_representation, + are_isomorphic, ) __all__ = [ @@ -34,16 +35,17 @@ "kron", "direct_sum", "gram_schmidt", + "gram_schmidt_with_change_of_basis", "extend_basis", "nullspace", "sequential_nullspace", "infer_change_of_basis", "basis_intersection", - "check_algebra_vs_generators", "infer_algebra_from_generators", "permutation_sign", "unique_with_tol", "decompose_rep_into_irreps", "is_irreducible", "regular_representation", + "are_isomorphic", ] diff --git a/setup.cfg b/setup.cfg index bb0c2cd..5aa1f90 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,4 +26,4 @@ python_requires = >=3.7 install_requires = numpy scipy - multipledispatch + multimethod diff --git a/lie_nn/_src/change_algebra_test.py b/tests/change_algebra_test.py similarity index 66% rename from lie_nn/_src/change_algebra_test.py rename to tests/change_algebra_test.py index f639e25..fe2964e 100644 --- a/lie_nn/_src/change_algebra_test.py +++ b/tests/change_algebra_test.py @@ -9,5 +9,5 @@ def test_change_algebra(): Q = np.array([[1.0, 1.0, 0.0], [0.0, 1.0, 1.0], [1.0, 0.0, 1.0]]) / np.sqrt(2.0) rep = lie.change_algebra(rep, Q) - rep.check_algebra_vs_generators() - lie.check_representation_triplet(rep, rep, rep) + lie.test.check_algebra_vs_generators(rep.A, rep.X) + lie.test.check_representation_triplet(rep, rep, rep) diff --git a/tests/clebsch_gordan_test.py b/tests/clebsch_gordan_test.py new file mode 100644 index 0000000..1cd78c9 --- /dev/null +++ b/tests/clebsch_gordan_test.py @@ -0,0 +1,29 @@ +import numpy as np + +import lie_nn as lie + + +def test_cg_irrep(): + Q = np.array([[0, 1, 1], [2, 2, 1], [-3, 1, 0]]) + rep1 = lie.change_basis(Q, lie.irreps.SU2(2)) + + Q = np.array([[1, 2, 0], [2, 1, 0], [0, -1, 1.0]]) + rep2 = lie.change_basis(Q, lie.irreps.SU2(2)) + + Q = np.array([[1, 0, 0], [1, 1, -1], [0, 0, 1.0]]) + rep3 = lie.change_basis(Q, lie.irreps.SU2(2)) + + lie.test.check_representation_triplet(rep1, rep2, rep3) + + +def test_cg_generic(): + Q = np.array([[0, 1, 1], [2, 2, 1], [-3, 1, 0]]) + rep1 = lie.change_basis(Q, lie.irreps.SO3(1)) + + Q = np.random.randn(5, 5) + rep2 = lie.change_basis(Q, lie.irreps.SO3(2)) + + Q = np.array([[1, 0, 0], [1, 1, -1], [0, 0, 1.0]]) + rep3 = lie.change_basis(Q, lie.irreps.SO3(1)) + + lie.test.check_representation_triplet(rep1, rep2, rep3) diff --git a/lie_nn/_src/group_product_test.py b/tests/group_product_test.py similarity index 57% rename from lie_nn/_src/group_product_test.py rename to tests/group_product_test.py index 9ac5c33..6dc4d6c 100644 --- a/lie_nn/_src/group_product_test.py +++ b/tests/group_product_test.py @@ -6,12 +6,8 @@ def test_group_product_SO3_Z2_Z2(): r2eo = lie.group_product(lie.irreps.SO3(l=2), lie.irreps.Z2(p=1), lie.irreps.Z2(p=-1)) r2ee = lie.group_product(lie.irreps.SO3(l=2), lie.irreps.Z2(p=1), lie.irreps.Z2(p=1)) - r1oo.check_algebra_vs_generators() - r2eo.check_algebra_vs_generators() - r2ee.check_algebra_vs_generators() - - lie.check_representation_triplet(r1oo, r1oo, r2ee) - lie.check_representation_triplet(r1oo, r2eo, r2ee) + lie.test.check_representation_triplet(r1oo, r1oo, r2ee) + lie.test.check_representation_triplet(r1oo, r2eo, r2ee) def test_group_product_SU2_SU2(): @@ -24,9 +20,6 @@ def test_group_product_SU2_SU2(): j13 = lie.group_product(j1, j3) j22 = lie.group_product(j2, j2) - j11.check_algebra_vs_generators() - j12.check_algebra_vs_generators() - - lie.check_representation_triplet(j11, j11, j22) - lie.check_representation_triplet(j11, j12, j22) - lie.check_representation_triplet(j12, j13, j22) + lie.test.check_representation_triplet(j11, j11, j22) + lie.test.check_representation_triplet(j11, j12, j22) + lie.test.check_representation_triplet(j12, j13, j22) diff --git a/lie_nn/_src/infer_change_of_basis_test.py b/tests/infer_change_of_basis_test.py similarity index 52% rename from lie_nn/_src/infer_change_of_basis_test.py rename to tests/infer_change_of_basis_test.py index d066125..90cb0e2 100644 --- a/lie_nn/_src/infer_change_of_basis_test.py +++ b/tests/infer_change_of_basis_test.py @@ -3,23 +3,23 @@ def test_infer_change_of_basis(): - rep1 = lie.change_basis(lie.irreps.SU2(2), np.random.randn(3, 3)) - rep2 = lie.change_basis(lie.irreps.SU2(2), np.random.randn(3, 3)) + rep1 = lie.change_basis(np.random.randn(3, 3), lie.irreps.SU2(2)) + rep2 = lie.change_basis(np.random.randn(3, 3), lie.irreps.SU2(2)) Q = lie.infer_change_of_basis(rep1, rep2)[0] - rep3 = lie.change_basis(rep1, Q) + rep3 = lie.change_basis(Q, rep1) np.testing.assert_allclose(rep2.X, rep3.X) def test_infer_change_of_basis_generic(): - rep1 = lie.change_basis(lie.irreps.SU2(2), np.random.randn(3, 3)) - rep2 = lie.change_basis(lie.irreps.SU2(2), np.random.randn(3, 3)) + rep1 = lie.change_basis(np.random.randn(3, 3), lie.irreps.SU2(2)) + rep2 = lie.change_basis(np.random.randn(3, 3), lie.irreps.SU2(2)) rep1 = lie.GenericRep.from_rep(rep1) rep2 = lie.GenericRep.from_rep(rep2) Q = lie.infer_change_of_basis(rep1, rep2)[0] - rep3 = lie.change_basis(rep1, Q) + rep3 = lie.change_basis(Q, rep1) np.testing.assert_allclose(rep2.X, rep3.X) diff --git a/lie_nn/_src/irrep_test.py b/tests/irrep_test.py similarity index 63% rename from lie_nn/_src/irrep_test.py rename to tests/irrep_test.py index b55b647..0b674f1 100644 --- a/lie_nn/_src/irrep_test.py +++ b/tests/irrep_test.py @@ -2,10 +2,9 @@ import numpy as np import pytest -from lie_nn import TabulatedIrrep, clebsch_gordan, check_representation_triplet, GenericRep -from lie_nn.irreps import O3, SL2C, SO3, SO13, SU2Real, SU2, SU2_, SU3, SU4 -from lie_nn.util import round_to_sqrt_rational +import lie_nn as lie +from lie_nn.irreps import O3, SL2C, SO3, SO13, SU2, SU2_, SU3, SU4, SU2Real REPRESENTATIONS = [O3, SU2, SO3, SU2Real, SL2C, SO13, SU2_, SU3, SU4] @@ -25,32 +24,36 @@ def bunch_of_triplets(): @pytest.mark.parametrize("ir", bunch_of_reps()) -def test_algebra_vs_generators(ir: TabulatedIrrep): - ir.check_algebra_vs_generators() +def test_algebra_vs_generators(ir: lie.TabulatedIrrep): + lie.test.check_algebra_vs_generators(ir.A, ir.X) @pytest.mark.parametrize("ir1, ir2, ir3", bunch_of_triplets()) -def test_numerical_cg_vs_generators(ir1: TabulatedIrrep, ir2: TabulatedIrrep, ir3: TabulatedIrrep): - check_representation_triplet(GenericRep.from_rep(ir1), ir2, ir3) +def test_numerical_cg_vs_generators( + ir1: lie.TabulatedIrrep, ir2: lie.TabulatedIrrep, ir3: lie.TabulatedIrrep +): + lie.test.check_representation_triplet(lie.GenericRep.from_rep(ir1), ir2, ir3) @pytest.mark.parametrize("ir1, ir2, ir3", bunch_of_triplets()) def test_irreps_clebsch_gordan_vs_generators( - ir1: TabulatedIrrep, ir2: TabulatedIrrep, ir3: TabulatedIrrep + ir1: lie.TabulatedIrrep, ir2: lie.TabulatedIrrep, ir3: lie.TabulatedIrrep ): - check_representation_triplet(ir1, ir2, ir3) + lie.test.check_representation_triplet(ir1, ir2, ir3) @pytest.mark.parametrize("ir1, ir2, ir3", bunch_of_triplets()) -def test_recompute_clebsch_gordan(ir1: TabulatedIrrep, ir2: TabulatedIrrep, ir3: TabulatedIrrep): +def test_recompute_clebsch_gordan( + ir1: lie.TabulatedIrrep, ir2: lie.TabulatedIrrep, ir3: lie.TabulatedIrrep +): tol = 1e-14 - C1 = clebsch_gordan(ir1, ir2, ir3, round_fn=round_to_sqrt_rational) + C1 = lie.clebsch_gordan(ir1, ir2, ir3, round_fn=lie.utils.round_to_sqrt_rational) C2 = ir1.clebsch_gordan(ir1, ir2, ir3) assert np.allclose(C1, C2, atol=tol, rtol=tol) or np.allclose(C1, -C2, atol=tol, rtol=tol) @pytest.mark.parametrize("ir1, ir2, ir3", bunch_of_triplets()) -def test_selection_rule(ir1: TabulatedIrrep, ir2: TabulatedIrrep, ir3: TabulatedIrrep): +def test_selection_rule(ir1: lie.TabulatedIrrep, ir2: lie.TabulatedIrrep, ir3: lie.TabulatedIrrep): cg = ir1.clebsch_gordan(ir1, ir2, ir3) if ir3 in ir1 * ir2: diff --git a/tests/reduce_test.py b/tests/reduce_test.py new file mode 100644 index 0000000..0878ef0 --- /dev/null +++ b/tests/reduce_test.py @@ -0,0 +1,43 @@ +import numpy as np +import lie_nn as lie +import itertools + + +def test_reduce1(): + Q = np.array([[2.0, 0.0, 1.0], [1.0, -1.0, 0.0], [0.1, 0.1, 0.1]]) + rep1 = lie.change_basis(Q, lie.irreps.SU2(2)) + rep2 = lie.tensor_product(rep1, rep1) + rep2 = lie.GenericRep.from_rep(rep2) + + rep3 = lie.reduce(rep2) + + np.testing.assert_allclose(rep2.X, rep3.X, atol=1e-5) + + +def test_reduce2(): + rep = lie.tensor_product(lie.irreps.SU2(2), lie.irreps.SU2(2)) + rep = lie.direct_sum(rep, rep) + + Q = np.random.randn(rep.dim, rep.dim) + u, s, vh = np.linalg.svd(Q) + Q = u @ vh + + rep1 = lie.change_basis(Q, rep) + + rep1 = lie.reduce(rep1) + rep2 = lie.reduce(lie.GenericRep.from_rep(rep1)) + + np.testing.assert_allclose(rep1.X, rep2.X, atol=1e-5) + + +def test_reduce3(): + rep1 = lie.tensor_power(lie.irreps.SU2(1), 4) + + rep2 = lie.reduce(rep1) + + np.testing.assert_allclose(rep1.X, rep2.X, atol=1e-5) + + irs = [ir for mul, ir in rep2.reps] + + for ir1, ir2 in itertools.combinations(irs, 2): + assert not lie.are_isomorphic(ir1, ir2) diff --git a/tests/so13_test.py b/tests/so13_test.py new file mode 100644 index 0000000..5f07b81 --- /dev/null +++ b/tests/so13_test.py @@ -0,0 +1,7 @@ +import lie_nn as lie + + +def test_fourvector(): + vec = lie.irreps.SO13.four_vector() + + lie.test.check_representation_triplet(vec, vec, vec) diff --git a/lie_nn/_src/irreps/sun_test.py b/tests/sun_test.py similarity index 98% rename from lie_nn/_src/irreps/sun_test.py rename to tests/sun_test.py index 8f1456a..e747713 100644 --- a/lie_nn/_src/irreps/sun_test.py +++ b/tests/sun_test.py @@ -10,7 +10,7 @@ lower_ladder_matrices, upper_ladder_matrices, ) -from lie_nn.util import commutator +from lie_nn.utils import commutator j_max = 4 diff --git a/tests/tensor_product_test.py b/tests/tensor_product_test.py new file mode 100644 index 0000000..edf6e92 --- /dev/null +++ b/tests/tensor_product_test.py @@ -0,0 +1,41 @@ +import itertools + +import numpy as np +import pytest +from lie_nn.irreps import O3, SL2C, SO3, SO13, SU2 +import lie_nn as lie + + +def first_reps(IR: lie.TabulatedIrrep, n: int): + return list(itertools.islice(IR.iterator(), n)) + + +REPRESENTATIONS = [O3, SU2, SO3, SL2C, SO13] +# TODO: add SU2Real (or remove it completely) +# Note: SU2Real are not Irreps and this might be a problem +# TODO: resolve tensor_product_consistency for SU3, SU4 + + +@pytest.mark.parametrize( + "ir1, ir2", + sum((list(itertools.product(first_reps(IR, 3), repeat=2)) for IR in REPRESENTATIONS), []), +) +def test_tensor_product_consistency(ir1, ir2): + rep1 = lie.direct_sum(lie.multiply(2, ir1), ir2) + rep2 = lie.direct_sum(lie.multiply(3, ir1)) + + tp1 = lie.tensor_product(rep1, rep2) + tp2 = lie.tensor_product(lie.GenericRep.from_rep(rep1), lie.GenericRep.from_rep(rep2)) + + np.testing.assert_allclose(tp1.X, tp2.X, atol=1e-10) + + +def test_tensor_product_types(): + assert isinstance(lie.tensor_product(O3(l=1, p=1), O3(l=1, p=1)), lie.QRep) + assert isinstance(lie.tensor_product(O3(l=1, p=1), lie.multiply(2, O3(l=1, p=1))), lie.QRep) + + +def test_tensor_power_types(): + assert isinstance(lie.tensor_power(O3(l=1, p=1), 2), lie.QRep) + assert isinstance(lie.tensor_power(lie.multiply(2, O3(l=1, p=1)), 2), lie.QRep) + assert isinstance(lie.tensor_power(lie.GenericRep.from_rep(O3(l=1, p=1)), 2), lie.GenericRep) diff --git a/lie_nn/_src/util_test.py b/tests/utils_test.py similarity index 97% rename from lie_nn/_src/util_test.py rename to tests/utils_test.py index bbf20d9..283f9b9 100644 --- a/lie_nn/_src/util_test.py +++ b/tests/utils_test.py @@ -1,11 +1,11 @@ import numpy as np -from lie_nn.util import ( +from lie_nn.utils import ( nullspace, infer_change_of_basis, round_to_sqrt_rational, ) -from lie_nn._src.util import ( +from lie_nn._src.utils import ( as_approx_integer_ratio, limit_denominator, normalize_integer_ratio,