> As much as people like to use NumPyro and sometimes even PyMC to generate JAX code, I think it may be easier in the end to just write JAX directly. That way, nothing gets between you and JAX and you don’t have to figure out how to filter JAX through middleware. When you do that, the models can be organized very much like in Stan.
^much truth. Nascent libraries like distreqx make it much easier to work at a lower level while maintaining some of the log density affordances that PPLs provide.
> As much as people like to use NumPyro and sometimes even PyMC to generate JAX code, I think it may be easier in the end to just write JAX directly. That way, nothing gets between you and JAX and you don’t have to figure out how to filter JAX through middleware. When you do that, the models can be organized very much like in Stan.
^much truth. Nascent libraries like distreqx make it much easier to work at a lower level while maintaining some of the log density affordances that PPLs provide.
https://github.com/lockwo/distreqx