-
Notifications
You must be signed in to change notification settings - Fork 263
Fixes and improvements coming from test refactoring (part 3) #1848
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…sing new test framework
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A first batch of comments!
@@ -4,6 +4,8 @@ | |||
The following functions return error messages. | |||
""" | |||
|
|||
from geomstats.exceptions import AutodiffNotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need a special exception, since the MSG is not the exception file?
Either:
- Put the message in exception file
OR: - Only use RunTimeError with this MSG?
I'd prefer the second one, as it is easier for the reader.
If we keep that tailored exception, maybe change the name to:
UseAutodiffBackend?
Or
AutodiffNotInNumpy?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The goal of having a special exception is that we can run numpy tests under a context
try:
# test
except AutodiffNotImplementedError:
pass
This way we don't have to skip tests if they require autodiff (which removes a lot of boilerplate code in the tests and is much nicer for the user).
Regarding naming, I think AutodiffNotInNumpy
is a better name than UseAutodiffBackend
. Between AutodiffNotInNumpy
and AutodiffNotImplementedError
I probably prefer the latter because if it happens we add another backend that does not support autodiff
, we will not have to change the exception name.
|
||
|
||
class AutodiffNotImplementedError(RuntimeError): | ||
"""Raised when autodiff is not implemented.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefered the RuntimeError option, which seems less "over-engineering", but there might be something I am missing there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see above.
@@ -117,7 +118,7 @@ def _squared_dist(point_a, point_b, metric): | |||
_ : array-like, shape=[...,] | |||
Geodesic distance between point_a and point_b. | |||
""" | |||
return metric.private_squared_dist(point_a, point_b) | |||
return metric._squared_dist(point_a, point_b) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah nice! How did you solve this? The introduction of the very ugly "private_squared_dist" was to get the automatic differentiation with custom gradient work with all backends.
Maybe it was not needed anymore since we dropped TF?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here is really just naming. I'm representing private by _
. But nothing changes from a logical perspective.
Codecov Report
@@ Coverage Diff @@
## master #1848 +/- ##
==========================================
- Coverage 90.10% 87.22% -2.87%
==========================================
Files 131 126 -5
Lines 13180 12402 -778
==========================================
- Hits 11874 10816 -1058
- Misses 1306 1586 +280
Flags with carried forward coverage won't be shown. Click here to find out more.
... and 22 files with indirect coverage changes 📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
"""Reshape diagonal metric matrix to a symmetric matrix of size n. | ||
@property | ||
def reshaped_metric_matrix(self): | ||
"""Diagonal metric matrix reshaped to a symmetric matrix of size n. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the rule "Docstring starts with verb at infinitive" not apply for @Property ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think not, because we can look to properties a bit like attributes.
flat_bp = gs.reshape(base_point, (-1, sphere_embedding_dim)) | ||
flat_pt = gs.reshape(point, (-1, sphere_embedding_dim)) | ||
flat_log = sphere.metric.log(flat_pt, flat_bp) | ||
batch_shape = get_batch_shape(self._space, point, base_point) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why "get_batch_shape"? and not "batch_shape". Is this because it is a defined as a property somewhere?
Even in that case, I don't think that we use "get" for the other properties defined in the library: I would remove to keep the coding style consistent, and conciseness.
https://stackoverflow.com/questions/374763/should-i-use-get-set-prefixes-in-python-method-names
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And we generally don't use "get" in the whole codebase, thus if we can avoid it altogether, it'll be better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this comes from another PR. (see our discussion)
I still think having a verb is better for a function (for me a name represents an object, not a callable). also, as you see in this example, I do batch_shape = get_batch_shape(self._space, point, base_point)
. if the function didn't have the verb, I would have to do batch_shape_ = batch_shape(self._space, point, base_point)
which is more cumbersome.
maybe we can change this in another PR if you really don't want it, since it is unrelated.
return trace_a + trace_b - 2 * trace_prod | ||
squared_dist = trace_a + trace_b - 2.0 * trace_prod | ||
|
||
return gs.where(squared_dist < 0.0, 0.0, squared_dist) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would the squared_dist be < 0?
Are we sure that it is a numerical issue, and not a problem in the math in the code?
If it is the later, then the gs.where would hide that bug, that would otherwise be caught by unit tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On my tests the negative values were always super close to zero (e.g. -1e-12). the problem seems to be the tolerance in gs.linalg.sqrtm
.
Your point is very good though!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome!
There seem to be vectorization errors in the examples: fix these?
Another high level comments:
- Add small explanations for vectorization logics in docstrings
- Why "get" and "_get"? If we can avoid using "get" that could be great, otherwise what is the clear logic jsutifying its use? Anything that is "engineering related" and not "math related"?
@@ -106,6 +111,36 @@ def repeat_point(point, n_reps=2, expand=False): | |||
return gs.repeat(gs.expand_dims(point, 0), n_reps, axis=0) | |||
|
|||
|
|||
def _is_not_none(value): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to use with filter
in some of the functions of this module.
geomstats/vectorization.py
Outdated
out : array-like | ||
If no batch, then input is returned. Otherwise it is broadcasted. | ||
""" | ||
points = filter(_is_not_none, points) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have trouble reading through this because some private functions do not have docstrings (as is conventional) and the public functions have short docstrings.
Any chance you could add details + 1-2 examples in the docstrings, even if private ones, to explain the logic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was also confused in the _get_max_ndim and the line point_max_ndim = point[0]
--> why is the ndim an element of point?
docstrings would help for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll improve the docstrings.
The main idea is that I want to have these methods working for the cases where point
is None
, because in a lot of our methods we have something like base_point=None
and then I want to use something like repeat_out(self.space, out, base_point)
without having to check if there's None
(it simplifies a lot the vectorization logic there).
therefore, before checking batch_shape
, I remove the None
from the list of points (that's the role of filter
).
geomstats/vectorization.py
Outdated
If no batch, then input is returned. Otherwise it is broadcasted. | ||
""" | ||
points = filter(_is_not_none, points) | ||
batch_shape = get_batch_shape(space, *points) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the "get_" wording?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see above.
@@ -26,7 +26,9 @@ class SasakiMetricTestData(TestData): | |||
def inner_product_test_data(self): | |||
_sqrt2 = 1.0 / gs.sqrt(2.0) | |||
base_point = gs.array([[_sqrt2, -_sqrt2, 0], [_sqrt2, _sqrt2, 1]]) | |||
_log = self.sas_sphere_metric.log(gs.array([self.pu0, self.pu1]), base_point) | |||
end_point = gs.stack([self.pu0, self.pu1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: what is pu0, pu1? what is sqrt2? the names could be more self explanatory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comes from a very old PR. I'll take that into account in the new tests.
Follows #1828.
Main additions are:
repeat_out
: quick way of repeating an output for vectorization consistency when some of the inputs are not used in the computationAutodiffNotImplementedError
exception: will be used later in the tests in atry...except
statement to avoid having to skip tests that depend in autodiff when testing for numpy.Several vectorization inconsistencies are also fixed.