Hacker News new | past | comments | ask | show | jobs | submit login
Keras Core: Keras for TensorFlow, Jax, and PyTorch (keras.io)
191 points by dewitt on July 11, 2023 | hide | past | favorite | 69 comments



I think that is pretty cool - literally made me screen "Yes!" when I saw it and I don't do this for your everyday framework.

I think the beauty of keras was the perfect balance between simplicity/abstraction and flexibility. I moved to PyTorch eventually but one thing I always missed was this. And now, to have it leapfrog the current fragmentation and just achieve what seems to be a true multi-backend is pretty awesome.

Looking forward to the next steps!


Keras was already that some years ago. It supported tensorflow, theano, mxnet if my memory is right. And then they ditched everything for tensorflow. At the time it was really hard to use keras without calling backend directly for lots of optimisation, unsupported feature on they API etc... This make the use of Keras not agnostic at all. What's different now ?


> What's different now ?

PyTorch adoption: back when Keras went hard into TensorFlow in 2018, both TF and PyTorch adoption were about the same with TF having a bit more popularity. Now, most of the papers and models released are PyTorch-first.


Yes I understand why they do the move (they want to attract pytorch user). What's the benefit for the user instead of directly using pytorch for example ? I see we can maybe use tpu by switching to jax etc...

PS: sorry I'm a bit salty by my user experience of Keras.


Pytorch is an animal by itself when you try to put it into production. They have started addressing it with torch 2.0 but it still has lengths to go. With this you can switch to TFserve if you have usual architecture.


You can just use Triton which is basically TFserve for Tensorflow, Pytorch, Onnx and more.


Can you explain that?

My understand of Triton is more that this is an alternative to CUDA, but instead you write it directly in Python, and on a slightly higher-level, and it does a lot of optimizations automatically. So basically: Python -> Triton-IR -> LLVM-IR -> PTX.

https://openai.com/research/triton


