8000 Make default_point_type a property by luisfpereira · Pull Request #1644 · geomstats/geomstats · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Make default_point_type a property #1644

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/learning_graph_embedding_and_predicting.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def main():
group_1 = mpatches.Patch(color=colors[1], label="Group 1")
group_2 = mpatches.Patch(color=colors[2], label="Group 2")

circle = visualization.PoincareDisk(point_type="ball")
circle = visualization.PoincareDisk(coords_type="ball")

_, ax = plt.subplots(figsize=(8, 8))
ax.axes.xaxis.set_visible(False)
Expand Down Expand Up @@ -59,7 +59,7 @@ def main():
labels = kmeans.predict(X=embeddings)

colors = ["g", "c", "m"]
circle = visualization.PoincareDisk(point_type="ball")
circle = visualization.PoincareDisk(coords_type="ball")
_, ax2 = plt.subplots(figsize=(8, 8))
circle.set_ax(ax2)
circle.draw(ax=ax2)
Expand Down Expand Up @@ -107,7 +107,7 @@ def main():
labels = kmedoid.predict(data=embeddings)

colors = ["g", "c", "m"]
circle = visualization.PoincareDisk(point_type="ball")
circle = visualization.PoincareDisk(coords_type="ball")
_, ax2 = plt.subplots(figsize=(8, 8))
circle.set_ax(ax2)
circle.draw(ax=ax2)
Expand Down
2 changes: 1 addition & 1 deletion examples/learning_graph_structured_data_h2.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def main():
"iteration %d loss_value %f", epoch, sum(total_loss, 0) / len(total_loss)
)

circle = visualization.PoincareDisk(point_type="ball")
circle = visualization.PoincareDisk(coords_type="ball")
plt.figure()
ax = plt.subplot(111)
circle.add_points(gs.array([[0, 0]]))
Expand Down
2 changes: 1 addition & 1 deletion examples/plot_geodesics_h2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import geomstats.visualization as visualization
from geomstats.geometry.hyperboloid import Hyperboloid

H2 = Hyperboloid(dim=2, coords_type="extrinsic")
H2 = Hyperboloid(dim=2, default_coords_type="extrinsic")
METRIC = H2.metric


Expand Down
2 changes: 1 addition & 1 deletion examples/plot_geodesics_poincare_polydisk.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

N_DISKS = 4

POINCARE_POLYDISK = PoincarePolydisk(n_disks=N_DISKS, coords_type="extrinsic")
POINCARE_POLYDISK = PoincarePolydisk(n_disks=N_DISKS, default_coords_type="extrinsic")
METRIC = POINCARE_POLYDISK.metric


Expand Down
6 changes: 3 additions & 3 deletions examples/plot_kmeans_manifolds.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def kmean_poincare_ball():
space="H2_poincare_disk",
marker=".",
color="black",
point_type=manifold.point_type,
coords_type=manifold.default_coords_type,
)

for i in range(n_clusters):
Expand All @@ -54,7 +54,7 @@ def kmean_poincare_ball():
space="H2_poincare_disk",
marker=".",
color=colors[i],
point_type=manifold.point_type,
coords_type=manifold.default_coords_type,
)

ax = visualization.plot(
Expand All @@ -64,7 +64,7 @@ def kmean_poincare_ball():
marker="*",
color="green",
s=100,
point_type=manifold.point_type,
coords_type=manifold.default_coords_type,
)

ax.set_title("Kmeans on Poincaré Ball Manifold")
Expand Down
6 changes: 3 additions & 3 deletions examples/plot_kmedoids_manifolds.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def kmedoids_poincare_ball():
space="H2_poincare_disk",
marker=".",
color="black",
point_type=manifold.point_type,
coords_type=manifold.default_coords_type,
)

for i in range(n_clusters):
Expand All @@ -54,7 +54,7 @@ def kmedoids_poincare_ball():
space="H2_poincare_disk",
marker=".",
color=colors[i],
point_type=manifold.point_type,
coords_type=manifold.default_coords_type,
)

ax = visualization.plot(
Expand All @@ -64,7 +64,7 @@ def kmedoids_poincare_ball():
marker="*",
color="green",
s=100,
point_type=manifold.point_type,
coords_type=manifold.default_coords_type,
)

ax.set_title("Kmedoids on Poincaré Ball Manifold")
Expand Down
2 changes: 1 addition & 1 deletion examples/plot_square_h2_poincare_half_plane.py
10000
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def main():
edge_points,
ax=ax,
space="H2_poincare_half_plane",
point_type="extrinsic",
coords_type="extrinsic",
marker=".",
color="black",
)
Expand Down
38 changes: 17 additions & 21 deletions geomstats/geometry/_hyperbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,14 @@ class _Hyperbolic:

