Jax preallocates 90% of available GPU memory when first operation is run to minimize allocation overhead. Can PyTorch grab that VRAM for a similar reason?
Yes PyTorch uses what they call a caching memory allocator[0], basically seems like are allocating a very chunk of GPU memory and implementing a heap with it. If needed they expose some knobs and functions to allow you to control it and observe the memory usage.