8000 Refactor by mariogeiger · Pull Request #3 · lie-nn/lie-nn · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Refactor #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,17 @@ jobs:
python-version: [3.8, '3.11']

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install flake8
run: |
pip install flake8
- name: Lint with flake8
run: |
flake8 . --count --show-source --statistics
- name: Install dependencies
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
python -m pip install --upgrade pip
pip install wheel
pip install jax jaxlib
pip install torch
pip install .
- name: Install pytest
run: |
Expand Down
10 changes: 10 additions & 0 deletions IDEAS
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@

add conjugate class

add tensor product class ?



implementer:
- symmetric_tensor_power
- anti_symmetric_tensor_power
66 changes: 41 additions & 25 deletions lie_nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,69 @@
__version__ = "0.0.0"

from ._src.rep import Rep, GenericRep, check_representation_triplet
from ._src.irrep import TabulatedIrrep
from ._src.reduced_rep import MulIrrep, ReducedRep
from ._src.rep import (
Rep,
GenericRep,
Irrep,
TabulatedIrrep,
MulRep,
SumRep,
QRep,
ConjRep,
ReducedRep,
)

from lie_nn import utils as utils

from ._src.change_basis import change_basis
from ._src.reduce import reduce
from ._src.change_algebra import change_algebra
from ._src.tensor_product import tensor_product, tensor_power
from ._src.multiply import multiply
from ._src.direct_sum import direct_sum
from ._src.reduced_tensor_product import (
reduced_tensor_product_basis, # TODO: find a better API
reduced_symmetric_tensor_product_basis,
)
from ._src.clebsch_gordan import clebsch_gordan
from ._src.infer_change_of_basis import infer_change_of_basis
from ._src.conjugate import conjugate

from ._src.infer_change_of_basis import infer_change_of_basis
from ._src.properties import is_unitary, is_irreducible, are_isomorphic

from ._src.reduce import reduce
from ._src.tensor_product import tensor_product, tensor_power
from ._src.clebsch_gordan import clebsch_gordan
from ._src.real import make_explicitly_real, is_real
from ._src.properties import is_unitary
from ._src.group_product import group_product
from ._src.is_irreducible import is_irreducible
from ._src.symmetric_tensor_power import symmetric_tensor_power

from lie_nn import irreps as irreps
from lie_nn import util as util
from lie_nn import finite as finite
from lie_nn import test as test


__all__ = [
"Rep",
"GenericRep",
"check_representation_triplet",
"Irrep",
"TabulatedIrrep",
"MulIrrep",
"MulRep",
"SumRep",
"QRep",
"ConjRep",
"ReducedRep",
"utils",
"change_basis",
"reduce",
"change_algebra",
"multiply",
"direct_sum",
"conjugate",
"infer_change_of_basis",
"is_unitary",
"is_irreducible",
"are_isomorphic",
"reduce",
"tensor_product",
"tensor_power",
"direct_sum",
"reduced_tensor_product_basis",
"reduced_symmetric_tensor_product_basis",
"clebsch_gordan",
"infer_change_of_basis",
"conjugate",
"make_explicitly_real",
"is_real",
"is_unitary",
"group_product",
"is_irreducible",
"symmetric_tensor_power",
"irreps",
"util",
"finite",
"test",
]
23 changes: 20 additions & 3 deletions lie_nn/_src/change_algebra.py
10000
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import numpy as np
from multimethod import multimethod

from .rep import GenericRep, Rep
import lie_nn as lie