Parameters
----------
dim : int
Dimension of the hyperbolic space.
point_type : str, {'extrinsic', 'intrinsic', etc}
Default coordinates to represent points in hyperbolic space.
Optional, default: 'extrinsic'.
scale : int
Scale of the hyperbolic space, defined as the set of points
in Minkowski space whose squared norm is equal to -scale.
Optional, default: 1.
"""

def __init__(self, dim, scale=1, **kwargs):
super(_Hyperbolic, self).__init__(dim=dim, **kwargs)
self.dim = dim
def __init__(self, scale=1, **kwargs):
super(_Hyperbolic, self).__init__(**kwargs)
self.scale = scale

@staticmethod
Expand Down Expand Up @@ -354,7 +348,8 @@ def change_coordinates_system(
-------
point_to : array-like, shape=[..., dim]
or shape=[n_sample, dim + 1]
Point in hyperbolic space in coordinates given by to_point_type.
Point in hyperbolic space in coordinates given by
to_coordinates_system.
"""
coords_transform = {
"ball-extrinsic": _Hyperbolic._ball_to_extrinsic_coordinates,
Expand Down Expand Up @@ -389,10 +384,10 @@ def to_coordinates(self, point, to_coords_type="ball"):
Returns
-------
point_to : array-like, shape=[..., {dim, dim + 1}]
Point in hyperbolic space in coordinates given by to_point_type.
Point in hyperbolic space in coordinates given by to_coords_type.
"""
return _Hyperbolic.change_coordinates_system(
point, self.coords_type, to_coords_type
point, self.default_coords_type, to_coords_type
)

def from_coordinates(self, point, from_coords_type):
Expand All @@ -404,7 +399,7 @@ def from_coordinates(self, point, from_coords_type):
Parameters
----------
point : array-like, shape=[..., {dim, dim + 1}]
Point in hyperbolic space in coordinates from_point_type.
Point in hyperbolic space in coordinates from_coords_type.
from_coords_type : str, {'ball', 'extrinsic', 'intrinsic', ...}
Coordinates type.

Expand All @@ -414,7 +409,7 @@ def from_coordinates(self, point, from_coords_type):
Point in hyperbolic space.
"""
return _Hyperbolic.change_coordinates_system(
point, from_coords_type, self.coords_type
point, from_coords_type, self.default_coords_type
)

def random_point(self, n_samples=1, bound=1.0):
Expand Down Expand Up @@ -444,7 +439,7 @@ def random_point(self, n_samples=1, bound=1.0):
samples = bound * 2.0 * (gs.random.rand(*size) - 0.5)

samples = _Hyperbolic.change_coordinates_system(
samples, "intrinsic", self.coords_type
samples, "intrinsic", self.default_coords_type
)

if n_samples == 1:
Expand All @@ -459,19 +454,20 @@ class HyperbolicMetric(RiemannianMetric):
----------
dim : int
Dimension of the hyperbolic space.
point_type : str, {'extrinsic', 'intrinsic', etc}, optional
default_coords_type : str, {'extrinsic', 'intrinsic', etc}, optional
Default coordinates to represent points in hyperbolic space.
scale : int, optional
Scale of the hyperbolic space, defined as the set of points
in Minkowski space whose squared norm is equal to -scale.
"""

default_point_type = "vector"
default_coords_type = "extrinsic"

def __init__(self, dim, scale=1):
super(HyperbolicMetric, self).__init__(dim=dim, signature=(dim, 0))
self.point_type = HyperbolicMetric.default_point_type
def __init__(self, dim, scale=1, default_coords_type="extrinsic"):
super(HyperbolicMetric, self).__init__(
dim=dim,
signature=(dim, 0),
shape=(dim + 1,),
default_coords_type=default_coords_type,
)

self.scale = scale

Expand Down
11 changes: 1 addition & 10 deletions geomstats/geometry/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import geomstats.backend as gs
from geomstats.geometry.manifold import Manifold

POINT_TYPES = {1: "vector", 2: "matrix"}


class VectorSpace(Manifold, abc.ABC):
"""Abstract class for vector spaces.
Expand All @@ -19,9 +17,6 @@ class VectorSpace(Manifold, abc.ABC):
shape : tuple
Shape of the elements of the vector space. The dimension is the
product of these values by default.
default_point_type : str, {'vector', 'matrix'}
Point type.
Optional, default: 'vector'.
"""

def __init__(self, shape, **kwargs):
Expand Down Expand Up @@ -187,10 +182,7 @@ def __init__(
):
kwargs.setdefault("shape", embedding_space.shape)
super(LevelSet, self).__init__(
dim=dim,
default_point_type=embedding_space.default_point_type,
default_coords_type=default_coords_type,
**kwargs
dim=dim, default_coords_type=default_coords_type, **kwargs
)
self.embedding_space = embedding_space
self.embedding_metric = embedding_space.metric
Expand Down Expand Up @@ -335,7 +327,6 @@ class OpenSet(Manifold, abc.ABC):
"""

def __init__(self, dim, ambient_space, **kwargs):
kwargs.setdefault("default_point_type", ambient_space.default_point_type)
kwargs.setdefault("shape", ambient_space.shape)
super().__init__(dim=dim, **kwargs)
self.ambient_space = ambient_space
Expand Down
41 changes: 14 additions & 27 deletions geomstats/geometry/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from geomstats.integrator import integrate

N_STEPS = 100
POINT_TYPES = {1: "vector", 2: "matrix", 3: "matrix"}


class Connection(ABC):
Expand All @@ -25,32 +24,31 @@ class Connection(ABC):
shape : tuple of int
Shape of one element of the manifold.
Optional, default : (dim, ).
default_point_type : str, {\'vector\', \'matrix\'}
Point type.
Optional, default: \'vector\'.
default_coords_type : str, {\'intrinsic\', \'extrinsic\', etc}
Coordinate type.
Optional, default: \'intrinsic\'.
"""

def __init__(
self, dim, shape=None, default_point_type=None, default_coords_type="intrinsic"
):
def __init__(self, dim, shape=None, default_coords_type="intrinsic"):
geomstats.errors.check_integer(dim, "dim")

if shape is None:
shape = (dim,)
if default_point_type is None:
default_point_type = POINT_TYPES[len(shape)]

geomstats.errors.check_integer(dim, "dim")
geomstats.errors.check_parameter_accepted_values(
default_point_type, "default_point_type", ["vector", "matrix"]
)

self.dim = dim
self.shape = shape
self.default_point_type = default_point_type
self.default_coords_type = default_coords_type

@property
def default_point_type(self):
"""Point type.

`vector` or `matrix`.
"""
if len(self.shape) == 1:
return "vector"
return "matrix"

def christoffels(self, base_point):
"""Christoffel symbols associated with the connection.

Expand Down Expand Up @@ -89,15 +87,7 @@ def geodesic_equation(self, state, _time):
equation = -gs.einsum("...kj,...j->...k", equation, velocity)
return gs.stack([velocity, equation])

def exp(
self,
tangent_vec,
base_point,
n_steps=N_STEPS,
step="euler",
point_type=None,
**kwargs
):
def exp(self, tangent_vec, base_point, n_steps=N_STEPS, step="euler", **kwargs):
"""Exponential map associated to the affine connection.

Exponential map at base_point of tangent_vec computed by integration
Expand All @@ -116,9 +106,6 @@ def exp(
step : str, {'euler', 'rk4'}
The numerical scheme to use for integration.
Optional, default: 'euler'.
point_type : str, {'vector', 'matrix'}
Type of representation used for points.
Optional, default: None.

Returns
-------
Expand Down
19 changes: 14 additions & 5 deletions geomstats/geometry/discrete_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def __init__(
super(DiscreteCurves, self).__init__(
dim=dim,
shape=(k_sampling_points,) + ambient_manifold.shape,
default_point_type="matrix",
**kwargs,
)
self.ambient_manifold = ambient_manifold
Expand Down Expand Up @@ -281,7 +280,7 @@ def __init__(self, ambient_manifold, k_sampling_points=10):
dim = ambient_manifold.dim * (k_sampling_points - 1)
super(ClosedDiscreteCurves, self).__init__(
dim=dim,
shape=(),
shape=(k_sampling_points,) + ambient_manifold.shape,
submersion=None,
tangent_submersion=None,
value=None,
Expand Down Expand Up @@ -583,7 +582,11 @@ class L2CurvesMetric(RiemannianMetric):
"""

def __init__(self, ambient_manifold, ambient_metric=None):
super(L2CurvesMetric, self).__init__(dim=math.inf, signature=(math.inf, 0, 0))
super(L2CurvesMetric, self).__init__(
dim=math.inf,
signature=(math.inf, 0, 0),
shape=(None,) + ambient_manifold.shape,
)
if ambient_metric is None:
if hasattr(ambient_manifold, "metric"):
self.ambient_metric = ambient_manifold.metric
Expand Down Expand Up @@ -831,7 +834,9 @@ def __init__(
self, a, b, ambient_manifold=R2, ambient_metric=None, translation_invariant=True
):
super(ElasticMetric, self).__init__(
dim=math.inf, signature=(math.inf, 0, 0), default_point_type="matrix"
dim=math.inf,
signature=(math.inf, 0, 0),
shape=(None,) + ambient_manifold.shape,
)
self.ambient_metric = ambient_metric
if ambient_metric is None:
Expand Down Expand Up @@ -2130,7 +2135,11 @@ class SRVQuotientMetric(QuotientMetric):
def __init__(self, ambient_manifold, k_sampling_points=10):
dim = ambient_manifold.dim * k_sampling_points
bundle = SRVShapeBundle(ambient_manifold, dim)
super(SRVQuotientMetric, self).__init__(fiber_bundle=bundle, dim=dim)
super(SRVQuotientMetric, self).__init__(
fiber_bundle=bundle,
dim=dim,
shape=(k_sampling_points,) + ambient_manifold.shape,
)

def geodesic(self, initial_point, end_point, threshold=1e-3):
"""Geodesic for the quotient SRV Metric.
Expand Down
Loading
0