JAX Container Early Access

JAX is a library for high-performance numerical computing and machine learning research.

To help developers to get up and running quickly with JAX, we’ve partnered with the JAX team to provide a container that includes JAX, FLAX (neural network library), and a set of dependencies tested for performance. Our early release includes examples for scaling large language model training. With simple scaling primitives in JAX, you can train large models based on GPT and T5 across multi-GPU and multi-node. JAX also has applications in drug discovery, physics ML, reinforcement learning and neural graphics. We will continue to provide examples across different use cases in JAX.

JAX Container is now available for private early access.

