Skip to content

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

License

Notifications You must be signed in to change notification settings

jax-ml/jax

Repository files navigation

logo

Transformable numerical computing at scale

Continuous integrationPyPI version

Transformations | Scaling | Install guide | Change logs | Reference docs

What is JAX?

JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning.

JAX can automatically differentiate native Python and NumPy functions. It can differentiate through loops, branches, recursion, and closures, and it can take derivatives of derivatives of derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation) via jax.grad as well as forward-mode differentiation, and the two can be composed arbitrarily to any order.

JAX uses XLA to compile and scale your NumPy programs on TPUs, GPUs, and other hardware accelerators. You can compile your own pure functions with jax.jit. Compilation and automatic differentiation can be composed arbitrarily.

Dig a little deeper, and you'll see that JAX is really an extensible system for composable function transformations at scale.

This is a research project, not an official Google product. Expect sharp edges. Please help by trying it out, reporting bugs, and letting us know what you think!

importjaximportjax.numpyasjnpdefpredict(params, inputs): forW, binparams: outputs=jnp.dot(inputs, W) +binputs=jnp.tanh(outputs) # inputs to the next layerreturnoutputs# no activation on last layerdefloss(params, inputs, targets): preds=predict(params, inputs) returnjnp.sum((preds-targets)**2) grad_loss=jax.jit(jax.grad(loss)) # compiled gradient evaluation functionperex_grads=jax.jit(jax.vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example grads

Contents

Transformations

At its core, JAX is an extensible system for transforming numerical functions. Here are three: jax.grad, jax.jit, and jax.vmap.

Automatic differentiation with grad

Use jax.grad to efficiently compute reverse-mode gradients:

importjaximportjax.numpyasjnpdeftanh(x): y=jnp.exp(-2.0*x) return (1.0-y) / (1.0+y) grad_tanh=jax.grad(tanh) print(grad_tanh(1.0)) # prints 0.4199743

You can differentiate to any order with grad:

print(jax.grad(jax.grad(jax.grad(tanh)))(1.0)) # prints 0.62162673

You're free to use differentiation with Python control flow:

defabs_val(x): ifx>0: returnxelse: return-xabs_val_grad=jax.grad(abs_val) print(abs_val_grad(1.0)) # prints 1.0print(abs_val_grad(-1.0)) # prints -1.0 (abs_val is re-evaluated)

See the JAX Autodiff Cookbook and the reference docs on automatic differentiation for more.

Compilation with jit

Use XLA to compile your functions end-to-end with jit, used either as an @jit decorator or as a higher-order function.

importjaximportjax.numpyasjnpdefslow_f(x): # Element-wise ops see a large benefit from fusionreturnx*x+x*2.0x=jnp.ones((5000, 5000)) fast_f=jax.jit(slow_f) %timeit-n10-r3fast_f(x) %timeit-n10-r3slow_f(x)

Using jax.jit constrains the kind of Python control flow the function can use; see the tutorial on Control Flow and Logical Operators with JIT for more.

Auto-vectorization with vmap

vmap maps a function along array axes. But instead of just looping over function applications, it pushes the loop down onto the function’s primitive operations, e.g. turning matrix-vector multiplies into matrix-matrix multiplies for better performance.

Using vmap can save you from having to carry around batch dimensions in your code:

importjaximportjax.numpyasjnpdefl1_distance(x, y): assertx.ndim==y.ndim==1# only works on 1D inputsreturnjnp.sum(jnp.abs(x-y)) defpairwise_distances(dist1D, xs): returnjax.vmap(jax.vmap(dist1D, (0, None)), (None, 0))(xs, xs) xs=jax.random.normal(jax.random.key(0), (100, 3)) dists=pairwise_distances(l1_distance, xs) dists.shape# (100, 100)

By composing jax.vmap with jax.grad and jax.jit, we can get efficient Jacobian matrices, or per-example gradients:

per_example_grads=jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0, 0)))

Scaling

To scale your computations across thousands of devices, you can use any composition of these:

ModeView?Explicit sharding?Explicit Collectives?
AutoGlobal
ExplicitGlobal
ManualPer-device
fromjax.shardingimportset_mesh, AxisType, PartitionSpecasPmesh=jax.make_mesh((8,), ('data',), axis_types=(AxisType.Explicit,)) set_mesh(mesh) # parameters are sharded for FSDP:forW, binparams: print(f'{jax.typeof(W)}') # f32[512@data,512]print(f'{jax.typeof(b)}') # f32[512]# shard data for batch parallelism:inputs, targets=jax.device_put((inputs, targets), P('data')) # evaluate gradients, automatically parallelized!gradfun=jax.jit(jax.grad(loss)) param_grads=gradfun(params, (inputs, targets))

See the tutorial and advanced guides for more.

Gotchas and sharp bits

See the Gotchas Notebook.

Installation

Supported platforms

Linux x86_64Linux aarch64Mac aarch64Windows x86_64Windows WSL2 x86_64
CPUyesyesyesyesyes
NVIDIA GPUyesyesn/anoexperimental
Google TPUyesn/an/an/an/a
AMD GPUyesnon/anoexperimental
Apple GPUn/anoexperimentaln/an/a
Intel GPUexperimentaln/an/anono

Instructions

PlatformInstructions
CPUpip install -U jax
NVIDIA GPUpip install -U "jax[cuda13]"
Google TPUpip install -U "jax[tpu]"
AMD GPU (Linux)Follow AMD's instructions.
Mac GPUFollow Apple's instructions.
Intel GPUFollow Intel's instructions.

See the documentation for information on alternative installation strategies. These include compiling from source, installing with Docker, using other versions of CUDA, a community-supported conda build, and answers to some frequently-asked questions.

Citing JAX

To cite this repository:

@software{jax2018github, author ={James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang}, title ={{JAX}: composable transformations of{P}ython+{N}um{P}y programs}, url ={http://github.com/jax-ml/jax}, version ={0.3.13}, year ={2018}, } 

In the above bibtex entry, names are in alphabetical order, the version number is intended to be that from jax/version.py, and the year corresponds to the project's open-source release.

A nascent version of JAX, supporting only automatic differentiation and compilation to XLA, was described in a paper that appeared at SysML 2018. We're currently working on covering JAX's ideas and capabilities in a more comprehensive and up-to-date paper.

Reference documentation

For details about the JAX API, see the reference documentation.

For getting started as a JAX developer, see the developer documentation.