It's confusing, there's OpenAI Triton (what you're thinking of) and Nvidia Triton server (a different thing).


Original comment is referring to Nvidia triton inference server


Tensorflow has some advantages, like being able to use tf-lite for embedded devices. JAX is amazing on the TPU, which AFAIK pytorch doesn't have great support for.

I assume most people will still research in PyTorch, but then move it over the Keras for production models if they need multi-platform support.


Keras has a cleaner API compared to base PyTorch, especially if you want to use the Sequential construction as demoed in the post.


How so? You can use torch.nn.Sequential pretty much equivalently?

https://pytorch.org/docs/stable/generated/torch.nn.Sequentia...


Huh, didn't realize base PyTorch had an equivalent Sequential API.

The point about the better API overall still stands (notably including the actual training part, as base PyTorch requires you to implement your own loop)


Some people would say that this is an advantage of PyTorch, that it is very easy to write and play around with your own loop.

But if you don't want that, if you want to go a bit higher level w.r.t. training, there is e.g. PyTorch Lightning on top of PyTorch.

If you say the API is better in Keras, can you give any examples? They look kind of similar to me. Keras looks a bit more complicated, esp when looking at the internals (which I tend to do often, to better understand what's actually happening), which is a negative point.


Adding keras abstractions on top of pytorch seems like negative value-added to me. It only barely abstracts more the easy/mundane stuff, while creating yet another layer of indirection between ideas and machine code in the current Rube-Goldberg mess that people call modern ML.

Keras has always been "lipstick on a pig for tensorflow". Its value beyond that seems tenuous at best.


I worked on the project, happy to answer any questions!


How does the future look for TFLite and Edge AI/TinyML in general?

Will Keras Core support direct deployment to edge devices like RPi or Arduino?

Will the experience of defining and training a model in JAX/PyTorch and then deploying to edge devices be seamless?

Anything related on the roadmap?


Great to see this, but I’m curious, does this mean we’ll get fewer fchollet tweets that talk up TF and down PyTorch? Is the rivalry done?


Hi, first off thank you for your contributions, and this goes to the entire team. Keras is a wonderful tool and this was definitely the right move to do. No other package nails the “progressive disclosure” philosophy like Keras.

This caught my eye:

> “Right now, we use tf.nest (a Python data structure processing utility) extensively across the codebase, which requires the TensorFlow package. In the near future, we intend to turn tf.nest into a standalone package, so that you could use Keras Core without installing TensorFlow.”

I recently migrated a TF project to PyTorch (would have been great to have keras_core at the time) and used torch.nested. Could this not be an option?

A second question. For “customizing what happens in fit()”. Must this be written in either TF/PyTorch/Jax only, or can this be done with keras_core.ops, similar to the example shown for custom components? The idea would be you can reuse the same training loop logic across frameworks, like for custom components.


At this time, there are no backend-agnostic APIs to implement training steps/training loops, because each backend handles training very differently so no shared abstraction can exist (expecially for JAX). So when customizing fit() you have to use backend-native APIs.

If you want to make a model with a custom train_step that is cross-backend, you can do something like:

  def train_step(self, *args, *kwargs):
    if keras.config.backend() == "tensorflow":
      return self._tf_train_step(*args, *kwargs)
    elif ...
BTW it looks the previous account is being rate-limited to less than 1 post / hour (maybe even locked for the day) so I will be very slow to answer questions.


Just want to say I love Keras. Thank you for your work!


This looks awesome; I was a big fan of Keras back when it had pluggable backends and a much cleaner API than Tensorflow.

Fast forward to now, and my biggest pain point is that all the new models are released on PyTorch, but the PyTorch serving story is still far behind TF Serving. Can this help convert a PyTorch model into a servable SavedModel?


For a Keras Core model to be usable with the TF Serving ecosystem, it must be implemented either via Keras APIs (Keras layers and Keras ops) or via TF APIs.

To use pretrained models, you can take a look at KerasCV and KerasNLP, they have all the classics, like BERT, T5, OPT, Whisper, StableDiffusion, EfficientNet, YOLOv8, etc. They're adding new models regularly.


Congrats on the launch! I learned Keras back when I first got in to ML, so really happy to see it making a comeback. Are there some example architectures available/planned that are somewhat complex, and not just a couple layers (BERT, ResNet, etc.)?


Yes, you can check out KerasCV and KerasNLP which host pretrained models like ResNet, BERT, and many more. They run on all backends as of the latest releases (today), and converting them to be backend-agnostic was pretty smooth! It took a couple of weeks to convert the whole packages.

https://github.com/keras-team/keras-nlp/tree/master/keras_nl... https://github.com/keras-team/keras-cv/tree/master/keras_cv/...


This is an amazing contribution to the NN world. Thank you all the team members.


Firstly thanks to all the team for everything you have done and congrats on this. It must have been a ton of work and I am excited to get my hands on it.


Do you foresee any compatibility or integration issues with higher level frameworks, i.e. lightning, transformers, etc?


Thanks for helping to de-fragment a the AI ecosystem! I’ll look to get involved, test and collaborate with patches!


As someone who has dealt with countless breaking changes in keras and wasted days of my life attempting to upgrade, no thank you.

My pytorch code from years ago still works with no issues, my old keras code would break all the time even in minor releases.


Agreed. This will only break things, especially research code.


From the announcement:

"We're excited to share with you a new library called Keras Core, a preview version of the future of Keras. In Fall 2023, this library will become Keras 3.0. Keras Core is a full rewrite of the Keras codebase that rebases it on top of a modular backend architecture. It makes it possible to run Keras workflows on top of arbitrary frameworks — starting with TensorFlow, JAX, and PyTorch."

Excited about this one. Please let us know if you have any questions.


That looks very interesting.

I actually have developed (and am developing) sth very similar, what we call the RETURNN frontend, a new frontend + new backends for our RETURNN framework. The new frontend is supporting very similar Python code to define models as you see in PyTorch or Keras, i.e. a core Tensor class, a base Module class you can derive, a Parameter class, and then a core functional API to perform all the computations. That supports multiple backends, currently mostly TensorFlow (graph-based) and PyTorch, but JAX was something I also planned. Some details here: https://github.com/rwth-i6/returnn/issues/1120

(Note that we went a bit further ahead and made named dimensions a core principle of the framework.)

(Example beam search implementation: https://github.com/rwth-i6/i6_experiments/blob/14b66c4dc74c0...)

One difficulty I found was how design the API in a way that works well both for eager-mode frameworks (PyTorch, TF eager-mode) and graph-based frameworks (TF graph-mode, JAX). That mostly involves everything where there is some state, or sth code which should not just execute in the inner training loop but e.g. for initialization only, or after each epoch, or whatever. So for example:

- Parameter initialization.

- Anything involving buffers, e.g. batch normalization.

- Other custom training loops? Or e.g. an outer loop and an inner loop (e.g. like GAN training)?

- How to implement sth like weight normalization? In PyTorch, the module.param is renamed, and then there is a pre-forward hook, which on-the-fly calculates module.param for each call for forward. So, just following the same logic for both eager-mode and graph-mode?

- How to deal with control flow context, accessing values outside the loop which came from inside, etc. Those things are naturally possible eager-mode, where you would get the most recent value, and where there is no real control flow context.

- Device logic: Have device defined explicitly for each tensor (like PyTorch), or automatically eagerly move tensors to the GPU (like TensorFlow)? Moving from one device to another (or CPU) is automatic or must be explicit?

- How to you allow easy interop, e.g. mixing torch.nn.Module and Keras layers?

I see that you have keras_core.callbacks.LambdaCallback which is maybe similar, but can you effectively update the logic of the module in there?


Keras and PyTorch! I thought I'd never see the day! Glad to see the two communities bury the hatchet.


I don’t get it - why would you want Keras if you already use Pytorch?


The same reason why you might want to use Keras if you use any of the other backends. They operate at different levels.

Keras is a higher-level API. It means that you can prototype architectures quickly and you don't have to write a training loop. It's also really easy to extend.

I currently use PyTorch Lightning to avoid having to write tonnes of boilerplate code, but I've been looking for a way to leave it for ages as I'm not a huge fan of the direction of the product. Keras seems like it might be the answer for me.


What’s so hard about writing a training loop?


nothing is. but to write a basic training loop with proper logging etc. from scratch every time you want to train a basic neural net classifier seems inefficient to me. There should be a framework for it where you can just plug in your model and your data and it trains it in a supervised fashion. That's what fast.ai or keras are doing.


You only need to write a training loop function once. Then you can just pass to it a model, dataloader, etc, just like you would if you used a training loop written by someone else in Keras. The only difference is it would be hidden from you behind layers of wrappers and abstraction, making it harder to modify and debug.


It sounds like you've found something that works best for you, and that the large Keras user base has found something that works best for them.


The large Keras userbase exists largely because Tensorflow sucked.


I'll agree to disagree. I find great value in the Keras API.

It's also a bit histrionic, in that Keras was very popular with the Theano backend before that project wound down.


Keras is great for painful libraries like Theano, which was similar to early TF. Btw, many Theano users already used a higher level library called Lasagne, which was similar to Keras.

When I switched to TF in 2016 Keras was still in its infancy so I wrote a lot of low level TF code (eg my own batchnorm layer), but many new DL researchers struggled with TF because they didn’t have Theano experience. That steep learning curve led to the rise of Keras as the friendly TF interface.

Pytorch is a different story.


It's extremely easy to get wrong in subtle ways.


do you implement the sort function and the hash map from scratch every time you need them? If not, what's so hard about implementing a sorting algorithm?


Because sometimes you don’t want to write your own training loops, you just want a working method to train a model.


There are a lot of libraries for that. For example Pytorch Lightning, Accelerate are very mature


Sure, and Keras is another, very mature library which allows you to do this...


Keras + Pytorch is not mature.


If I want to use this brave new keras with torch.compile, what does that look like?


We are still working on this feature. We try to have it in model.compile(jit_compile=True). https://github.com/keras-team/keras-core/blob/v0.1.0/keras_c...


Wait so what happens if I use a model with torch backend now and call .compile()? Does it just return and then do normal torch jit when .fit() (or whatever the keras notation is, i have forgotten most of it) is called?


A lot of this seems like abstraction for abstractions sake. When would someone actually use this?


Same question for the decorator tf.function ?


Does this mean the weights output can be backend agnostic?

Also, are there any examples using this for the coral TPU?


Yes, model weights saved with Keras Core are backend-agnostic. You can train a model in one backend and reload it in another.

Coral TPU could be used with Keras Core, but via the TensorFlow backend only.


Super cool, does that mean if someone trains something using the PyTorch backend, I can still use it with the coral if I load the weights using the tensorflow backend?


That's right, if the model is backend-agnostic you can train it with a PyTorch training loop and then reload it and use it with TF ecosystem tools, like serve it with TF-Serving or export it to Coral TPU.


Can someone ELI5 the relationship between Keras and TensorFlow/Jax/PyTorch/etc? I kinda get the idea the Keras is the "frontend" and TF/Jax/PyTorch are the "backend" but I'm looking to solidify my understanding of the relationship. It might help to also comment on the key differences between TF/Jax/PyTorch/etc. Thank you.


Keras was a high-level wrapper around Theano or Tensorflow.

The creator of Keras was then employed by Google to work on Keras, who promised everyone Keras would remain backend agnostic.

Keras become part of Tensorflow as a high-level API and did not remain backend agnostic. There was lots of questionable twitter beef about Pytorch by the Keras creator.

Keras is now once again backend agnostic, as a high-level API for Tensorflow/PyTorch/Jax. Likely as them seeing Tensorflow losing traction.


Supporting multiple backends (especially Jax) is nice! Makes experimenting/migrating between them so much more approachable. Any timeline on when can we expect support for distributed Jax training? The doc currently seems to indicate only TF is supported for distributed training.


Support for distributed JAX training demoed here: bit.ly/keras-on-jax-demo You have to write a custom training loop for now, but it works.


Thanks!


IIRC, Keras was added officially added to Tensorflow as part of the version 2.0 release. With Keras reverting to its backend-agnostic state, will it be removed from Tensorflow? Is this a divorce or are TF & Keras just opening up their relationship?


Wouldn't a multi-framework wrapper be a subset of any supported framework's features common among all frameworks?

Additionally would it always be at least a step behind any framework depending on the wrapper's release cycle?


I must admit i've never actually used keras but this is interesting to see how they are implementing it with Jax, definitely worth a reminder to dig into one of these days.



Like everything keras, this will promise a lot and only deliver on conops deemed worthy by the keras team.


will keras be backward compatible, or as always and now google/tf ecosystem will have 3 gens of frameworks: tf1, tf2 + keras, keras core 3.




Join us for AI Startup School this June 16-17 in San Francisco!

Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: