From 6b216affb48248ae4f5c30033fd8467d215af2f8 Mon Sep 17 00:00:00 2001 From: "L. F. Pereira" Date: Tue, 2 Aug 2022 10:10:45 +0200 Subject: [PATCH 1/5] Add first draft of generate_vec_tests --- tests/data_generation.py | 74 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/tests/data_generation.py b/tests/data_generation.py index 1a664cccb..b3183ca83 100644 --- a/tests/data_generation.py +++ b/tests/data_generation.py @@ -1,3 +1,4 @@ +import copy import itertools import pytest @@ -12,6 +13,18 @@ def better_squeeze(array): return array +def _expand_point(point): + return gs.expand_dims(point, 0) + + +def _repeat_point(point, n_reps=2): + return gs.repeat(_expand_point(point), n_reps, axis=0) + + +def _expand_and_repeat_point(point, n_reps=2): + return _expand_point(point), _repeat_point(point, n_reps=n_reps) + + class TestData: """Class for TestData objects.""" @@ -47,6 +60,67 @@ def generate_tests(self, smoke_test_data, random_test_data=[]): return tests + def _generate_datum_vec_tests( + self, datum, arg_names, expected_name, check_expand=True, n_reps=2 + ): + + # TODO: improving handling of expected + if expected_name is not None: + has_expected = True + expected = datum.get(expected_name) + expected_rep = _repeat_point(expected, n_reps=n_reps) + else: + has_expected = False + + args_combs = [] + for arg_name in arg_names: + arg = datum.get(arg_name) + arg_combs = [arg] + if check_expand: + arg_combs.extend(_expand_and_repeat_point(arg, n_reps=n_reps)) + else: + arg_combs.append(_repeat_point(arg, n_reps=n_reps)) + + args_combs.append(arg_combs) + + n_indices = 2 + int(check_expand) + comb_indices = list(itertools.product(*[range(n_indices)] * len(arg_names))) + + new_data = [] + for indices in comb_indices: + new_datum = copy.copy(datum) + + if has_expected: + new_datum[expected_name] = ( + expected_rep if (1 + int(check_expand)) in indices else expected + ) + + for arg_i, (index, arg_name) in enumerate(zip(indices, arg_names)): + new_datum[arg_name] = args_combs[arg_i][index] + + new_data.append(new_datum) + + return new_data + + def generate_vec_tests( + self, data, arg_names, expected_name=None, check_expand=True, n_reps=2 + ): + # TODO: vectorization type (e.g. symmetric) + new_data = [] + for datum in data: + new_data.extend( + self._generate_datum_vec_tests( + datum, + arg_names, + expected_name=expected_name, + check_expand=check_expand, + n_reps=n_reps, + ) + ) + + # TODO: mark as vec? + return self.generate_tests(new_data) + class _ManifoldTestData(TestData): """Class for ManifoldTestData: data to test manifold properties.""" From 12a0ed0bdde87c62798dbff5fb6b4161fb08cbb8 Mon Sep 17 00:00:00 2001 From: "L. F. Pereira" Date: Tue, 2 Aug 2022 11:25:18 +0200 Subject: [PATCH 2/5] Add possibility of choosing vectorization type --- tests/data_generation.py | 46 ++++++++++++++++++++++++++++++++++------ 1 file changed, 40 insertions(+), 6 deletions(-) diff --git a/tests/data_generation.py b/tests/data_generation.py index b3183ca83..5ebfd8d37 100644 --- a/tests/data_generation.py +++ b/tests/data_generation.py @@ -4,6 +4,7 @@ import pytest import geomstats.backend as gs +from geomstats.errors import check_parameter_accepted_values def better_squeeze(array): @@ -60,8 +61,22 @@ def generate_tests(self, smoke_test_data, random_test_data=[]): return tests + def _filter_combs(self, combs, vec_type, threshold): + MAP_VEC_TYPE = { + "asym-left": 1, + "asym-right": 0, + } + index = MAP_VEC_TYPE[vec_type] + other_index = (index + 1) % 2 + + for comb in combs.copy(): + if comb[index] >= threshold and comb[index] != comb[other_index]: + combs.remove(comb) + + return combs + def _generate_datum_vec_tests( - self, datum, arg_names, expected_name, check_expand=True, n_reps=2 + self, datum, comb_indices, arg_names, expected_name, check_expand=True, n_reps=2 ): # TODO: improving handling of expected @@ -83,9 +98,6 @@ def _generate_datum_vec_tests( args_combs.append(arg_combs) - n_indices = 2 + int(check_expand) - comb_indices = list(itertools.product(*[range(n_indices)] * len(arg_names))) - new_data = [] for indices in comb_indices: new_datum = copy.copy(datum) @@ -103,14 +115,36 @@ def _generate_datum_vec_tests( return new_data def generate_vec_tests( - self, data, arg_names, expected_name=None, check_expand=True, n_reps=2 + self, + data, + arg_names, + expected_name=None, + check_expand=True, + n_reps=2, + vec_type="sym", ): - # TODO: vectorization type (e.g. symmetric) + + check_parameter_accepted_values( + vec_type, "vec_type", ["sym", "asym-left", "asym-right"] + ) + + n_args = len(arg_names) + if n_args != 2 and vec_type != "sym": + raise NotImplementedError(f"`{vec_type} only implemented for 2 arguments.") + + n_indices = 2 + int(check_expand) + comb_indices = list(itertools.product(*[range(n_indices)] * len(arg_names))) + if n_args == 2 and vec_type != "sym": + comb_indices = self._filter_combs( + comb_indices, vec_type, threshold=1 + int(check_expand) + ) + new_data = [] for datum in data: new_data.extend( self._generate_datum_vec_tests( datum, + comb_indices, arg_names, expected_name=expected_name, check_expand=check_expand, From a3386bb87710eab8b0093a952cccd69baea3878a Mon Sep 17 00:00:00 2001 From: "L. F. Pereira" Date: Tue, 2 Aug 2022 15:55:39 +0200 Subject: [PATCH 3/5] Update _repeat_point (improves robustness of both input and output handling) --- tests/data_generation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/data_generation.py b/tests/data_generation.py index 5ebfd8d37..eccc042b3 100644 --- a/tests/data_generation.py +++ b/tests/data_generation.py @@ -19,6 +19,9 @@ def _expand_point(point): def _repeat_point(point, n_reps=2): + if not gs.is_array(point): + return [point] * n_reps + return gs.repeat(_expand_point(point), n_reps, axis=0) @@ -79,7 +82,6 @@ def _generate_datum_vec_tests( self, datum, comb_indices, arg_names, expected_name, check_expand=True, n_reps=2 ): - # TODO: improving handling of expected if expected_name is not None: has_expected = True expected = datum.get(expected_name) From 65637f2a6cbde178defb8d53ed4fbfee34ae73f9 Mon Sep 17 00:00:00 2001 From: "L. F. Pereira" Date: Fri, 5 Aug 2022 09:44:31 +0200 Subject: [PATCH 4/5] Update naming and add docstrings to generate_vectorization_tests --- tests/data_generation.py | 48 +++++++++++++++++++++++++++++++--------- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/tests/data_generation.py b/tests/data_generation.py index eccc042b3..1987500e3 100644 --- a/tests/data_generation.py +++ b/tests/data_generation.py @@ -66,8 +66,8 @@ def generate_tests(self, smoke_test_data, random_test_data=[]): def _filter_combs(self, combs, vec_type, threshold): MAP_VEC_TYPE = { - "asym-left": 1, - "asym-right": 0, + "repeat-first": 1, + "repeat-second": 0, } index = MAP_VEC_TYPE[vec_type] other_index = (index + 1) % 2 @@ -78,7 +78,7 @@ def _filter_combs(self, combs, vec_type, threshold): return combs - def _generate_datum_vec_tests( + def _generate_datum_vectorization_tests( self, datum, comb_indices, arg_names, expected_name, check_expand=True, n_reps=2 ): @@ -116,35 +116,61 @@ def _generate_datum_vec_tests( return new_data - def generate_vec_tests( + def generate_vectorization_tests( self, data, arg_names, expected_name=None, check_expand=True, n_reps=2, - vec_type="sym", + vectorization_type="sym", ): + """Create new data with vectorized version of inputs. + + Parameters + ---------- + data : list of dict + Data. Each to vectorize. + arg_names: list + Name of inputs to vectorize. + expected_name: str + Output name in case it needs to be repeated. + check_expand: bool + If `True`, expanded version of each input will be tested. + n_reps: int + Number of times the input points should be repeated. + vectorization_type: str + Possible values are `sym`, `repeat-first`, `repeat-second`. + `repeat-first` and `repeat-second` only valid for two argument case. + `repeat-first` and `repeat-second` test asymmetric cases, repeating + only first or second input, respectively. + + + """ check_parameter_accepted_values( - vec_type, "vec_type", ["sym", "asym-left", "asym-right"] + vectorization_type, + "vectorization_type", + ["sym", "repeat-first", "repeat-second"], ) n_args = len(arg_names) - if n_args != 2 and vec_type != "sym": - raise NotImplementedError(f"`{vec_type} only implemented for 2 arguments.") + if n_args != 2 and vectorization_type != "sym": + raise NotImplementedError( + f"`{vectorization_type} only implemented for 2 arguments." + ) n_indices = 2 + int(check_expand) comb_indices = list(itertools.product(*[range(n_indices)] * len(arg_names))) - if n_args == 2 and vec_type != "sym": + if n_args == 2 and vectorization_type != "sym": comb_indices = self._filter_combs( - comb_indices, vec_type, threshold=1 + int(check_expand) + comb_indices, vectorization_type, threshold=1 + int(check_expand) ) new_data = [] for datum in data: new_data.extend( - self._generate_datum_vec_tests( + self._generate_datum_vectorization_tests( datum, comb_indices, arg_names, From 2c86088496566d0a5c2875dc7ab3c0f0885d1f10 Mon Sep 17 00:00:00 2001 From: "L. F. Pereira" Date: Fri, 5 Aug 2022 09:47:28 +0200 Subject: [PATCH 5/5] Fix lint issues --- tests/data_generation.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/data_generation.py b/tests/data_generation.py index 1987500e3..b32b6be29 100644 --- a/tests/data_generation.py +++ b/tests/data_generation.py @@ -144,10 +144,7 @@ def generate_vectorization_tests( `repeat-first` and `repeat-second` only valid for two argument case. `repeat-first` and `repeat-second` test asymmetric cases, repeating only first or second input, respectively. - - """ - check_parameter_accepted_values( vectorization_type, "vectorization_type", @@ -205,7 +202,8 @@ def random_point_belongs_test_data( return self.generate_tests([], random_data) def projection_belongs_test_data(self, belongs_atol=gs.atol): - """Generate data to check that a point projected on a manifold belongs to the manifold. + """Generate data to check that a point projected on a manifold belongs + to the manifold. Parameters ---------- @@ -300,7 +298,8 @@ def to_tangent_is_tangent_in_ambient_space_test_data(self): class _LevelSetTestData(_ManifoldTestData): def intrinsic_after_extrinsic_test_data(self): - """Generate data to check that changing coordinate system twice gives back the point. + """Generate data to check that changing coordinate system twice gives + back the point. Assumes that random_point generates points in extrinsic coordinates. """ @@ -316,7 +315,8 @@ def intrinsic_after_extrinsic_test_data(self): return self.generate_tests([], random_data) def extrinsic_after_intrinsic_test_data(self): - """Generate data to check that changing coordinate system twice gives back the point. + """Generate data to check that changing coordinate system twice gives + back the point. Assumes that the first elements in space_args is the dimension of the space. """