8000 ENH: special: dispatch to array library by mdhaber · Pull Request #19023 · scipy/scipy · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

ENH: special: dispatch to array library #19023

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Sep 17, 2023
Merged

Conversation

mdhaber
Copy link
Contributor
@mdhaber mdhaber commented Aug 7, 2023

Reference issue

#18668 (review)

What does this implement/fix?

Special functions are not currently covered by the Array API. Until they are, we could use a way to dispatch to the appropriate library for our special function calls. This is an idea of how we might do that. Thoughts on implementation are appreciated.

from scipy import special
from scipy._lib.array_api_compat.array_api_compat import numpy
import torch

x = torch.Tensor([1.])
print(special.ndtr(x))  # tensor([0.8413])

x = numpy.asarray([1.])
print(special.ndtr(x))  # [0.84134475]

Additional information

  1. Tests would be needed, of course. I'm just submitting this for feedback right now to see if it's anywhere close to the "right" approach. Done.
  2. Some (all?) special ufuncs already seem to support CPU PyTorch tensors. The example above produces the same output with or without this PR, but I've checked that this PR actually is doing the expected dispatching in the example above.
  3. Minor modifications would be needed to work correctly with CuPy (cupyx.scipy.special, not cupy.special ) and JAX (jax.scipy.special, not jax.special). Done.
  4. We'd also need to raise an error or use the SciPy implementation if the library doesn't have the special function we're looking for. Done.
  5. We could define a mapping when there is a difference in signatures, but I think the other libraries were pretty careful to keep the signatures consistent.

@mdhaber mdhaber added enhancement A new feature or improvement scipy.special labels Aug 7, 2023
@mdhaber mdhaber requested a review from rgommers August 7, 2023 16:09
@rgommers
Copy link
Member
rgommers commented Aug 8, 2023

This is great, thanks for getting the ball rolling here Matt!

  1. Some (all?) special ufuncs already seem to support CPU PyTorch tensors. The example above produces the same output with or without this PR, but I've checked that this PR actually is doing the expected dispatching in the example above.

What's happening here is that currently the ufuncs will convert the PyTorch tensor to a NumPy array, call the scipy.special implementation, and then call __array_wrap__ to convert it back to a PyTorch tensor. With your changes, we call the corresponding torch.special functions instead. That should be better (assuming the signatures match) - for example, it will work for tensors that are part of an autograd graph. Also performance should be better I'd expect.

  1. We could define a mapping when there is a difference in signatures, but I think the other libraries were pretty careful to keep the signatures consistent.

I think we're going to need that mapping, since I don't think any other library implements all of the ~380 functions in scipy.special nor the ~280 in scipy.special._ufuncs. There's also a lot of cruft in naming. What I think we should do is define a much smaller set of special functions that we'd expect other libraries to have, with matching signatures and behavior. And then we forward those calls.

Special functions would in principle be a useful array API standard extension, a la fft and linalg. But I'd expect that to be a namespace of a reasonable size (30 functions maybe, or 50?) with the most commonly used functionality.

@rgommers
Copy link
Member
rgommers commented Aug 8, 2023

@mdhaber would you be interested in hashing out a proposal for a special extension module in the array API standard together? Most functions are element-wise with no keyword arguments, so hopefully it's not too difficult - more just scoping what to include based on usefulness and presence in CuPy/PyTorch/JAX.

@mdhaber
Copy link
Contributor Author
mdhaber commented Aug 8, 2023

Yes, I considered doing that. I looked into GitHub code search to collect usage data for the functions the libraries have in common and to determine what other functions should be prioritized. Are there other tools you'd recommend?

Also, IIRC the @array_api_compatible decorator currently tests only a few backends. Presumably we don't need to run tests of all these functions on all backends on every CI run. But how should they be tested? I can't easily do them all locally.

As for what the test would look like, I was thinking of something like https://github.com/mdhaber/mparray/blob/70b134492a0d6fc08dbf76093ca61368f681000d/mparray/special/tests/test_special.py#L41 but probably with Hypothesis, and certainly without fixed magic numbers. (Those were a shortcut to getting something working.) The idea is just to make sure the functions are supposed to be the same and that arguments are in the right order, not to check accuracy.

