How is it compared to JAX? After TensorFlow and PyTorch, JAX seems very simple, basically an accelerated numpy with just a few additional useful features like automatic differentiation, vectorization and jit-compilation. In terms of API I don't see how you can go any simpler.
I've just tried making a loop in a jit-compiled function and it just worked:
>>> import jax
>>> def a(y):
... x = 0
... for i in range(5):
... x += y
... return x
...
>>> a(5)
25
>>> a_jit = jax.jit(a)
>>> a_jit(5)
DeviceArray(25, dtype=int32, weak_type=True)
It definitely works, JAX only sees the unrolled loop:
x = 0
x += y
x += y
x += y
x += y
x += y
return x
The reason you might need `jax.lax.fori_loop` or some such is if you have a long loop with a complex body. Replicating a complex body many times means you end up with a huge computation graph and slow compilation.
Fused into one operation since the Tensor isn't resolved until I call .numpy()
kafka@tubby:/tmp$ cat fuse.py
from tinygrad.tensor import Tensor
x = Tensor.zeros(1)
for i in range(5):
x += i
print(x.numpy())
kafka@tubby:/tmp$ OPT=2 GPU=1 DEBUG=2 python3 fuse.py
using [<pyopencl.Device 'Apple M1 Max' on 'Apple' at 0x1027f00>]
**CL** 0 elementwise_0 args 1 kernels [1, 1, 1] None OPs 0.0M/ 0.00G mem 0.00 GB tm 0.15us/ 0.00ms ( 0.03 GFLOPS)
**CL** copy OUT (1,)
[10.]
He mentioned in a recent stream that he dislikes the complexity of the XLA instruction set used by JAX. So it's less the user-facing API, and more the inner workings of the library.