8000 Add strategy to vectorize tests during data generation by luisfpereira · Pull Request #1613 · geomstats/geomstats · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add strategy to vectorize tests during data generation #1613

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 5, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 139 additions & 3 deletions tests/data_generation.py
8000
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import copy
import itertools

import pytest

import geomstats.backend as gs
from geomstats.errors import check_parameter_accepted_values


def better_squeeze(array):
Expand All @@ -12,6 +14,21 @@ def better_squeeze(array):
return array


def _expand_point(point):
return gs.expand_dims(point, 0)


def _repeat_point(point, n_reps=2):
if not gs.is_array(point):
return [point] * n_reps

return gs.repeat(_expand_point(point), n_reps, axis=0)


def _expand_and_repeat_point(point, n_reps=2):
return _expand_point(point), _repeat_point(point, n_reps=n_reps)


class TestData:
"""Class for TestData objects."""

Expand Down Expand Up @@ -47,6 +64,122 @@ def generate_tests(self, smoke_test_data, random_test_data=[]):

return tests

def _filter_combs(self, combs, vec_type, threshold):
MAP_VEC_TYPE = {
"repeat-first": 1,
"repeat-second": 0,
}
index = MAP_VEC_TYPE[vec_type]
other_index = (index + 1) % 2

for comb in combs.copy():
if comb[index] >= threshold and comb[index] != comb[other_index]:
combs.remove(comb)

return combs

def _generate_datum_vectorization_tests(
self, datum, comb_indices, arg_names, expected_name, check_expand=True, n_reps=2
):

if expected_name is not None:
has_expected = True
expected = datum.get(expected_name)
expected_rep = _repeat_point(expected, n_reps=n_reps)
else:
has_expected = False

args_combs = []
for arg_name in arg_names:
arg = datum.get(arg_name)
arg_combs = [arg]
if check_expand:
arg_combs.extend(_expand_and_repeat_point(arg, n_reps=n_reps))
else:
arg_combs.append(_repeat_point(arg, n_reps=n_reps))

args_combs.append(arg_combs)

new_data = []
for indices in comb_indices:
new_datum = copy.copy(datum)

if has_expected:
new_datum[expected_name] = (
expected_rep if (1 + int(check_expand)) in indices else expected
)

for arg_i, (index, arg_name) in enumerate(zip(indices, arg_names)):
new_datum[arg_name] = args_combs[arg_i][index]

new_data.append(new_datum)

return new_data

def generate_vectorization_tests(
self,
data,
arg_names,
expected_name=None,
check_expand=True,
n_reps=2,
vectorization_type="sym",
):
"""Create new data with vectorized version of inputs.

Parameters
----------
data : list of dict
Data. Each to vectorize.
arg_names: list
Name of inputs to vectorize.
expected_name: str
Output name in case it needs to be repeated.
check_expand: bool
If `True`, expanded version of each input will be tested.
n_reps: int
Number of times the input points should be repeated.
vectorization_type: str
Possible values are `sym`, `repeat-first`, `repeat-second`.
`repeat-first` and `repeat-second` only valid for two argument case.
`repeat-first` and `repeat-second` test asymmetric cases, repeating
only first or second input, respectively.
"""
check_parameter_accepted_values(
vectorization_type,
"vectorization_type",
["sym", "repeat-first", "repeat-second"],
)

n_args = len(arg_names)
if n_args != 2 and vectorization_type != "sym":
raise NotImplementedError(
f"`{vectorization_type} only implemented for 2 arguments."
)

n_indices = 2 + int(check_expand)
comb_indices = list(itertools.product(*[range(n_indices)] * len(arg_names)))
if n_args == 2 and vectorization_type != "sym":
comb_indices = self._filter_combs(
comb_indices, vectorization_type, threshold=1 + int(check_expand)
)

new_data = []
for datum in data:
new_data.extend(
self._generate_datum_vectorization_tests(
datum,
comb_indices,
arg_names,
expected_name=expected_name,
check_expand=check_expand,
n_reps=n_reps,
)
)

# TODO: mark as vec?
return self.generate_tests(new_data)


class _ManifoldTestData(TestData):
"""Class for ManifoldTestData: data to test manifold properties."""
Expand All @@ -69,7 +202,8 @@ def random_point_belongs_test_data(
return self.generate_tests([], random_data)

def projection_belongs_test_data(self, belongs_atol=gs.atol):
"""Generate data to check that a point projected on a manifold belongs to the manifold.
"""Generate data to check that a point projected on a manifold belongs
to the manifold.

Parameters
----------
Expand Down Expand Up @@ -164,7 +298,8 @@ def to_tangent_is_tangent_in_ambient_space_test_data(self):

class _LevelSetTestData(_ManifoldTestData):
def intrinsic_after_extrinsic_test_data(self):
"""Generate data to check that changing coordinate system twice gives back the point.
"""Generate data to check that changing coordinate system twice gives
back the point.

Assumes that random_point generates points in extrinsic coordinates.
"""
Expand All @@ -180,7 +315,8 @@ def intrinsic_after_extrinsic_test_data(self):
return self.generate_tests([], random_data)

def extrinsic_after_intrinsic_test_data(self):
"""Generate data to check that changing coordinate system twice gives back the point.
"""Generate data to check that changing coordinate system twice gives
back the point.

Assumes that the first elements in space_args is the dimension of the space.
"""
Expand Down
0