Hacker Newsnew | past | comments | ask | show | jobs | submitlogin

Are you able to share a bit, enough to explain to others doing similar work that this "Jax > numpy" aspect applies to what their work (and thus that they'd be well-off to learn enough Jax to make use of it themselves)?


This should be a good starting point:

https://docs.jax.dev/en/latest/jax.numpy.html

A lot of this really is a drop in replacement for numpy that runs insanely fast on the GPU.

That said you do need to adapt to its constraints somewhat. Some things you can't do in the jitted functions, and some things need to be done differently.

For example, finding the most common value along some dimension in a matrix on the GPU is often best done by sorting along that dimension and taking a cumulative sum, which sort of blew my mind when I first learnt it.




Consider applying for YC's Winter 2026 batch! Applications are open till Nov 10

Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: