Google JAX explained

JAX
Logo Caption:Google JAX logo
Developer:Google
Latest Preview Version:v0.3.13
Programming Language:Python, C++
Operating System:Linux, macOS, Windows
Platform:Python, NumPy
Size:9.0 MB
Genre:Machine learning
License:Apache 2.0

Google JAX is a machine learning framework for transforming numerical functions.[1] [2] It is described as bringing together a modified version of autograd (automatic obtaining of the gradient function through differentiation of a function) and TensorFlow's XLA (Accelerated Linear Algebra). It is designed to follow the structure and workflow of NumPy as closely as possible and works with various existing frameworks such as TensorFlow and PyTorch.[3] [4] The primary functions of JAX are:

  1. grad: automatic differentiation
  2. jit: compilation
  3. vmap: auto-vectorization
  4. pmap: SPMD programming

grad

See main article: Automatic differentiation. The below code demonstrates the grad function's automatic differentiation.

  1. imports

from jax import gradimport jax.numpy as jnp

  1. define the logistic function

def logistic(x): return jnp.exp(x) / (jnp.exp(x) + 1)

  1. obtain the gradient function of the logistic function

grad_logistic = grad(logistic)

  1. evaluate the gradient of the logistic function at x = 1

grad_log_out = grad_logistic(1.0)) print(grad_log_out)

The final line should outputː

0.19661194

jit

The below code demonstrates the jit function's optimization through fusion.

  1. imports

from jax import jitimport jax.numpy as jnp

  1. define the cube function

def cube(x): return x * x * x

  1. generate data

x = jnp.ones((10000, 10000))

  1. create the jit version of the cube function

jit_cube = jit(cube)

  1. apply the cube and jit_cube functions to the same data for spreed comoparion

cube(x)jit_cube(x)

The computation time for (line #17) should be noticably shorter than that for (line #16). Increasing the values on line #7, will further exacerbate the difference.

vmap

See main article: Array programming. The below code demonstrates the vmap function's vectorization.

  1. imports

from jax import vmap partialimport jax.numpy as jnp

  1. define function

def grads(self, inputs): in_grad_partial = jax.partial(self._net_grads, self._net_params) grad_vmap = jax.vmap(in_grad_partial) rich_grads = grad_vmap(inputs) flat_grads = np.asarray(self._flatten_batch(rich_grads)) assert flat_grads.ndim

2 and flat_grads.shape[0]

inputs.shape[0] return flat_grads

The GIF on the right of this section illustrates the notion of vectorized addition.

pmap

The below code demonstrates the pmap function's parallelization for matrix multiplication.

  1. import pmap and random from JAX; import JAX NumPy

from jax import pmap, randomimport jax.numpy as jnp

  1. generate 2 random matrices of dimensions 5000 x 6000, one per device

random_keys = random.split(random.PRNGKey(0), 2)matrices = pmap(lambda key: random.normal(key, (5000, 6000)))(random_keys)

  1. without data transfer, in parallel, perform a local matrix multiplication on each CPU/GPU

outputs = pmap(lambda x: jnp.dot(x, x.T))(matrices)

  1. without data transfer, in parallel, obtain the mean for both matrices on each CPU/GPU separately

means = pmap(jnp.mean)(outputs)print(means)

The final line should print the valuesː

[1.1566595 1.1805978]

See also

External links

Notes and References

  1. Frostig . Roy . Johnson . Matthew James . Leary . Chris . 2018-02-02 . 2018 . Compiling machine learning programs via high-level tracing . live . MLsys . 1–3 . https://web.archive.org/web/20220621153349/https://mlsys.org/Conferences/doc/2018/146.pdf . 2022-06-21.
  2. Web site: Using JAX to accelerate our research . live . https://web.archive.org/web/20220618205746/https://www.deepmind.com/blog/using-jax-to-accelerate-our-research . 2022-06-18 . 2022-06-18 . www.deepmind.com . en.
  3. Web site: Lynley . Matthew . Google is quietly replacing the backbone of its AI product strategy after its last big push for dominance got overshadowed by Meta . https://web.archive.org/web/20220621143905/https://www.businessinsider.com/facebook-pytorch-beat-google-tensorflow-jax-meta-ai-2022-6 . 2022-06-21 . 2022-06-21 . Business Insider . en-US.
  4. Web site: 2022-04-25 . Why is Google's JAX so popular? . live . https://web.archive.org/web/20220618210503/https://analyticsindiamag.com/why-is-googles-jax-so-popular/ . 2022-06-18 . 2022-06-18 . Analytics India Magazine . en-US.