@lucascolley
Copy link
Member
lucascolley commented Aug 8, 2023

I looked into GitHub code search to collect usage data for the functions they have in common and to determine what other functions should be prioritized. Are there other tools you'd recommend?

May be worth giving https://github.com/data-apis/python-record-api a look.

An excerpt from Aaron's slides from SciPy 2023:

image

And for finding out which functions the libraries have in common: https://github.com/data-apis/array-api-comparison

@rgommers
Copy link
Member
rgommers commented Aug 8, 2023

I looked into GitHub code search to collect usage data for the functions they have in common and to determine what other functions should be prioritized. Are there other tools you'd recommend?

May be worth giving https://github.com/data-apis/python-record-api a look.

That python-record-api tool is quite powerful, but not easy to use (and very slow to run). So in this case, I think manual curation first and then for the fairly small set of functions where usage data could be the deciding factor, grepping the code bases of scipy, scikit-learn, statsmodels, networkx & co by hand will likely be the more painless way of going about this.

@rgommers
Copy link
Member
rgommers commented Aug 8, 2023

The idea is just to make sure the functions are supposed to be the same and that arguments are in the right order, not to check accuracy.

That sounds useful indeed. Using assert_allclose with a very loose tolerance would probably also be fine. To confirm that functions actually behave the same, something like rtol=1e-4 is probably fine? We don't need 1e-10 or better like in our own tests.

@mdhaber
Copy link
Contributor Author
mdhaber commented Aug 8, 2023

Thanks @rgommers.
There's just one part I wanted to check about - IIUC the only alternative backend available on CI right now is PyTorch CPU. I don't think we need all CI runs to test all backends, but what should the test plan be? I can test some locally, but maybe not all (e.g. JAX on Windows is tricky, and I don't think my Mac has a compatible GPU).

@rgommers
Copy link
Member
rgommers commented Aug 8, 2023

I don't think we need all CI runs to test all backends, but what should the test plan be? I can test some locally, but maybe not all (e.g. JAX on Windows is tricky, and I don't think my Mac has a compatible GPU).

I think CI also tests numpy.array_api. We can add a JAX run once it's added to array-api-compat. CuPy and PyTorch GPU aren't testable in CI (with reasonable effort at least), because there are no CI services which offer free GPUs. I can test those locally, as can @lucascolley and probably several other folks. So I think until we get much further along, we should only test on GPU locally.

elif xp.__name__ == f"{array_api_comat_prefix}.jax":
f = getattr(xp.scipy.special, f_name, None)
else:
f = None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going to return a wrapper function that converts to a NumPy array, evaluates the function, and converts back, but that would never happen: if the array is not recognized, array_namespace raises an error.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's correct? There is:

        elif hasattr(x, '__array_namespace__'):
            namespaces.add(x.__array_namespace__(api_version=api_version))

so any object with a __array_namespace__ will be accepted by design.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed that an error is not raised. That code-snippet is the case which handles numpy.array_api currently. Converting to NumPy and back is what the cluster and fft PRs do.

Copy link
Contributor Author
@mdhaber mdhaber Aug 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I didn't try a real array type; I got the error when providing a list as the argument like @rgommers reported in #19023 (review). I'll try returning the wrapper.

Copy link
Contributor Author
@mdhaber mdhaber Sep 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now returning a wrapper that does the conversions and evaluates w/ SciPy normally.

xp = array_namespace(*args[:n_args])
f = get_array_special_func(f_name, xp)
return f(*args, **kwargs)
wrapped.__doc__ = getattr(_ufuncs, f_name).__doc__
Copy link
Contributor Author
@mdhaber mdhaber Aug 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably do something with the docs to mention CuPy/ support.
Is it OK if it's something generic advising of experimental support that I could add programmatically with _docscrape.py? Should it be in the notes or in a more prominant Note admonition? Suggestions on what it should say?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably best not to do that in this decorator. Programmatically appending notes is what we tried in various flavors before, and it usually doesn't work well without a ton of effort. For now we can keep it in the module-level docstring for special I'd say.

