Hacker News new | past | comments | ask | show | jobs | submit login

>MaxText was just very difficult to work with. We felt like we were fighting against it every time we needed to change something because we would be digging through numerous needless layers of abstraction. My favorite was after one long day of debugging, I found a function who's only purpose was to pass its arguments to another function untouched; this function's only purpose was to pass its arguments untouched to a new, third function, that then slightly changed them and passed them to a fourth function that did the work

Some of this complexity may be necessary for achieving optimal performance in Jax. E.g. extra indirection to avoid the compiler making some bad fusion decision, or multiple calls so something can be marked as static for the jit in the outer call. As far as I'm aware MaxText is the only public Jax codebase that's demonstrated scaling to models with 100s of billions of weights. I've just started evaluating it and it seems to scale better than the Torch implementation I was using previously (even on GPU). Most of the abstraction seems to have a reason behind it (at least for me since I'm making some modifications to the vanilla model, which is easier when the components are less tightly coupled).




> Some of this complexity may be necessary for achieving optimal performance in Jax. E.g. extra indirection to avoid the compiler making some bad fusion decision, or multiple calls so something can be marked as static for the jit in the outer call

certainly some of it is but not the lion's share - I have a much simpler (private) codebase which scales pretty similarly afaict.

the complexity of Maxtext feels more Serious Engineering ™ flavored, following Best Practices.




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: