8000 Fixes and improvements coming from test refactoring (part 3) by luisfpereira · Pull Request #1848 · geomstats/geomstats · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Fixes and improvements coming from test refactoring (part 3) #1848

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 5, 2023

Conversation

luisfpereira
Copy link
Collaborator
@luisfpereira luisfpereira commented May 5, 2023

Follows #1828.

Main additions are:

  • repeat_out: quick way of repeating an output for vectorization consistency when some of the inputs are not used in the computation

  • AutodiffNotImplementedError exception: will be used later in the tests in a try...except statement to avoid having to skip tests that depend in autodiff when testing for numpy.

Several vectorization inconsistencies are also fixed.

@luisfpereira luisfpereira requested a review from ninamiolane May 5, 2023 15:00
Copy link
Collaborator
@ninamiolane ninamiolane left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A first batch of comments!

@@ -4,6 +4,8 @@
The following functions return error messages.
"""

from geomstats.exceptions import AutodiffNotImplementedError
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need a special exception, since the MSG is not the exception file?

Either:

  1. Put the message in exception file
    OR:
  2. Only use RunTimeError with this MSG?
    I'd prefer the second one, as it is easier for the reader.

If we keep that tailored exception, maybe change the name to:
UseAutodiffBackend?
Or
AutodiffNotInNumpy?

Copy link
Collaborator Author
@luisfpereira luisfpereira May 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The goal of having a special exception is that we can run numpy tests under a context

try:
    # test
except AutodiffNotImplementedError:
   pass

This way we don't have to skip tests if they require autodiff (which removes a lot of boilerplate code in the tests and is much nicer for the user).

Regarding naming, I think AutodiffNotInNumpy is a better name than UseAutodiffBackend. Between AutodiffNotInNumpy and AutodiffNotImplementedError I probably prefer the latter because if it happens we add another backend that does not support autodiff, we will not have to change the exception name.



class AutodiffNotImplementedError(RuntimeError):
"""Raised when autodiff is not implemented."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefered the RuntimeError option, which seems less "over-engineering", but there might be something I am missing there?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above.

