8000 Start fixing inheritance in Product/NFold/Landmarks/Curves Spaces by ninamiolane · Pull Request #1581 · geomstats/geomstats · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Start fixing inheritance in Product/NFold/Landmarks/Curves Spaces #1581

8000
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 4 additions & 127 deletions geomstats/geometry/product_manifold.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import geomstats.backend as gs
import geomstats.errors
from geomstats.geometry.manifold import Manifold
from geomstats.geometry.product_riemannian_metric import ProductRiemannianMetric
from geomstats.geometry.riemannian_metric import RiemannianMetric
from geomstats.geometry.product_riemannian_metric import (
NFoldMetric,
ProductRiemannianMetric,
)


class ProductManifold(Manifold):
Expand Down Expand Up @@ -487,128 +489,3 @@ def projection(self, point):
raise NotImplementedError(
"The base manifold does not implement a projection " "method."
)


class NFoldMetric(RiemannianMetric):
r"""Class for an n-fold product manifold :math:`M^n`.

Define a manifold as the product manifold of n copies of a given base manifold M.

Parameters
----------
base_metric : RiemannianMetric
Base metric.
n_copies : int
Number of replication of the base metric.
"""

def __init__(self, base_metric, n_copies):
geomstats.errors.check_integer(n_copies, "n_copies")
dim = n_copies * base_metric.dim
base_shape = base_metric.shape
super(NFoldMetric, self).__init__(dim=dim, shape=(n_copies, *base_shape))
self.base_shape = base_shape
self.base_metric = base_metric
self.n_copies = n_copies

def metric_matrix(self, base_point=None):
"""Compute the matrix of the inner-product.

Matrix of the inner-product defined by the Riemmanian metric
at point base_point of the manifold.

Parameters
----------
base_point : array-like, shape=[..., n_copies, *base_shape]
Point on the manifold at which to compute the inner-product matrix.
Optional, default: None.

Returns
-------
matrix : array-like, shape=[..., n_copies, dim, dim]
Matrix of the inner-product at the base point.
"""
point_ = gs.reshape(base_point, (-1, *self.base_shape))
matrices = self.base_metric.metric_matrix(point_)
dim = self.base_metric.dim
reshaped = gs.reshape(matrices, (-1, self.n_copies, dim, dim))
return gs.squeeze(reshaped)

def inner_product(self, tangent_vec_a, tangent_vec_b, base_point):
"""Compute the inner-product of two tangent vectors at a base point.

Inner product defined by the Riemannian metric at point `base_point`
between tangent vectors `tangent_vec_a` and `tangent_vec_b`.

Parameters
----------
tangent_vec_a : array-like, shape=[..., n_copies, *base_shape]
First tangent vector at base point.
tangent_vec_b : array-like, shape=[..., n_copies, *base_shape]
Second tangent vector at base point.
base_point : array-like, shape=[..., n_copies, *base_shape]
Point on the manifold.
Optional, default: None.

Returns
-------
inner_prod : array-like, shape=[...,]
Inner-product of the two tangent vectors.
"""
tangent_vec_a_, tangent_vec_b_, point_ = gs.broadcast_arrays(
tangent_vec_a, tangent_vec_b, base_point
)
point_ = gs.reshape(point_, (-1, *self.base_shape))
vector_a = gs.reshape(tangent_vec_a_, (-1, *self.base_shape))
vector_b = gs.reshape(tangent_vec_b_, (-1, *self.base_shape))
inner_each = self.base_metric.inner_product(vector_a, vector_b, point_)
reshaped = gs.reshape(inner_each, (-1, self.n_copies))
return gs.squeeze(gs.sum(reshaped, axis=-1))

def exp(self, tangent_vec, base_point, **kwargs):
"""Compute the Riemannian exponential of a tangent vector.

Parameters
----------
tangent_vec : array-like, shape=[..., n_copies, *base_shape]
Tangent vector at a base point.
base_point : array-like, shape=[..., n_copies, *base_shape]
Point on the manifold.
Optional, default: None.

Returns
-------
exp : array-like, shape=[..., n_copies, *base_shape]
Point on the manifold equal to the Riemannian exponential
of tangent_vec at the base point.
"""
tangent_vec, point_ = gs.broadcast_arrays(tangent_vec, base_point)
point_ = gs.reshape(point_, (-1, *self.base_shape))
vector_ = gs.reshape(tangent_vec, (-1, *self.base_shape))
each_exp = self.base_metric.exp(vector_, point_)
reshaped = gs.reshape(each_exp, (-1, self.n_copies) + self.base_shape)
return gs.squeeze(reshaped)

def log(self, point, base_point, **kwargs):
"""Compute the Riemannian logarithm of a point.

Parameters
----------
point : array-like, shape=[..., n_copies, *base_shape]
Point on the manifold.
base_point : array-like, shape=[..., n_copies, *base_shape]
Point on the manifold.
Optional, default: None.

Returns
-------
log : array-like, shape=[..., n_copies, *base_shape]
Tangent vector at the base point equal to the Riemannian logarithm
of point at the base point.
"""
point_, base_point_ = gs.broadcast_arrays(point, base_point)
base_point_ = gs.reshape(base_point_, (-1, *self.base_shape))
point_ = gs.reshape(point_, (-1, *self.base_shape))
each_log = self.base_metric.log(point_, base_point_)
reshaped = gs.reshape(each_log, (-1, self.n_copies) + self.base_shape)
return gs.squeeze(reshaped)
125 changes: 125 additions & 0 deletions geomstats/geometry/product_riemannian_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,128 @@ def log(self, point, base_point=None, point_type=None, **kwargs):
axis=-2,
)
return logs


