8000 SRV changes by shubhamtalbar96 · Pull Request #1525 · geomstats/geomstats · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

SRV changes #1525

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
May 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Fi 8000 lter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion geomstats/geometry/discrete_curves.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -423,10 +423,11 @@ def __init__(self, ambient_manifold, metric=None, translation_invariant=True):
)
else:
self.ambient_metric = metric
self.default_point_type = "matrix"
self.l2_metric = L2CurvesMetric(ambient_manifold, metric)
self.translation_invariant = translation_invariant

def srv_transform(self, curve):
def srv_transform(self, curve, tol=gs.atol):
"""Square Root Velocity Transform (SRVT).

Compute the square root velocity representation of a curve. The
Expand All @@ -440,11 +441,23 @@ def srv_transform(self, curve):
curve : array-like, shape=[..., n_sampling_points, ambient_dim]
Discrete curve.

tol : float
Tolerance value to decide duplicity of two consecutive sample
points on a given Discrete Curve.

Returns
-------
srv : array-like, shape=[..., n_sampling_points - 1, ambient_dim]
Square-root velocity representation of a discrete curve.
"""
if gs.any(
self.ambient_metric.norm(curve[..., 1:, :] - curve[..., :-1, :]) < tol
):
raise AssertionError(
"The square root velocity framework "
"is only defined for discrete curves "
"with distinct consecutive sample points."
)
curve_ndim = gs.ndim(curve)
curve = gs.to_ndarray(curve, to_ndim=3)
n_curves, n_sampling_points, n_coords = curve.shape
Expand Down
9 ch 8000 anges: 8 additions & 1 deletion geomstats/learning/frechet_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import geomstats.backend as gs
import geomstats.errors as error
import geomstats.vectorization
from geomstats.geometry.discrete_curves import SRVMetric
from geomstats.geometry.hypersphere import Hypersphere

EPSILON = 1e-4
Expand Down Expand Up @@ -138,7 +139,13 @@ def _default_gradient_descent(
sq_dists_between_iterates.append(sq_dist)

var_is_0 = gs.isclose(var, 0.0)
sq_dist_is_small = gs.less_equal(sq_dist, epsilon * metric.dim)

metric_dim = metric.dim
if isinstance(metric, SRVMetric):
metric_dim = tangent_mean.shape[-2] * tangent_mean.shape[-1]

sq_dist_is_small = gs.less_equal(sq_dist, epsilon * metric_dim)

condition = ~gs.logical_or(var_is_0, sq_dist_is_small)
if not (condition or iteration == 0):
break
Expand Down
13 changes: 13 additions & 0 deletions geomstats/learning/riemannian_mean_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,16 @@ def predict(self, points):
indices = self.metric.closest_neighbor_index(points, self.centers)

return gs.take(self.centers, indices, axis=0)

def predict_labels(self, points):
"""Predict the closest cluster label each point in 'points' belongs to.

Parameters
----------
points : array-like, shape=[..., n_features]
Clusters of points.
"""
if self.centers is None:
raise Exception("Not fitted")

return self.metric.closest_neighbor_index(points, self.centers)
Loading
0