Skip to main content
All terms
Frameworks & Tools

JAX

A Google library for fast number-crunching that auto-computes gradients and compiles code for speed.

Definition

JAX is a Google library for high-performance numerical computing that pairs NumPy-style array math with automatic differentiation (computing the gradients, or rates of change, that training needs) and XLA, a compiler that turns the code into fast machine instructions. Built-in transforms let the same function run in parallel or spread training across many GPUs and TPUs (specialized AI chips). Its style favors large-scale training and it underlies several frontier model stacks.