-
-
Notifications
You must be signed in to change notification settings - Fork 5.4k
ENH: special: dispatch to array library #19023
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This is great, thanks for getting the ball rolling here Matt!
What's happening here is that currently the ufuncs will convert the PyTorch tensor to a NumPy array, call the
I think we're going to need that mapping, since I don't think any other library implements all of the ~380 functions in Special functions would in principle be a useful array API standard extension, a la |
@mdhaber would you be interested in hashing out a proposal for a |
Yes, I considered doing that. I looked into GitHub code search to collect usage data for the functions the libraries have in common and to determine what other functions should be prioritized. Are there other tools you'd recommend? Also, IIRC the @array_api_compatible decorator currently tests only a few backends. Presumably we don't need to run tests of all these functions on all backends on every CI run. But how should they be tested? I can't easily do them all locally. As for what the test would look like, I was thinking of something like https://github.com/mdhaber/mparray/blob/70b134492a0d6fc08dbf76093ca61368f681000d/mparray/special/tests/test_special.py#L41 but probably with Hypothesis, and certainly without fixed magic numbers. (Those were a shortcut to getting something working.) The idea is just to make sure the functions are supposed to be the same and that arguments are in the right order, not to check accuracy. |
May be worth giving https://github.com/data-apis/python-record-api a look. An excerpt from Aaron's slides from SciPy 2023: And for finding out which functions the libraries have in common: https://github.com/data-apis/array-api-comparison |
That |
That sounds useful indeed. Using |
Thanks @rgommers. |
I think CI also tests |
elif xp.__name__ == f"{array_api_comat_prefix}.jax": | ||
f = getattr(xp.scipy.special, f_name, None) | ||
else: | ||
f = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was going to return a wrapper function that converts to a NumPy array, evaluates the function, and converts back, but that would never happen: if the array is not recognized, array_namespace
raises an error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that's correct? There is:
elif hasattr(x, '__array_namespace__'):
namespaces.add(x.__array_namespace__(api_version=api_version))
so any object with a __array_namespace__
will be accepted by design.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed that an error is not raised. That code-snippet is the case which handles numpy.array_api
currently. Converting to NumPy and back is what the cluster
and fft
PRs do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I didn't try a real array type; I got the error when providing a list as the argument like @rgommers reported in #19023 (review). I'll try returning the wrapper.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now returning a wrapper that does the conversions and evaluates w/ SciPy normally.
xp = array_namespace(*args[:n_args]) | ||
f = get_array_special_func(f_name, xp) | ||
return f(*args, **kwargs) | ||
wrapped.__doc__ = getattr(_ufuncs, f_name).__doc__ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should probably do something with the docs to mention CuPy/ support.
Is it OK if it's something generic advising of experimental support that I could add programmatically with _docscrape.py
? Should it be in the notes or in a more prominant Note admonition? Suggestions on what it should say?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably best not to do that in this decorator. Programmatically appending notes is what we tried in various flavors before, and it usually doesn't work well without a ton of effort. For now we can keep it in the module-level docstring for special
I'd say.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So at:
https://docs.scipy.org/doc/scipy/reference/special.html
put a list of functions that can be dispatched to the array API equivalents? Since I would personally put it into the function documentation (with docscrape), I don't have a vision for what that would look like. Would it be a note at the top? Mention that users can enable support for CuPy and PyTorch with the SCIPY_ARRAY_API
environment variable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rgommers let me know how you were thinking this documentation should go and I'll add it here. Or if you're open to seeing how docscrape would do, LMK what you'd like to see in the docs (notes section, or more prominently?) of each wrapped function.
Maybe something like
This function has preliminary support for CuPy, PyTorch, and JAX arrays. When environment variable
SCIPY_ARRAY_API=1
, arrays of these types are automatically passed to the appropriate function of the native library for calculation, and it will return an array of the same type. In this case, only the first positional argument is supported. Other arguments will be passed to the underlying function, but the behavior is not tested.
"first positional argument" would be replaced by "first X positional arguments" where appropriate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That looks like a nice doc snippet to me. The last two sentences may perhaps be confusing to the average reader. In most cases there aren't more than 1 (or X) arguments, right? So everything is actually tested in those cases, and those sentences may be left out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is usually an out
argument, and some functions have a complex
argument, but it's not consistent. These are not tested, and they won't work with all libraries, so I think it should be stated. The intent was to write something that is easy to adapt to all functions. If I use docscrape
to add this comment to the notes, it's easy to customize the message to say "first positional argument" or "first X positional arguments" because the wrapper already has that information. It's less easy to customize statements about out
, complex
, or other parameters without manually adding a table with that information - but I could if it's important.
ref = f(*args_np) | ||
res = f(*args_xp) | ||
# having trouble with this locally | ||
assert type(res) == type(xp.asarray([])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a better way to check that the array type is correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think type(x)
is actually formally guaranteed to work, but I think it always will in practice. I'm not sure why this doesn't work for you (it does seem to for me), but maybe this does?
isinstance(x, type(xp.asarray([])))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you just trying to check that an array of the correct namespace is returned here? If so, I am using this assertion written by @tylerjereddy, which I have added to _array_api.py
in #19005.
def _assert_matching_namespace(actual, expected):
expected_space = array_api_compat.array_namespace(expected)
if isinstance(actual, tuple):
for arr in actual:
arr_space = array_api_compat.array_namespace(arr)
assert arr_space == expected_space
else:
actual_space = array_api_compat.array_namespace(actual)
assert actual_space == expected_space
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, this is another thing I'll change after gh-19005 merges.
args_xp = [xp.asarray(arg) for arg in args_np] | ||
|
||
ref = f(*args_np) | ||
res = f(*args_xp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we test the out
argument?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out
is tricky. It's not part of the array API standard. It will work in principle with PyTorch and CuPy, but not with JAX. It's not used a lot, and there are good reasons to avoid it. So I'd suggest to keep out=
specific to numpy arrays.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, so I won't test it, but hopefully it is covered by tests that were already written for scipy.special
. (I've checked that it works, though.)
import numpy as np
from scipy import special
x = np.asarray([1])
y = np.asarray([1.])
special.ndtr(x, y)
print(y) # [0.84134475]
@rgommers does this look closer to what you had in mind? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks close to what we want I think, thanks @mdhaber.
One thing seems wrong with the array type handling:
>>> special.ndtri_exp([-3])
array([-1.64692172])
>>> special.ndtr([-3])
Traceback (most recent call last):
...
TypeError: The input is not a supported <
CEB7
span class="pl-s1">array type
For array-like input we should preserve the current behavior.
elif xp.__name__ == f"{array_api_comat_prefix}.jax": | ||
f = getattr(xp.scipy.special, f_name, None) | ||
else: | ||
f = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that's correct? There is:
elif hasattr(x, '__array_namespace__'):
namespaces.add(x.__array_namespace__(api_version=api_version))
so any object with a __array_namespace__
will be accepted by design.
xp = array_namespace(*args[:n_args]) | ||
f = get_array_special_func(f_name, xp) | ||
return f(*args, **kwargs) | ||
wrapped.__doc__ = getattr(_ufuncs, f_name).__doc__ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably best not to do that in this decorator. Programmatically appending notes is what we tried in various flavors before, and it usually doesn't work well without a ton of effort. For now we can keep it in the module-level docstring for special
I'd say.
args_xp = [xp.asarray(arg) for arg in args_np] | ||
|
||
ref = f(*args_np) | ||
res = f(*args_xp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out
is tricky. It's not part of the array API standard. It will work in principle with PyTorch and CuPy, but not with JAX. It's not used a lot, and there are good reasons to avoid it. So I'd suggest to keep out=
specific to numpy arrays.
ref = f(*args_np) | ||
res = f(*args_xp) | ||
# having trouble with this locally | ||
assert type(res) == type(xp.asarray([])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think type(x)
is actually formally guaranteed to work, but I think it always will in practice. I'm not sure why this doesn't work for you (it does seem to for me), but maybe this does?
isinstance(x, type(xp.asarray([])))
# having trouble with this locally | ||
assert type(res) == type(xp.asarray([])) | ||
assert res.shape == ref.shape | ||
assert_allclose(np.asarray(res), ref) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unless I am missing something, this np.asarray
will fail for CuPy and PyTorch GPU. There has been discussion in #19005 of how best to go about testing on GPU. I have added this utility to conftest.py
:
def set_assert_allclose(xp=None):
if xp is None:
return npt.assert_allclose
if 'cupy' in xp.__name__:
return xp.testing.assert_allclose
elif 'torch' in xp.__name__:
return xp.testing.assert_close
return npt.assert_allclose
This utility is called before every use of assert_allclose
:
def test_ifft(self, xp):
x = xp.asarray(random(30) + 1j*random(30))
_assert_allclose = set_assert_allclose(xp)
_assert_allclose(x, fft.ifft(fft.fft(x)))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I'll also implement this if gh-19005 merges first.
But what will be the canonical way to convert any Array API array to a different Array API array (if not xp.asarray
, where xp
is the desired type)?
In gh-19005 I see:
xp = array_namespace(x)
x = np.asarray(x)
...
return xp.asarray(y)
and this is what was suggested in gh-18668, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I can tell, this shouldn't be needed outside of tests. For tests, the use of set_assert_allclose
is a deliberate attempt to avoid this as well.
The conversion to np
and back is just so that we can make use of SciPy's existing functions, which only work for NumPy arrays. In general, there shouldn't be a need to convert between different array libraries.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there was confusion here because I was lumping np
in with other Array API arrays. I guess that is not correct.
You wrote
Unless I am missing something, this
np.asarray
will fail for CuPy and PyTorch GPU
In that case, my question was how should we be converting from any xp
array to a np
array outside of tests? (And until gh-19005 merges, I will do the conversion this way in tests before using np.testing.assert_allclose
, too.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, sorry for the confusion. We just use np.asarray
, with import numpy as np
. This is not the array API compatible version, but that doesn't matter since this is our 'last resort'. We expect it to raise exceptions, with GPU arrays for example. This can't be fixed without a general dispatch system, but that is out of scope for the array API work, hence your efforts to call to some particular libraries here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have an example of a conversion which you think should be possible? As far as I understand, the array API is supposed to be completely agnostic of the underlying library. Are you thinking of something like a general to_numpy
function? This is something unsuitable for the array API, as it should hold up fine, regardless of the status of one particular project. I.e. forcing libraries to interact with other libraries in order to comply is undesirable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have an example of a conversion which you think should be possible?
I would expect torch.asarray(np.asarray([1, 2, 3]))
to return a PyTorch array, and np.asarray(torch.asarray([1, 2, 3]))
to return a NumPy array. Indeed, they do if they are both imported from array_api_compat
, so I would expect the same if we were to replace np
and torch
with any two Array libraries imported from array_api_compat
.
Are you thinking of something like a general to_numpy function?
I was expecting that np.asarray
, where np
is an Array API compatible version of NumPy, would convert any Array API array to a NumPy array.
Note that in the referenced code, np
is defined by:
from scipy._lib.array_api_compat.array_api_compat import numpy as np
in case that makes a difference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got you 👍. I am not so sure, but I believe that the blocker for this is device transfers. Imagine two array libraries which use different devices, both of which do not exist yet, but do in this hypothetical future.
How could we standardise the ability to transfer between these devices? It seems to me that you are hoping for the ability to transfer between any devices which could possibly be used by libraries which support the standard. Since the standard does not place any restrictions on the devices a library supports, I don't think that this is possible.
I think that we're talking about the same thing here, but apologies if I've misunderstood.
cc @rgommers for clarification as I may also be misunderstanding the standard.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would expect
torch.asarray(np.asarray([1, 2, 3]))
to return a PyTorch array, andnp.asarray(torch.asarray([1, 2, 3]))
to return a NumPy array
Yes indeed - when the data is on the same device.
How could we standardise the ability to transfer between these devices?
We can't really. At least, we haven't quite figured out how to best do this for testing purposes specifically. We do know that we never want to do this for actual library code.
Are we on the same page here now? The current PR diff says:
# TODO: use `set_assert_allclose` when gh-19005 merges
which sounds right (and if not, we'll have figured this out on gh-19005).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do know that we never want to do this for actual library code.
You asked that the else
case be to convert to NumPy, perform the computation, and convert back (#19023 (comment)). So even for some Array API compatible arrays, this will just fail, and users will get whatever error np.asarray
gives them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TypeError: The input is not a supported array type
For array-like input we should preserve the current behavior.
This error is coming from array_api_compat
: https://github.com/data-apis/array-api-compat/blob/546fa3d65dde1cc20c68b2d22241f7977a8c08fa/array_api_compat/common/_helpers.py#L102-L104
@rgommers I believe that, following the cluster
PR, we actually do not support array-likes when SCIPY_ARRAY_API
is set. SciPy's array_namespace
calls compliance_scipy
, which calls is_array_api_obj
from array_api_compat
, which returns False
for array-likes:
def is_array_api_obj(x):
"""
Check if x is an array API compatible array object.
"""
return _is_numpy_array(x) \
or _is_cupy_array(x) \
or _is_torch_array(x) \
or hasattr(x, '__array_namespace__')
compliance_scipy
raises TypeError: Only support Array API compatible arrays
in this situation.
So, if we do want to support array-likes (at least for now), I think a change is needed to compliance_scipy
.
# having trouble with this locally | ||
assert type(res) == type(xp.asarray([])) | ||
assert res.shape == ref.shape | ||
assert_allclose(np.asarray(res), ref) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I can tell, this shouldn't be needed outside of tests. For tests, the use of set_assert_allclose
is a deliberate attempt to avoid this as well.
The conversion to np
and back is just so that we can make use of SciPy's existing functions, which only work for NumPy arrays. In general, there shouldn't be a need to convert between different array libraries.
This is the error you get when you pass a list to Maybe what we need to do is skip all this stuff unless the |
This already happens in the However, my comments above pertain to how to treat array-likes when the environment variable is set. This is what may involve a change to |
Erg, well, as written in a11333c, this also means that the docstring is only modified when |
OK, I think the suggestions will resolve the lint and MyPy issues. Maybe we can figure out how to do the documentation in a followup. |
That could perhaps be handled by setting
Amending docstrings in
That seems like a good idea indeed. |
OK. Well, aside from needing #19023 (comment), CI seems to be happy. Then I guess we can just wait for #19186 to get the stricter array checks. After this merges, I can open an issue about what to do about the ufuncness of these functions - whether we are OK with these no longer being ufuncs, or if there is some way to get the best of both worlds. |
Thanks @steppi (and @mdhaber)! Just circling back here, I'm still keen to help. Did this tracking issue get created?
Certainly, will soon create said issue. I agree, it's at least worth a proper discussion. Thanks for taking the time to respond earlier! |
Not yet, but I was actually going to email you this weekend about a real-time meeting, if you're up for that. I'll also add you to the spreadsheet I made. |
Updated with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On the whole I think this looks good. I have some minor questions but see no major issues here.
elif xp.__name__ == f"{array_api_compat_prefix}.jax": | ||
f = getattr(xp.scipy.special, f_name, None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just want to confirm I'm on the right page regarding jax support. It seems at the moment jax is not treated as array api compatible. Currently array_namespace
considers jax arrays to be numpy arrays and they end up getting dispatched to the regular scipy.special
version of a supported function, which converts them to a numpy array and returns numpy array output.
Am I right that the code path above is for future proofing, so that once jax is added to array_api_compat
, things should just work? I guess when jax is added to array_api_compat
, there will be an is_jax
helper function that can be used here, but until then it is good to have the correct way of dispatching to jax already in the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes - perhaps I wrote it originally not knowing that JAX wasn't supported. It was decided to leave it in for future-proofing (#19023 (comment)).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modulo the question of the name support_cupy_torch_jax
, I think this looks good.
Renamed. |
Reference issue
#18668 (review)
What does this implement/fix?
Special functions are not currently covered by the Array API. Until they are, we could use a way to dispatch to the appropriate library for our special function calls. This is an idea of how we might do that. Thoughts on implementation are appreciated.
Additional information
Tests would be needed, of course. I'm just submitting this for feedback right now to see if it's anywhere close to the "right" approach.Done.Minor modifications would be needed to work correctly with CuPy (Done.cupyx.scipy.special
, notcupy.special
) and JAX (jax.scipy.special
, notjax.special
).We'd also need to raise an error or use the SciPy implementation if the library doesn't have the special function we're looking for.Done.