@@ -117,7 +118,7 @@ def _squared_dist(point_a, point_b, metric):
_ : array-like, shape=[...,]
Geodesic distance between point_a and point_b.
"""
return metric.private_squared_dist(point_a, point_b)
return metric._squared_dist(point_a, point_b)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah nice! How did you solve this? The introduction of the very ugly "private_squared_dist" was to get the automatic differentiation with custom gradient work with all backends.

Maybe it was not needed anymore since we dropped TF?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here is really just naming. I'm representing private by _. But nothing changes from a logical perspective.

@codecov
Copy link
codecov bot commented May 5, 2023

Codecov Report

Merging #1848 (b8d49d3) into master (6f289d3) will decrease coverage by 2.87%.
The diff coverage is 91.33%.

@@            Coverage Diff             @@
##           master    #1848      +/-   ##
==========================================
- Coverage   90.10%   87.22%   -2.87%     
==========================================
  Files         131      126       -5     
  Lines       13180    12402     -778     
==========================================
- Hits        11874    10816    -1058     
- Misses       1306     1586     +280     
Flag Coverage Δ
autograd 87.22% <91.33%> (?)
numpy ?
pytorch ?

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
geomstats/_backend/__init__.py 82.26% <ø> (ø)
geomstats/exceptions.py 0.00% <0.00%> (ø)
geomstats/geometry/_hyperbolic.py 96.35% <ø> (+1.90%) ⬆️
geomstats/geometry/discrete_surfaces.py 93.34% <ø> (-4.00%) ⬇️
geomstats/geometry/stratified/wald_space.py 26.39% <0.00%> (-63.82%) ⬇️
geomstats/learning/aac.py 45.60% <0.00%> (-52.80%) ⬇️
.../learning/agglomerative_hierarchical_clustering.py 100.00% <ø> (ø)
geomstats/learning/exponential_barycenter.py 100.00% <ø> (ø)
geomstats/learning/geodesic_regression.py 83.77% <ø> (+3.90%) ⬆️
geomstats/visualization/hypersphere.py 69.92% <ø> (-0.26%) ⬇️
... and 75 more

... and 22 files with indirect coverage changes

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

"""Reshape diagonal metric matrix to a symmetric matrix of size n.
@property
def reshaped_metric_matrix(self):
"""Diagonal metric matrix reshaped to a symmetric matrix of size n.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the rule "Docstring starts with verb at infinitive" not apply for @Property ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think not, because we can look to properties a bit like attributes.

flat_bp = gs.reshape(base_point, (-1, sphere_embedding_dim))
flat_pt = gs.reshape(point, (-1, sphere_embedding_dim))
flat_log = sphere.metric.log(flat_pt, flat_bp)
batch_shape = get_batch_shape(self._space, point, base_point)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why "get_batch_shape"? and not "batch_shape". Is this because it is a defined as a property somewhere?

Even in that case, I don't think that we use "get" for the other properties defined in the library: I would remove to keep the coding style consistent, and conciseness.

https://stackoverflow.com/questions/374763/should-i-use-get-set-prefixes-in-python-method-names

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And we generally don't use "get" in the whole codebase, thus if we can avoid it altogether, it'll be better.

Copy link
Collaborator Author
@luisfpereira luisfpereira May 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comes from another PR. (see our discussion)

I still think having a verb is better for a function (for me a name represents an object, not a callable). also, as you see in this example, I do batch_shape = get_batch_shape(self._space, point, base_point). if the function didn't have the verb, I would have to do batch_shape_ = batch_shape(self._space, point, base_point) which is more cumbersome.

maybe we can change this in another PR if you really don't want it, since it is unrelated.

return trace_a + trace_b - 2 * trace_prod
squared_dist = trace_a + trace_b - 2.0 * trace_prod

return gs.where(squared_dist < 0.0, 0.0, squared_dist)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would the squared_dist be < 0?

Are we sure that it is a numerical issue, and not a problem in the math in the code?

If it is the later, then the gs.where would hide that bug, that would otherwise be caught by unit tests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On my tests the negative values were always super close to zero (e.g. -1e-12). the problem seems to be the tolerance in gs.linalg.sqrtm.

Your point is very good though!

Copy link
Collaborator
@ninamiolane ninamiolane left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!

There seem to be vectorization errors in the examples: fix these?

Another high level comments:

  • Add small explanations for vectorization logics in docstrings
  • Why "get" and "_get"? If we can avoid using "get" that could be great, otherwise what is the clear logic jsutifying its use? Anything that is "engineering related" and not "math related"?

@@ -106,6 +111,36 @@ def repeat_point(point, n_reps=2, expand=False):
return gs.repeat(gs.expand_dims(point, 0), n_reps, axis=0)


def _is_not_none(value):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to use with filter in some of the functions of this module.

out : array-like
If no batch, then input is returned. Otherwise it is broadcasted.
"""
points = filter(_is_not_none, points)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have trouble reading through this because some private functions do not have docstrings (as is conventional) and the public functions have short docstrings.

Any chance you could add details + 1-2 examples in the docstrings, even if private ones, to explain the logic?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was also confused in the _get_max_ndim and the line point_max_ndim = point[0]

--> why is the ndim an element of point?

docstrings would help for this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll improve the docstrings.

The main idea is that I want to have these methods working for the cases where point is None, because in a lot of our methods we have something like base_point=None and then I want to use something like repeat_out(self.space, out, base_point) without having to check if there's None (it simplifies a lot the vectorization logic there).

therefore, before checking batch_shape, I remove the None from the list of points (that's the role of filter).

If no batch, then input is returned. Otherwise it is broadcasted.
"""
points = filter(_is_not_none, points)
batch_shape = get_batch_shape(space, *points)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the "get_" wording?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above.

@@ -26,7 +26,9 @@ class SasakiMetricTestData(TestData):
def inner_product_test_data(self):
_sqrt2 = 1.0 / gs.sqrt(2.0)
base_point = gs.array([[_sqrt2, -_sqrt2, 0], [_sqrt2, _sqrt2, 1]])
_log = self.sas_sphere_metric.log(gs.array([self.pu0, self.pu1]), base_point)
end_point = gs.stack([self.pu0, self.pu1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: what is pu0, pu1? what is sqrt2? the names could be more self explanatory.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comes from a very old PR. I'll take that into account in the new tests.

@luisfpereira luisfpereira merged commit 62d0ab0 into geomstats:master May 5, 2023
@luisfpereira luisfpereira deleted the fixes-from-test branch May 9, 2023 07:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants
0