Hacker News new | past | comments | ask | show | jobs | submit login

How is it compared to JAX? After TensorFlow and PyTorch, JAX seems very simple, basically an accelerated numpy with just a few additional useful features like automatic differentiation, vectorization and jit-compilation. In terms of API I don't see how you can go any simpler.



JAX is a DSL on top of XLA, instead of writing Python. Example: a JAX for loop looks like this:

   def summ(i, v): return i + v
   x = jax.lax.fori_loop(0, 100, summ, 5)
A for loop in TinyGrad or PyTorch looks like regular Python:

   x = 5
   for i in range(0, 100):
      x += 1
By the way, PyTorch also has JIT.


I've just tried making a loop in a jit-compiled function and it just worked:

    >>> import jax
    >>> def a(y):
    ...   x = 0
    ...   for i in range(5):
    ...     x += y
    ...   return x
    ...
    >>> a(5)
    25
    >>> a_jit = jax.jit(a)
    >>> a_jit(5)
    DeviceArray(25, dtype=int32, weak_type=True)


It definitely works, JAX only sees the unrolled loop:

  x = 0
  x += y
  x += y
  x += y
  x += y
  x += y
  return x
The reason you might need `jax.lax.fori_loop` or some such is if you have a long loop with a complex body. Replicating a complex body many times means you end up with a huge computation graph and slow compilation.


And how does TinyGrad solve this?


Fused into one operation since the Tensor isn't resolved until I call .numpy()

  kafka@tubby:/tmp$ cat fuse.py 
  from tinygrad.tensor import Tensor
  x = Tensor.zeros(1)
  for i in range(5):
    x += i
  print(x.numpy())

  kafka@tubby:/tmp$ OPT=2 GPU=1 DEBUG=2 python3 fuse.py 
  using [<pyopencl.Device 'Apple M1 Max' on 'Apple' at 0x1027f00>]
  **CL**      0 elementwise_0        args     1  kernels [1, 1, 1]          None         OPs     0.0M/   0.00G  mem  0.00 GB tm      0.15us/     0.00ms (    0.03 GFLOPS)
  **CL**        copy OUT (1,)
  [10.]


How does this differ from XLA? Would tinygrad's lazy approach also just see the same unrolled loop right before compilation?


He mentioned in a recent stream that he dislikes the complexity of the XLA instruction set used by JAX. So it's less the user-facing API, and more the inner workings of the library.




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: