8000 Geometric median by shubhamtalbar96 · Pull Request #1565 · geomstats/geomstats · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Geometric median #1565

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Jul 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion geomstats/_backend/autograd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
cumsum,
diag_indices,
diagonal,
divide,
dot,
)
from autograd.numpy import dtype as _ndtype
Expand Down Expand Up @@ -430,6 +429,12 @@ def mat_from_diag_triu_tril(diag, tri_upp, tri_low):
return mat


def divide(a, b, ignore_div_zero=False):
if ignore_div_zero is False:
return _np.divide(a, b)
return _np.divide(a, b, out=_np.zeros_like(a), where=b != 0)


def ravel_tril_indices(n, k=0, m=None):
if m is None:
size = (n, n)
Expand Down
7 changes: 6 additions & 1 deletion geomstats/_backend/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
cumsum,
diag_indices,
diagonal,
divide,
dot,
)
from numpy import dtype as _ndtype # NOQA
Expand Down Expand Up @@ -428,6 +427,12 @@ def mat_from_diag_triu_tril(diag, tri_upp, tri_low):
return mat


def divide(a, b, ignore_div_zero=False):
if ignore_div_zero is False:
return _np.divide(a, b)
return _np.divide(a, b, out=_np.zeros_like(a), where=b != 0)


def ravel_tril_indices(n, k=0, m=None):
if m is None:
size = (n, n)
Expand Down
8 changes: 7 additions & 1 deletion geomstats/_backend/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
cos,
cosh,
cross,
divide,
empty_like,
erf,
exp,
Expand Down Expand Up @@ -719,6 +718,13 @@ def mat_from_diag_triu_tril(diag, tri_upp, tri_low):
return mat


def divide(a, b, ignore_div_zero=False):
if ignore_div_zero is False:
return _torch.divide(a, b)
quo = _torch.divide(a, b)
return _torch.nan_to_num(quo, nan=0.0, posinf=0.0, neginf=0.0)


def ravel_tril_indices(n, k=0, m=None):
if m is None:
size = (n, n)
Expand Down
7 changes: 6 additions & 1 deletion geomstats/_backend/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from tensorflow import (
cos,
cosh,
divide,
equal,
exp,
expand_dims,
Expand Down Expand Up @@ -924,6 +923,12 @@ def mat_from_diag_triu_tril(diag, tri_upp, tri_low):
return mat


def divide(a, b, ignore_div_zero=False):
if ignore_div_zero is False:
return _tf.math.divide(a, b)
return _tf.math.divide_no_nan(a, b)


def _ravel_multi_index(multi_index, shape):
strides = _tf.math.cumprod(shape, exclusive=True, reverse=True)
return _tf.reduce_sum(multi_index * _tf.expand_dims(strides, 1), axis=0)
Expand Down
116 changes: 116 additions & 0 deletions geomstats/learning/geometric_median.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""Geometric Median Estimation."""

import logging

import geomstats.backend as gs


class GeometricMedian:
r"""Using Weiszfeld Algorithm for finding Geometric Median on Manifolds.

Parameters
----------
metric : RiemannianMetric
Riemannian metric.
weights : array-like, [N]
Weights for weighted sum.
Optional, default : None
If None equal weights (1/N) are used for all points
max_iter : int
Maximum number of iterations for the algorithm.
Optional, default : 100
lr : float
Learning rate to be used for the algorithm.
Optional, default : 1.0
init : array-like
Initialization to be used in the start.
Optional, default : None
print_every : int
Print updated median after print_every iterations.
Optional, default : None

References
----------
.. [BJL2008]_ Bhatia, Jain, Lim. "Robust statistics on
Riemannian manifolds via the geometric median"
"""

def __init__(self, metric, max_iter=100, lr=1.0, init=None, print_every=None):
se 6D47 lf.metric = metric
self.max_iter = max_iter
self.lr = lr
self.init = init
self.print_every = print_every
self.estimate_ = None

def _iterate_once(self, current_median, X, weights, lr):
"""Compute a single iteration of Weiszfeld Algorithm.

Parameters
----------
current_median : array-like, shape={representation shape}
current median.
X : array-like, shape=[..., {representation shape}]
data for which geometric has to be found.
weights : array-like, shape=[N]
weights for weighted sum.
lr : float
learning rate for the current iteration.

Returns
-------
updated_median: array-like, shape={representation shape}
updated median after single iteration.
"""

def _scalarmul(scalar, array):
if gs.ndim(array) == 2:
return scalar[:, None] * array

return scalar[:, None, None] * array

dists = self.metric.dist(current_median, X)

if gs.allclose(dists, 0.0):
return current_median

logs = self.metric.log(X, current_median)
mul = gs.divide(weights, dists, ignore_div_zero=True)
v_k = gs.sum(_scalarmul(mul, logs), axis=0) / gs.sum(mul)
updated_median = self.metric.exp(lr * v_k, current_median)
return updated_median

def fit(self, X, weights=None):
r"""Compute the Geometric Median.

Parameters
----------
X : array-like, shape=[n_samples, n_features]
Training input samples.
weights : array-like, shape=[...]
Weights associated to the points.
Optional, default: None, in which case
it is equally weighted

Returns
-------
self : object
Returns self.
"""
n_points = X.shape[0]
current_median = X[-1] if self.init is None else self.init
if weights is None:
weights = gs.ones(n_points) / n_points

for iteration in range(1, self.max_iter + 1):
new_median = self._iterate_once(current_median, X, weights, self.lr)
shift = self.metric.dist(new_median, current_median)
if shift < gs.atol:
break

current_median = new_median
if self.print_every and (iteration + 1) % self.print_every == 0:
logging.info(f"iteration: {iteration} curr_median: {current_median}")
self.estimate_ = current_median

return self
31 changes: 31 additions & 0 deletions tests/data/geometric_median_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import geomstats.backend as gs
from geomstats.geometry.hypersphere import Hypersphere
from geomstats.geometry.spd_matrices import SPDMatrices, SPDMetricAffine
from geomstats.learning.geometric_median import GeometricMedian
from tests.data_generation import TestData


class GeometricMedianTestData(TestData):
def fit_test_data(self):
estimator = GeometricMedian(SPDMetricAffine(n=2))
X = gs.array([[[1.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [0.0, 1.0]]])
expected = gs.array([[1.0, 0.0], [0.0, 1.0]])

smoke_data = [dict(estimator=estimator, X=X, expected=expected)]

return self.generate_tests(smoke_data)

def fit_sanity_test_data(self):
n = 4
estimator_1 = GeometricMedian(SPDMetricAffine(n))
space_1 = SPDMatrices(n)

space_2 = Hypersphere(2)
estimator_2 = GeometricMedian(space_2.metric)

smoke_data = [
dict(estimator=estimator_1, space=space_1),
dict(estimator=estimator_2, space=space_2),
]

return self.generate_tests(smoke_data)
22 changes: 22 additions & 0 deletions tests/tests_geomstats/test_geometric_median.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Methods for testing the Geometric Median."""


from tests.conftest import Parametrizer, TestCase
from tests.data.geometric_median_data import GeometricMedianTestData


class TestGeometricMedian(TestCase, metaclass=Parametrizer):
testing_data = GeometricMedianTestData()

def test_fit(self, estimator, X, expected):
estimator.fit(X)
self.assertAllClose(estimator.estimate_, expected)

def test_fit_sanity(self, estimator, space):
"""Test estimate belongs to space."""
n_samples = 5

X = space.random_point(n_samples)
estimator.fit(X)

self.assertTrue(space.belongs(estimator.estimate_))
0