Open
Description
Hi,
I am interested in using dask as a vehicle for parallelization of an existing CPU/GPU code which is written using JAX. I am pretty naive/new to dask, but I see that there is some documentation on use for GPUs/ a blog post where the backend for array creation can be cupy, for example. Is there something similar planned/available for JAX?
Thanks!