Copy link
Contributor Author
@mdhaber mdhaber Aug 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So at:
https://docs.scipy.org/doc/scipy/reference/special.html

put a list of functions that can be dispatched to the array API equivalents? Since I would personally put it into the function documentation (with docscrape), I don't have a vision for what that would look like. Would it be a note at the top? Mention that users can enable support for CuPy and PyTorch with the SCIPY_ARRAY_API environment variable?

Copy link
Contributor Author
@mdhaber mdhaber Sep 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rgommers let me know how you were thinking this documentation should go and I'll add it here. Or if you're open to seeing how docscrape would do, LMK what you'd like to see in the docs (notes section, or more prominently?) of each wrapped function.

Maybe something like

This function has preliminary support for CuPy, PyTorch, and JAX arrays. When environment variable SCIPY_ARRAY_API=1, arrays of these types are automatically passed to the appropriate function of the native library for calculation, and it will return an array of the same type. In this case, only the first positional argument is supported. Other arguments will be passed to the underlying function, but the behavior is not tested.

"first positional argument" would be replaced by "first X positional arguments" where appropriate.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks like a nice doc snippet to me. The last two sentences may perhaps be confusing to the average reader. In most cases there aren't more than 1 (or X) arguments, right? So everything is actually tested in those cases, and those sentences may be left out?

Copy link
Contributor Author
@mdhaber mdhaber Sep 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is usually an out argument, and some functions have a complex argument, but it's not consistent. These are not tested, and they won't work with all libraries, so I think it should be stated. The intent was to write something that is easy to adapt to all functions. If I use docscrape to add this comment to the notes, it's easy to customize the message to say "first positional argument" or "first X positional arguments" because the wrapper already has that information. It's less easy to customize statements about out, complex, or other parameters without manually adding a table with that information - but I could if it's important.

ref = f(*args_np)
res = f(*args_xp)
# having trouble with this locally
assert type(res) == type(xp.asarray([]))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a better way to check that the array type is correct?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think type(x) is actually formally guaranteed to work, but I think it always will in practice. I'm not sure why this doesn't work for you (it does seem to for me), but maybe this does?

isinstance(x, type(xp.asarray([])))

Copy link
Member
@lucascolley lucascolley Aug 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you just trying to check that an array of the correct namespace is returned here? If so, I am using this assertion written by @tylerjereddy, which I have added to _array_api.py in #19005.

def _assert_matching_namespace(actual, expected):
    expected_space = array_api_compat.array_namespace(expected)
    if isinstance(actual, tuple):
        for arr in actual:
            arr_space = array_api_compat.array_namespace(arr)
            assert arr_space == expected_space
    else:
        actual_space = array_api_compat.array_namespace(actual)
        assert actual_space == expected_space

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, this is another thing I'll change after gh-19005 merges.

args_xp = [xp.asarray(arg) for arg in args_np]

