Matthew Johnson, Google
GTC 2020
JAX is a system for high-performance machine-learning research. It offers the familiarity of Python+NumPy together with hardware acceleration, and it enables the definition and composition of user-wielded function transformations useful for machine-learning programs. These transformations include automatic differentiation, automatic batching, end-to-end compilation (via XLA), parallelizing over multiple accelerators, and more. Composing these transformations is the key to JAX's power and simplicity. JAX had its initial open-source release in December 2018. Researchers use it for a wide range of advanced applications, from studying training dynamics of neural networks to probabilistic programming to scientific applications in physics and biology. We'll introduce JAX and its core-function transformations with a live demo. You'll learn about JAX's core design, how it's powering new research, and how you can start using it, too.