Interesting perspective about possible Jax optimizations. Assuming these models are trained and deployed on non-TPU hardware, are there any real advantages in using Jax for deployment on GPU? I’d have assumed that inference is largely a solved optimization for large transformer based models (with any low hanging fruits from custom CUDA code already written) and the details are shifting towards infrastructure tradeoffs and availability of efficient GPUs. But I may be out of the loop with the latest gossip. Or do you simply mean that maybe there exist cases where TPU inference makes sense financially and using jax makes a difference?