Are you able to share a bit, enough to explain to others doing similar work that this "Jax > numpy" aspect applies to what their work (and thus that they'd be well-off to learn enough Jax to make use of it themselves)?
A lot of this really is a drop in replacement for numpy that runs insanely fast on the GPU.
That said you do need to adapt to its constraints somewhat. Some things you can't do in the jitted functions, and some things need to be done differently.
For example, finding the most common value along some dimension in a matrix on the GPU is often best done by sorting along that dimension and taking a cumulative sum, which sort of blew my mind when I first learnt it.