8000 [WIP] Adding minimal infrastructure for transforming jax function to unit-aware by dfm · Pull Request #1 · dfm/jpu · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[WIP] Adding minimal infrastructure for transforming jax function to unit-aware #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

dfm
Copy link
Owner
@dfm dfm commented Jun 1, 2022

This builds on conversations on Twitter to sketch an interface for transforming a raw JAX function (using the jax.numpy interface directly) into one that supports units. For example:

@units
def func(x, y):
    return jnp.exp(x / (0.5 * y) + 2.3)

u = UnitRegistry()
func(jnp.ones(3) * u.m, jnp.array([5.6]) * u.km)

or

@partial(units, input_units=["m", "km"])
def func(x, y):
    return jnp.exp(x / (0.5 * y) + 2.3)

func(jnp.ones(3), jnp.array([5.6]))

should both work.

This is so far (very!) incomplete. Some things to do / think about:

  • What to do about literals with units? I think the best bet would be to add a add_units or make_quantity primitive that decorates the literal with units that we can use when transforming the jaxpr. This might also be useful for implementing correct derivative rules without overloading grad. Other ideas?
  • Implement some more primitives. What's a good list to start with?
  • etc.

/cc @shoyer @mattjj @sschoenholz @patrick-kidger

@patrick-kidger
Copy link

I'm a bit skeptical about the implementation.

For one thing I don't think it really needs a registry of operation-to-operation mappings. At the moment it's essentially a layer of checking for "is this jaxpr correct". This can be done without operation substitution, by simply analysing the jaxpr and raising an error if e.g. two disparate units are added together.

For another I don't think this will scale well: if I call @units on a function that itself calls @units then the inner one will have units_jaxpr evaluated on it multiple times, which isn't necessary and will probably explode if one has a heavily @units-decorated codebase. Probably nested @units calls should seamlessly join up into a single level of checking (somehow), and each time they encounter each other they should make sure that the units being propagated by the outer level agree with the units annotated on the inner level.

(+various nits: **kwargs isn't passed to units_jaxpr; what if an argument has a magnitude attribute; for an input_units API I'd suggesting copying the API Equinox uses for e.g. filter_vmap which uses inspect.signature and signature.bind_partial etc.)

@dfm
Copy link
Owner Author
dfm commented Jun 1, 2022

For one thing I don't think it really needs a registry of operation-to-operation mappings. At the moment it's essentially a layer of checking for "is this jaxpr correct". This can be done without operation substitution, by simply analysing the jaxpr and raising an error if e.g. two disparate units ar 81BF e added together.

Not really! It's actually also doing unit conversions in the cases when that's necessary. It would be possible to provide a more minimal API that just checks, but that's not what I need!

For another I don't think this will scale well: if I call @units on a function that itself calls @units then the inner one will have units_jaxpr evaluated on it multiple times, which isn't necessary and will probably explode if one has a heavily @units-decorated codebase. Probably nested @units calls should seamlessly join up into a single level of checking (somehow), and each time they encounter each other they should make sure that the units being propagated by the outer level agree with the units annotated on the inner level.

Totally agreed - I'd love to hear suggestions for approaches to this!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants
0