8000 Fix discrete curves vectorization and remove empty quotient metrics by luisfpereira · Pull Request #1935 · geomstats/geomstats · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Fix discrete curves vectorization and remove empty quotient metrics #1935

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 22, 2024

Conversation

luisfpereira
Copy link
Collaborator
@luisfpereira luisfpereira commented Jan 19, 2024

This PR follows #1924 to:

  • fix IterativeHorizontalGeodesicAligner.discrete_horizontal_geodesic vectorization

  • raise NotImplementedError in new fiber bundles to avoid recursion errors when calling the log on the quotient metrics

  • improve SRVReparametrizationBundle tests:

    • by removing skips
    • testing for both even and odd k_sampling_points due to different behavior in Euler step (this should probably be done in a specific test for the aligner later)
  • improve discrete curves tests by adding tests for all quotient metrics

  • remove empty quotient metrics (closes On the creation of empty quotient metrics #1931)

…ectorization; remove empty quotient metrics; improve discrete curves testing
Copy link
codecov bot commented Jan 19, 2024

Codecov Report

Attention: 2 lines in your changes are missing coverage. Please review.

Comparison is base (c887e2d) 90.51% compared to head (2bc94fd) 86.76%.

Files Patch % Lines
geomstats/geometry/discrete_curves.py 71.43% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1935      +/-   ##
==========================================
- Coverage   90.51%   86.76%   -3.74%     
==========================================
  Files         143      137       -6     
  Lines       13495    12898     -597     
==========================================
- Hits        12214    11190    -1024     
- Misses       1281     1708     +427     
Flag Coverage Δ
autograd 86.76% <71.43%> (?)
numpy ?
pytorch ?

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator
@alebrigant alebrigant left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Luis !

@@ -1314,6 +1314,9 @@ def discrete_horizontal_geodesic(
if initial_point.ndim != end_point.ndim:
initial_point, end_point = gs.broadcast_arrays(initial_point, end_point)

if callable(end_spline):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this the same condition as not is_batch ?

Copy link
Collaborator Author
@luisfpereira luisfpereira Jan 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the issue is that the check_is_batch function in align is only considering point, i.e. for the use case 1 point, multiple base_point it creates only one spline instead of a list, which breaks the for within discrete_horizontal_geodesic.

I could have fixed it in align, but I thought it is not worthy to create multiple equal splines. Therefore, with the callable check we ensure the code works properly for the use case I mentioned above.

Maybe to make the code more readable we can change align and keep callable

if point.ndim == bundle.total_space.point_ndim:
  spline = bundle.total_space.interpolate(point)
else:
  spline = [bundle.total_space.interpolate(point_) for point_ in point]

or change align and remove callable

if point.ndim == bundle.total_space.point_ndim:
  spline = bundle.total_space.interpolate(point)
  if base_point.ndim > bundle.total_space.point_ndim:
    spline = [spline]*base_point.shape[0]
else:
  spline = [bundle.total_space.interpolate(point_) for point_ in point]

Which solution do you prefer?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see ! I think I like the second solution better because it's easier to understand when reading the code.

(random.randint(2, 3), random.choice([5, 7, 9])),
],
)
def srv_reparameterization_bundles(request):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this necessary, and no similar function is defined for the other bundles, e.g. the rotation-reparametrization bundle ? I have to admit that for me, this code and the way it is used in the next class, is hard to read.

Copy link
Collaborator Author
@luisfpereira luisfpereira Jan 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is a pytest trick to reduce the verbosity of tests. For each param in params, it creates a test instance for TestSRVReparametrizationBundle.

More concretely, here we are parameterizing over ambient_dim and k_sampling_points. In each run of the tests we create two random parameter choices, one with even k_sampling_points, and another with odd k_sampling_points (with random ambient_dim: 2 or 3). It runs all the tests in TestSRVReparametrizationBundle twice, one for each parameter combination.

Without using fixtures, we would have to create e.g. TestSRVReparametrizationBundleOddKSamplingPoints, TestSRVReparametrizationBundleEvenKSamplingPoints, which would differ only on the integer we pass to k_sampling_points.

This also justifies the introduction of TestCase notion: we can use the same tests for different parameter combinations, even without the use of the pytest parameterization feature, just by inheriting from the appropriate TestCase child.

(Writing about all this is definitely in my todo list.)

@@ -1314,6 +1314,9 @@ def discrete_horizontal_geodesic(
if initial_point.ndim != end_point.ndim:
initial_point, end_point = gs.broadcast_arrays(initial_point, end_point)

if callable(end_spline):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see ! I think I like the second solution better because it's easier to understand when reading the code.

@luisfpereira luisfpereira merged commit 80b188b into geomstats:main Jan 22, 2024
@luisfpereira luisfpereira deleted the fix-discrete branch January 22, 2024 16:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

On the creation of empty quotient metrics
2 participants
0