ref = f(*args_np)
res = f(*args_xp)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we test the out argument?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out is tricky. It's not part of the array API standard. It will work in principle with PyTorch and CuPy, but not with JAX. It's not used a lot, and there are good reasons to avoid it. So I'd suggest to keep out= specific to numpy arrays.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, so I won't test it, but hopefully it is covered by tests that were already written for scipy.special. (I've checked that it works, though.)

import numpy as np
from scipy import special
x = np.asarray([1])
y = np.asarray([1.])
special.ndtr(x, y)
print(y)  # [0.84134475]

@rgommers rgommers added the array types Items related to array API support and input array validation (see gh-18286) label Aug 12, 2023
@mdhaber
Copy link
Contributor Author
mdhaber commented Aug 14, 2023

@rgommers does this look closer to what you had in mind?

Copy link
Member
@rgommers rgommers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks close to what we want I think, thanks @mdhaber.

One thing seems wrong with the array type handling:

>>> special.ndtri_exp([-3])
array([-1.64692172])
>>> special.ndtr([-3])
Traceback (most recent call last):
...
TypeError: The input is not a supported <
CEB7
span class="pl-s1">array type

For array-like input we should preserve the current behavior.

elif xp.__name__ == f"{array_api_comat_prefix}.jax":
f = getattr(xp.scipy.special, f_name, None)
else:
f = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's correct? There is:

        elif hasattr(x, '__array_namespace__'):
            namespaces.add(x.__array_namespace__(api_version=api_version))

so any object with a __array_namespace__ will be accepted by design.

xp = array_namespace(*args[:n_args])
f = get_array_special_func(f_name, xp)
return f(*args, **kwargs)
wrapped.__doc__ = getattr(_ufuncs, f_name).__doc__
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably best not to do that in this decorator. Programmatically appending notes is what we tried in various flavors before, and it usually doesn't work well without a ton of effort. For now we can keep it in the module-level docstring for special I'd say.

args_xp = [xp.asarray(arg) for arg in args_np]

ref = f(*args_np)
res = f(*args_xp)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out is tricky. It's not part of the array API standard. It will work in principle with PyTorch and CuPy, but not with JAX. It's not used a lot, and there are good reasons to avoid it. So I'd suggest to keep out= specific to numpy arrays.

ref = f(*args_np)
res = f(*args_xp)
# having trouble with this locally
assert type(res) == type(xp.asarray([]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think type(x) is actually formally guaranteed to work, but I think it always will in practice. I'm not sure why this doesn't work for you (it does seem to for me), but maybe this does?

isinstance(x, type(xp.asarray([])))

# having trouble with this locally
assert type(res) == type(xp.asarray([]))
assert res.shape == ref.shape
assert_allclose(np.asarray(res), ref)
Copy link
Member
@lucascolley lucascolley Aug 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless I am missing something, this np.asarray will fail for CuPy and PyTorch GPU. There has been discussion in #19005 of how best to go about testing on GPU. I have added this utility to conftest.py:

def set_assert_allclose(xp=None):
    if xp is None:
        return npt.assert_allclose
    if 'cupy' in xp.__name__:
        return xp.testing.assert_allclose
    elif 'torch' in xp.__name__:
        return xp.testing.assert_close
    return npt.assert_allclose

This utility is called before every use of assert_allclose:

 def test_ifft(self, xp):
     x = xp.asarray(random(30) + 1j*random(30))
     _assert_allclose = set_assert_allclose(xp)
     _assert_allclose(x, fft.ifft(fft.fft(x)))

Copy link
Contributor Author
@mdhaber mdhaber Aug 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I'll also implement this if gh-19005 merges first.

But what will be the canonical way to convert any Array API array to a different Array API array (if not xp.asarray, where xp is the desired type)?

In gh-19005 I see:

    xp = array_namespace(x)
    x = np.asarray(x)
    ...
    return xp.asarray(y)

and this is what was suggested in gh-18668, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I can tell, this shouldn't be needed outside of tests. For tests, the use of set_assert_allclose is a deliberate attempt to avoid this as well.

The conversion to np and back is just so that we can make use of SciPy's existing functions, which only work for NumPy arrays. In general, there shouldn't be a need to convert between different array libraries.

Copy link
Contributor Author
@mdhaber mdhaber Aug 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there was confusion here because I was lumping np in with other Array API arrays. I guess that is not correct.

You wrote

Unless I am missing something, this np.asarray will fail for CuPy and PyTorch GPU

In that case, my question was how should we be converting from any xp array to a np array outside of tests? (And until gh-19005 merges, I will do the conversion this way in tests before using np.testing.assert_allclose, too.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sorry for the confusion. We just use np.asarray, with import numpy as np. This is not the array API compatible version, but that doesn't matter since this is our 'last resort'. We expect it to raise exceptions, with GPU arrays for example. This can't be fixed without a general dispatch system, but that is out of scope for the array API work, hence your efforts to call to some particular libraries here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have an example of a conversion which you think should be possible? As far as I understand, the array API is supposed to be completely agnostic of the underlying library. Are you thinking of something like a general to_numpy function? This is something unsuitable for the array API, as it should hold up fine, regardless of the status of one particular project. I.e. forcing libraries to interact with other libraries in order to comply is undesirable.

Copy link
Contributor Author
@mdhaber mdhaber Aug 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have an example of a conversion which you think should be possible?

I would expect torch.asarray(np.asarray([1, 2, 3])) to return a PyTorch array, and np.asarray(torch.asarray([1, 2, 3])) to return a NumPy array. Indeed, they do if they are both imported from array_api_compat, so I would expect the same if we were to replace np and torch with any two Array libraries imported from array_api_compat.

Are you thinking of something like a general to_numpy function?

I was expecting that np.asarray, where np is an Array API compatible version of NumPy, would convert any Array API array to a NumPy array.

Note that in the referenced code, np is defined by:

from scipy._lib.array_api_compat.array_api_compat import numpy as np

in case that makes a difference.

Copy link
Member
@lucascolley lucascolley Aug 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got you 👍. I am not so sure, but I believe that the blocker for this is device transfers. Imagine two array libraries which use different devices, both of which do not exist yet, but do in this hypothetical future.

How could we standardise the ability to transfer between these devices? It seems to me that you are hoping for the ability to transfer between any devices which could possibly be used by libraries which support the standard. Since the standard does not place any restrictions on the devices a library supports, I don't think that this is possible.

I think that we're talking about the same thing here, but apologies if I've misunderstood.

cc @rgommers for clarification as I may also be misunderstanding the standard.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expect torch.asarray(np.asarray([1, 2, 3])) to return a PyTorch array, and np.asarray(torch.asarray([1, 2, 3])) to return a NumPy array

Yes indeed - when the data is on the same device.

How could we standardise the ability to transfer between these devices?

We can't really. At least, we haven't quite figured out how to best do this for testing purposes specifically. We do know that we never want to do this for actual library code.

Are we on the same page here now? The current PR diff says:

# TODO: use `set_assert_allclose` when gh-19005 merges

which sounds right (and if not, we'll have figured this out on gh-19005).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do know that we never want to do this for actual library code.

You asked that the else case be to convert to NumPy, perform the computation, and convert back (#19023 (comment)). So even for some Array API compatible arrays, this will just fail, and users will get whatever error np.asarray gives them?

Copy link
Member
@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TypeError: The input is not a supported array type

For array-like input we should preserve the current behavior.

This error is coming from array_api_compat: https://github.com/data-apis/array-api-compat/blob/546fa3d65dde1cc20c68b2d22241f7977a8c08fa/array_api_compat/common/_helpers.py#L102-L104

@rgommers I believe that, following the cluster PR, we actually do not support array-likes when SCIPY_ARRAY_API is set. SciPy's array_namespace calls compliance_scipy, which calls is_array_api_obj from array_api_compat, which returns False for array-likes:

def is_array_api_obj(x):
    """
    Check if x is an array API compatible array object.
    """
    return _is_numpy_array(x) \
        or _is_cupy_array(x) \
        or _is_torch_array(x) \
        or hasattr(x, '__array_namespace__')

compliance_scipy raises TypeError: Only support Array API compatible arrays in this situation.

So, if we do want to support array-likes (at least for now), I think a change is needed to compliance_scipy.

# having trouble with this locally
assert type(res) == type(xp.asarray([]))
assert res.shape == ref.shape
assert_allclose(np.asarray(res), ref)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I can tell, this shouldn't be needed outside of tests. For tests, the use of set_assert_allclose is a deliberate attempt to avoid this as well.

The conversion to np and back is just so that we can make use of SciPy's existing functions, which only work for NumPy arrays. In general, there shouldn't be a need to convert between different array libraries.

@mdhaber
Copy link
Contributor Author
mdhaber commented Aug 16, 2023

For array-like input we should preserve the current behavior.

This is the error you get when you pass a list to array_namespace. Whether that's the desired behavior or not is debateable (I'd be happy with lists being converted to NumPy arrays), but my understanding is that we are supposed to pass the input into array_namespace per the top post of gh-18668.

Maybe what we need to do is skip all this stuff unless the SCIPY_ARRAY_API environment variable is set?

@lucascolley
Copy link
Member

Maybe what we need to do is skip all this stuff unless the SCIPY_ARRAY_API environment variable is set?

This already happens in the array_namespace you can import from _lib._array_api, so importing from the correct location will fix this.

However, my comments above pertain to how to treat array-likes when the environment variable is set. This is what may involve a change to compliance_scipy, although maybe we do want to refuse array-likes when the environment variable is set.

@mdhaber
Copy link
Contributor Author
mdhaber commented Sep 8, 2023

Erg, well, as written in a11333c, this also means that the docstring is only modified when SCIPY_ARRAY_API=1. So the note is not added to the rendered documentation in that commit. And AttributeError: attribute '__doc__' of 'numpy.ufunc' objects is not writable, so I can't seem to change the documentation and maintain ufunc-ness. So I'll remove the notes for now and save documentation for a follow-up.

@mdhaber
Copy link
Contributor Author
mdhaber commented Sep 9, 2023

OK, I think the suggestions will resolve the lint and MyPy issues. Maybe we can figure out how to do the documentation in a followup.

@rgommers
Copy link
Member
rgommers commented Sep 9, 2023

Erg, well, as written in a11333c, this also means that the docstring is only modified when SCIPY_ARRAY_API=1.

That could perhaps be handled by setting SCIPY_ARRAY_API=1 in the doc build job. But that's not a great solution.

So the note is not added to the rendered documentation in that commit. And AttributeError: attribute '__doc__' of 'numpy.ufunc' objects is not writable, so I can't seem to change the documentation and maintain ufunc-ness.

Amending docstrings in scipy/special/_add_newdocs.py would be the way to do that.

So I'll remove the notes for now and save documentation for a follow-up.

That seems like a good idea indeed.

@mdhaber
Copy link
Contributor Author
mdhaber commented Sep 9, 2023

OK. Well, aside from needing #19023 (comment), CI seems to be happy. Then I guess we can just wait for #19186 to get the stricter array checks. After this merges, I can open an issue about what to do about the ufuncness of these functions - whether we are OK with these no longer being ufuncs, or if there is some way to get the best of both worlds.

@izaid
Copy link
A92E
Contributor
izaid commented Sep 9, 2023

@izaid I was in the same situation as you and am trying to make a push to improve things in special. I plan to make a tracking issue and would like to ping you on that to get your opinion on which shortcomings it would be good to address sooner than later.

Thanks @steppi (and @mdhaber)! Just circling back here, I'm still keen to help. Did this tracking issue get created?

Could SciPy just have a lot of header files for special functions that other projects can easily include?

That's an interesting idea. Can you submit an issue requesting this? This isn't really the place to discuss that at the depth I'd like to.

Certainly, will soon create said issue. I agree, it's at least worth a proper discussion. Thanks for taking the time to respond earlier!

@mdhaber
Copy link
Contributor Author
mdhaber commented Sep 9, 2023

Did this tracking issue get created?

Not yet, but I was actually going to email you this weekend about a real-time meeting, if you're up for that. I'll also add you to the spreadsheet I made.

@mdhaber
Copy link
Contributor Author
mdhaber commented Sep 11, 2023

Updated with xp_assert_close now that gh-19186 has merged.

Copy link
Contributor
@steppi steppi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the whole I think this looks good. I have some minor questions but see no major issues here.

Comment on lines 26 to 27
elif xp.__name__ == f"{array_api_compat_prefix}.jax":
f = getattr(xp.scipy.special, f_name, None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just want to confirm I'm on the right page regarding jax support. It seems at the moment jax is not treated as array api compatible. Currently array_namespace considers jax arrays to be numpy arrays and they end up getting dispatched to the regular scipy.special version of a supported function, which converts them to a numpy array and returns numpy array output.

Am I right that the code path above is for future proofing, so that once jax is added to array_api_compat, things should just work? I guess when jax is added to array_api_compat, there will be an is_jax helper function that can be used here, but until then it is good to have the correct way of dispatching to jax already in the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - perhaps I wrote it originally not knowing that JAX wasn't supported. It was decided to leave it in for future-proofing (#19023 (comment)).

Copy link
Contributor
@steppi steppi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modulo the question of the name support_cupy_torch_jax, I think this looks good.

@mdhaber
Copy link
Contributor Author
mdhaber commented Sep 16, 2023

Renamed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
array types Items related to array API support and input array validation (see gh-18286) enhancement A new feature or improvement scipy.special
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants
0