DeepMind on JAX
AI lab tooling long read #1: DeepMind published a blog post about using JAX to accelerate their research. JAX is a modern take on the NumPy API that “includes an extensible system of composable transformation that help support machine learning research” by taking care of differentiation, vectorization (like abstracting batching away from the researcher), and JIT-compilation (for GPUs and TPUs). The Python library now underpins many of DeepMind’s recent publications, and they’ve also open-sourced several components of their internal ecosystem on top of JAX: Haiku, Optax, RLax, Chex, and Jraph (“it’s pronounced gif ”).