def change_algebra(rep: Rep, Q: np.ndarray) -> GenericRep:
@multimethod
def change_algebra(rep: lie.Rep, Q: np.ndarray) -> lie.GenericRep:
"""Apply change of basis to algebra.

.. math::
Expand All @@ -16,8 +18,23 @@ def change_algebra(rep: Rep, Q: np.ndarray) -> GenericRep:

"""
iQ = np.linalg.pinv(Q)
return GenericRep(
return lie.GenericRep(
A=np.einsum("ia,jb,abc,ck->ijk", Q, Q, rep.A, iQ),
X=np.einsum("ia,auv->iuv", Q, rep.X),
H=rep.H,
)


@multimethod
def change_algebra(rep: lie.QRep, Q: np.ndarray) -> lie.Rep: # noqa: F811
return lie.change_basis(rep.Q, change_algebra(rep.rep, Q))


@multimethod
def change_algebra(rep: lie.SumRep, Q: np.ndarray) -> lie.Rep: # noqa: F811
return lie.direct_sum(*[change_algebra(subrep, Q) for subrep in rep.reps])


@multimethod
def change_algebra(rep: lie.MulRep, Q: np.ndarray) -> lie.Rep: # noqa: F811
return lie.multiply(rep.mul, change_algebra(rep.rep, Q))
40 changes: 13 additions & 27 deletions lie_nn/_src/change_basis.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import numpy as np
from multipledispatch import dispatch
from multimethod import multimethod

from .irrep import TabulatedIrrep
from .reduced_rep import MulIrrep, ReducedRep
from .rep import GenericRep, Rep
import lie_nn as lie


@dispatch(Rep, object)
def change_basis(rep: Rep, Q: np.ndarray) -> GenericRep:
@multimethod
def change_basis(Q: np.ndarray, rep: lie.Rep) -> lie.QRep:
"""Apply change of basis to generators.

.. math::
Expand All @@ -19,29 +17,17 @@ def change_basis(rep: Rep, Q: np.ndarray) -> GenericRep:
v' = Q v

"""
iQ = np.linalg.pinv(Q)
return GenericRep(
A=rep.algebra(),
X=Q @ rep.continuous_generators() @ iQ,
H=Q @ rep.discrete_generators() @ iQ,
)
assert Q.shape == (rep.dim, rep.dim), (Q.shape, rep.dim)

if np.allclose(Q.imag, 0.0, atol=1e-10):
Q = Q.real

@dispatch(ReducedRep, object)
def change_basis(rep: ReducedRep, Q: np.ndarray) -> ReducedRep: # noqa: F811
Q = Q if rep.Q is None else Q @ rep.Q
return ReducedRep(A=rep.A, irreps=rep.irreps, Q=Q)
if np.allclose(Q, np.eye(rep.dim), atol=1e-10):
return rep

return lie.QRep(Q, rep, force=True)

@dispatch(MulIrrep, object)
def change_basis(rep: MulIrrep, Q: np.ndarray) -> ReducedRep: # noqa: F811
return ReducedRep(
A=rep.algebra(),
irreps=(rep,),
Q=Q,
)


@dispatch(TabulatedIrrep, object)
def change_basis(rep: TabulatedIrrep, Q: np.ndarray) -> ReducedRep: # noqa: F811
return change_basis(MulIrrep(mul=1, rep=rep), Q)
@multimethod
def change_basis(Q: np.ndarray, rep: lie.QRep) -> lie.Rep: # noqa: F811
return change_basis(Q @ rep.Q, rep.rep)
26 changes: 15 additions & 11 deletions lie_nn/_src/clebsch_gordan.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import numpy as np
from multipledispatch import dispatch
from multimethod import multimethod

from .infer_change_of_basis import infer_change_of_basis
from .irrep import TabulatedIrrep
from .rep import Rep
from .tensor_product import tensor_product

import lie_nn as lie

@dispatch(Rep, Rep, Rep)
def clebsch_gordan(rep1: Rep, rep2: Rep, rep3: Rep, *, round_fn=lambda x: x) -> np.ndarray:

@multimethod
def clebsch_gordan(
rep1: lie.Rep, rep2: lie.Rep, rep3: lie.Rep, *, round_fn=lambda x: x
) -> np.ndarray:
r"""Computes the Clebsch-Gordan coefficient of the triplet (rep1, rep2, rep3).

Args:
Expand All @@ -20,15 +20,19 @@ def clebsch_gordan(rep1: Rep, rep2: Rep, rep3: Rep, *, round_fn=lambda x: x) ->
The Clebsch-Gordan coefficient of the triplet (rep1, rep2, rep3).
It is an array of shape ``(number_of_paths, rep1.dim, rep2.dim, rep3.dim)``.
"""
tp = tensor_product(rep1, rep2)
cg = infer_change_of_basis(tp, rep3, round_fn=round_fn)
tp = lie.tensor_product(rep1, rep2)
cg = lie.infer_change_of_basis(tp, rep3, round_fn=round_fn)
cg = cg.reshape((-1, rep3.dim, rep1.dim, rep2.dim))
cg = np.moveaxis(cg, 1, 3)
return cg


@dispatch(TabulatedIrrep, TabulatedIrrep, TabulatedIrrep)
@multimethod
def clebsch_gordan( # noqa: F811
rep1: TabulatedIrrep, rep2: TabulatedIrrep, rep3: TabulatedIrrep, *, round_fn=lambda x: x
rep1: lie.TabulatedIrrep,
rep2: lie.TabulatedIrrep,
rep3: lie.TabulatedIrrep,
*,
round_fn=lambda x: x,
) -> np.ndarray:
return rep1.clebsch_gordan(rep1, rep2, rep3)
29 changes: 0 additions & 29 deletions lie_nn/_src/clebsch_gordan_test.py

This file was deleted.

29 changes: 19 additions & 10 deletions lie_nn/_src/conjugate.py
427E
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
import numpy as np
from multipledispatch import dispatch
from multimethod import multimethod

from .rep import GenericRep, Rep
import lie_nn as lie

# TODO(mario): Implement conjugate for Irreps

@multimethod
def conjugate(rep: lie.Rep) -> lie.GenericRep:
return lie.ConjRep(rep, force=True)

@dispatch(Rep)
def conjugate(rep: Rep) -> GenericRep:
return GenericRep(
A=rep.A,
X=np.conjugate(rep.X),
H=np.conjugate(rep.H),
)

@multimethod
def conjugate(rep: lie.QRep) -> lie.Rep: # noqa: F811
return lie.change_basis(np.conjugate(rep.Q), conjugate(rep.rep))


@multimethod
def conjugate(rep: lie.SumRep) -> lie.Rep: # noqa: F811
return lie.direct_sum(*[conjugate(subrep) for subrep in rep.reps])


@multimethod
def conjugate(rep: lie.MulRep) -> lie.Rep: # noqa: F811
return lie.multiply(rep.mul, conjugate(rep.rep))
Loading
0