As someone that has just started learning machine learning I can say that, in my experience, Pytorch is way more beginner friendly than TF. I had a hard time to setup TF in order for it work with my GPU. On the other hand pytorch was a breeze.
JAX feels like the new kid on the block. Every time that I see it is because it has made a project a loot faster (eg. WhisperJAX or the stable diffusion in JAX).
Can anyone explain me why is it so fast is it because of parallelism?