Is there anything similar to this for something like Tensorflow, Keras or Pytorch? I haven't used them super recently, but in the past I needed to do all of the things you just described in painful to debug ways.
Same - I’ve been using einops and jaxtyping together pretty extensively recently and it helps a lot for reading/writing multidimensional array code. Also array_api_compat, the API coverage isn’t perfect but it’s pretty satisfying to write code that works for both PyTorch and numpy arrays