-
Notifications
You must be signed in to change notification settings - Fork 266
clean test_spd_matrices with parametrization #1263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
80 commits
Select commit
Hold shift + click to select a range
451ced6
cleaning spd
SaitejaUtpala 256dfed
fix bugs
SaitejaUtpala ea31151
blacken
SaitejaUtpala 95c3e25
add numerical test cases for SPDAffine log,exp
SaitejaUtpala fda53a1
fix bug in test_diff_cf
SaitejaUtpala e250ca0
make tests pass
SaitejaUtpala f445b41
make tests pass
SaitejaUtpala 49ec5bd
make tests pass
SaitejaUtpala 6cd3720
add numerical tests for BW metric
SaitejaUtpala 63ae457
add more numerical tests
SaitejaUtpala d5c86ca
add more numerical tests
SaitejaUtpala cdf8acd
add more numerical tests
SaitejaUtpala 094dc50
BW squared dist
SaitejaUtpala 02a3d31
composition
SaitejaUtpala 89bc9c9
composition
SaitejaUtpala 6e222c7
checkpoint
SaitejaUtpala 9af13e5
add pytest.ini
SaitejaUtpala c9f067b
blacken
SaitejaUtpala 2bc9bc4
commit matrices
SaitejaUtpala 6638513
add
SaitejaUtpala 7a81f2d
fix it
SaitejaUtpala 83adfef
complete log_exp_composition
SaitejaUtpala 0e442d9
configure
SaitejaUtpala 94c0c47
configure
SaitejaUtpala 3c86c83
remove matrices stuff
SaitejaUtpala 73a7c19
reduce n
SaitejaUtpala 89f59e6
add documentation
SaitejaUtpala 4328d8c
stash
SaitejaUtpala b4e5e5c
blacken
SaitejaUtpala 5704e80
add
SaitejaUtpala 4b08d80
add more tests
SaitejaUtpala 833920a
add LT
SaitejaUtpala 92d169a
add PLT
SaitejaUtpala 218668e
add MatsV1
SaitejaUtpala ff6e9fa
add frm
SaitejaUtpala c03c443
add Euclidean
SaitejaUtpala d4c1245
Merge branch 'master' of https://github.com/geomstats/geomstats into …
SaitejaUtpala 038813e
add tests in mats
SaitejaUtpala 6b2f9f3
add more tests in mats
SaitejaUtpala caaa197
Test Matrices Metric
SaitejaUtpala dadc1e0
more tests
SaitejaUtpala 970a713
add flatten
SaitejaUtpala 30ec418
add randomized tets
SaitejaUtpala 08cde33
remove _test.matrices.py
SaitejaUtpala bab804f
add more tests
SaitejaUtpala b7abb6c
make matrices work
SaitejaUtpala 68b70f9
remove euclidean
SaitejaUtpala 94c7903
fix frm and ltm
SaitejaUtpala f3cb690
remove errors in PLT
SaitejaUtpala a16f4ea
make PLT tests pass
SaitejaUtpala 69dd367
test
SaitejaUtpala e60bf75
add more tests
SaitejaUtpala bbb09f5
add
SaitejaUtpala 5127db4
test_spd
SaitejaUtpala 94e92c8
make pytorch tests pas
SaitejaUtpala e931bd2
reduce precision
SaitejaUtpala b4d0f9a
further reduce precision
SaitejaUtpala acbe8fd
reduce values
SaitejaUtpala 0e830f4
reduce precision
SaitejaUtpala 92e469c
test_spd
SaitejaUtpala f9c65de
parallel testing
SaitejaUtpala 3f128dc
revert
SaitejaUtpala
8000
Jan 20, 2022
bd15739
testing
SaitejaUtpala 31bb95a
clean
SaitejaUtpala 9e11817
blacken
SaitejaUtpala daee562
add more tests
SaitejaUtpala 4876467
stash
SaitejaUtpala 983ec6d
add more tests
SaitejaUtpala 8c41f0a
add gs array
SaitejaUtpala e1b5b97
reduce size
SaitejaUtpala 6b31bfa
add more tests
SaitejaUtpala 8210442
add more tests
SaitejaUtpala af2703b
correct error
SaitejaUtpala 4a4892f
addr comments
SaitejaUtpala 228ec91
remove code comments
SaitejaUtpala 5f7cf16
add data
SaitejaUtpala 52650e0
renaming
SaitejaUtpala 29d2792
add missing batch data
SaitejaUtpala 0b3a8ce
reduce size of tets
SaitejaUtpala 11cba63
reduce size
SaitejaUtpala File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
[pytest] | ||
markers = | ||
smoke: simple and basic numerical tests. | ||
random: tests that use randomized data. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,51 +1,63 @@ | ||
"""Unit tests for full rank matrices.""" | ||
|
||
import warnings | ||
|
||
import geomstats.backend as gs | ||
import geomstats.tests | ||
import tests.helper as helper | ||
from geomstats.geometry.full_rank_matrices import FullRankMatrices | ||
from tests.conftest import Parametrizer, TestCase, TestData | ||
|
||
|
||
class TestFullRankMatrices(TestCase, metaclass=Parametrizer): | ||
|
||
cls = FullRankMatrices | ||
|
||
class TestDataFullRankMatrices(TestData): | ||
def belongs_data(self): | ||
smoke_data = [ | ||
dict( | ||
m=3, | ||
n=2, | ||
mat=[ | ||
[-1.6473486, -1.18240309], | ||
[0.1944016, 0.18169231], | ||
[-1.13933855, -0.64971248], | ||
], | ||
expected=True, | ||
), | ||
dict( | ||
m=3, n=2, mat=[[1.0, -1.0], [1.0, -1.0], [0.0, 0.0]], expected=False | ||
), | ||
] | ||
return self.generate_tests(smoke_data) | ||
|
||
def random_and_belongs_data(self): | ||
smoke_data = [ | ||
dict(m=1, n=1, n_points=1), | ||
dict(m=1, n=1, n_points=1000), | ||
dict(m=2, n=2, n_points=1), | ||
dict(m=2, n=2, n_points=100), | ||
dict(m=10, n=5, n_points=100), | ||
] | ||
return self.generate_tests(smoke_data) | ||
|
||
def projection_and_belongs_data(self): | ||
shapes = [(1, 1), (1, 1), (1, 10), (2, 2), (10, 5), (15, 15)] | ||
sizes = [1, 10, 1, 1, 100, 10] | ||
random_data = [ | ||
dict(m=m, n=n, mats=gs.random.normal(size=(size, m, n))) | ||
for (m, n), size in zip(shapes, sizes) | ||
] | ||
|
||
return self.generate_tests([], random_data) | ||
|
||
class TestFullRankMatrices(geomstats.tests.TestCase): | ||
"""Test of Full Rank Matrices methods.""" | ||
|
||
def setup_method(self): | ||
"""Set up the test.""" | ||
warnings.simplefilter("ignore", category=ImportWarning) | ||
|
||
gs.random.seed(1234) | ||
testing_data = TestDataFullRankMatrices() | ||
|
||
self.m = 3 | ||
self.n = 2 | ||
self.space = FullRankMatrices(self.m, self.n) | ||
def test_belongs(self, m, n, mat, expected): | ||
self.assertAllClose(self.cls(m, n).belongs(gs.array(mat)), gs.array(expected)) | ||
|
||
def test_belongs(self): | ||
"""Test of belongs method.""" | ||
fr = self.space | ||
mat_fr = gs.array( | ||
[ | ||
[-1.6473486, -1.18240309], | ||
[0.1944016, 0.18169231], | ||
[-1.13933855, -0.64971248], | ||
] | ||
def test_random_and_belongs(self, m, n, n_points): | ||
cls = self.cls(m, n) | ||
self.assertAllClose( | ||
gs.all(cls.belongs(cls.random_point(n_points))), gs.array(True) | ||
) | ||
mat_not_fr = gs.array([[1.0, -1.0], [1.0, -1.0], [0.0, 0.0]]) | ||
result = fr.belongs(mat_fr) | ||
self.assertTrue(result) | ||
result = fr.belongs(mat_not_fr) | ||
self.assertFalse(result) | ||
|
||
def test_projection_and_belongs(self): | ||
"""Test of projection method.""" | ||
shape = (2, self.m, self.n) | ||
result = helper.test_projection_and_belongs(self.space, shape, atol=gs.atol) | ||
for res in result: | ||
self.assertTrue(res) | ||
|
||
def test_random_and_belongs(self): | ||
"""Test of random point sampling method.""" | ||
mat = self.space.random_point(5) | ||
result = self.space.belongs(mat) | ||
self.assertTrue(gs.all(result)) | ||
|
||
def test_projection_and_belongs(self, m, n, mat): | ||
self.assertAllClose(gs.all(self.cls(m, n).belongs(mat)), True) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.