JAX Performance: From 63 Minutes to 2 Minutes
30x Speedup with vmap: Why JAX is Built for RL Research
I continue to be impressed with the JAX library as I learn how to use it. In my last post I mentioned switching from PyTorch to JAX for the alberta-framework. I had seen about a 2.78x speedup just …