Tags: genjax-community/genjax
Tags
Add support for sample_shape to primitive distributions (#1576) This PR passes any `sample_shape` argument provided to a tfp distribution through to the `sample` method, instead of passing it to the constructor. @femtomc one issue that I hit here was that to use this, I have to wrap the argument in `Const`, because it seems like our code tries to trace non-jit-compiled fns, or something like that. I'll add the error I saw as a reply. Co-authored-by: Mathieu Huot <MathieuHuot@users.noreply.github.com>
Update usage of JAX partial_eval / wrap_init to accept debug info (#1563 ) JAX claims this will improve the error messages coming from inside interpreters / metaprogramming. Excited to see it! Edit: I had to bump the JAX versioning, and update the lockfile -- hence the large LoC change. Edit2: I also went through and corrected a bunch of deprecation warnings from JAX (and our own `MixtureCombinator`).
Fix assess in vmap (GEN-903) (#1464) This PR modifies `assess` in vmap to vmap over indices and query the choicemap for each, rather than vmapping over the `ChoiceMap` itself. The error @georgematheos was hitting was that his choicemap had masks inside where a vmapped value was masked with a scalar bool. It's not possible to vmap over this because the scalar bool doesn't have the same leading dimension.
PreviousNext