GTC 2020: JAX: Accelerating Machine-Learning Research with Composable Function Transformations in Python
After clicking “Watch Now” you will be prompted to login or join.
Click “Watch Now” to login or join the NVIDIA Developer Program.
JAX: Accelerating Machine-Learning Research with Composable Function Transformations in Python
Matthew Johnson, Google
JAX is a system for high-performance machine-learning research. It offers the familiarity of Python+NumPy together with hardware acceleration, and it enables the definition and composition of user-wielded function transformations useful for machine-learning programs. These transformations include automatic differentiation, automatic batching, end-to-end compilation (via XLA), parallelizing over multiple accelerators, and more. Composing these transformations is the key to JAX's power and simplicity. JAX had its initial open-source release in December 2018. Researchers use it for a wide range of advanced applications, from studying training dynamics of neural networks to probabilistic programming to scientific applications in physics and biology. We'll introduce JAX and its core-function transformations with a live demo. You'll learn about JAX's core design, how it's powering new research, and how you can start using it, too.