class NFoldMetric(RiemannianMetric):
r"""Class for an n-fold product manifold :math:`M^n`.

Define a manifold as the product manifold of n copies of a given base manifold M.

Parameters
----------
base_metric : RiemannianMetric
Base metric.
n_copies : int
Number of replication of the base metric.
"""

def __init__(self, base_metric, n_copies):
geomstats.errors.check_integer(n_copies, "n_copies")
dim = n_copies * base_metric.dim
base_shape = base_metric.shape
super(NFoldMetric, self).__init__(dim=dim, shape=(n_copies, *base_shape))
self.base_shape = base_shape
self.base_metric = base_metric
self.n_copies = n_copies

def metric_matrix(self, base_point=None):
"""Compute the matrix of the inner-product.

Matrix of the inner-product defined by the Riemmanian metric
at point base_point of the manifold.

Parameters
----------
base_point : array-like, shape=[..., n_copies, *base_shape]
Point on the manifold at which to compute the inner-product matrix.
Optional, default: None.

Returns
-------
matrix : array-like, shape=[..., n_copies, dim, dim]
Matrix of the inner-product at the base point.
"""
point_ = gs.reshape(base_point, (-1, *self.base_shape))
matrices = self.base_metric.metric_matrix(point_)
dim = self.base_metric.dim
reshaped = gs.reshape(matrices, (-1, self.n_copies, dim, dim))
return gs.squeeze(reshaped)

def inner_product(self, tangent_vec_a, tangent_vec_b, base_point):
"""Compute the inner-product of two tangent vectors at a base point.

Inner product defined by the Riemannian metric at point `base_point`
between tangent vectors `tangent_vec_a` and `tangent_vec_b`.

Parameters
----------
tangent_vec_a : array-like, shape=[..., n_copies, *base_shape]
First tangent vector at base point.
tangent_vec_b : array-like, shape=[..., n_copies, *base_shape]
Second tangent vector at base point.
base_point : array-like, shape=[..., n_copies, *base_shape]
Point on the manifold.
Optional, default: None.

Returns
-------
inner_prod : array-like, shape=[...,]
Inner-product of the two tangent vectors.
"""
tangent_vec_a_, tangent_vec_b_, point_ = gs.broadcast_arrays(
tangent_vec_a, tangent_vec_b, base_point
)
point_ = gs.reshape(point_, (-1, *self.base_shape))
vector_a = gs.reshape(tangent_vec_a_, (-1, *self.base_shape))
vector_b = gs.reshape(tangent_vec_b_, (-1, *self.base_shape))
inner_each = self.base_metric.inner_product(vector_a, vector_b, point_)
reshaped = gs.reshape(inner_each, (-1, self.n_copies))
return gs.squeeze(gs.sum(reshaped, axis=-1))

def exp(self, tangent_vec, base_point, **kwargs):
"""Compute the Riemannian exponential of a tangent vector.

Parameters
----------
tangent_vec : array-like, shape=[..., n_copies, *base_shape]
Tangent vector at a base point.
base_point : array-like, shape=[..., n_copies, *base_shape]
Point on the manifold.
Optional, default: None.

Returns
-------
exp : array-like, shape=[..., n_copies, *base_shape]
Point on the manifold equal to the Riemannian exponential
of tangent_vec at the base point.
"""
tangent_vec, point_ = gs.broadcast_arrays(tangent_vec, base_point)
point_ = gs.reshape(point_, (-1, *self.base_shape))
vector_ = gs.reshape(tangent_vec, (-1, *self.base_shape))
each_exp = self.base_metric.exp(vector_, point_)
reshaped = gs.reshape(each_exp, (-1, self.n_copies) + self.base_shape)
return gs.squeeze(reshaped)

def log(self, point, base_point, **kwargs):
"""Compute the Riemannian logarithm of a point.

Parameters
----------
point : array-like, shape=[..., n_copies, *base_shape]
Point on the manifold.
base_point : array-like, shape=[..., n_copies, *base_shape]
Point on the manifold.
Optional, default: None.

Returns
-------
log : array-like, shape=[..., n_copies, *base_shape]
Tangent vector at the base point equal to the Riemannian logarithm
of point at the base point.
"""
point_, base_point_ = gs.broadcast_arrays(point, base_point)
base_point_ = gs.reshape(base_point_, (-1, *self.base_shape))
point_ = gs.reshape(point_, (-1, *self.base_shape))
each_log = self.base_metric.log(point_, base_point_)
reshaped = gs.reshape(each_log, (-1, self.n_copies) + self.base_shape)
return gs.squeeze(reshaped)
7 changes: 3 additions & 4 deletions tests/tests_geomstats/test_product_manifold.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
import geomstats.tests
from geomstats.geometry.euclidean import Euclidean
from geomstats.geometry.minkowski import Minkowski
from geomstats.geometry.product_manifold import (
NFoldManifold,
from geomstats.geometry.product_manifold import NFoldManifold, ProductManifold
from geomstats.geometry.product_riemannian_metric import (
NFoldMetric,
ProductManifold,
ProductRiemannianMetric,
)
from geomstats.geometry.product_riemannian_metric import ProductRiemannianMetric
from tests.conftest import Parametrizer
from tests.data.product_manifold_data import (
NFoldManifoldTestData,
Expand Down
0