JAX | |
Logo Caption: | Google JAX logo |
Developer: | |
Latest Preview Version: | v0.3.13 |
Programming Language: | Python, C++ |
Operating System: | Linux, macOS, Windows |
Platform: | Python, NumPy |
Size: | 9.0 MB |
Genre: | Machine learning |
License: | Apache 2.0 |
Google JAX is a machine learning framework for transforming numerical functions.[1] [2] It is described as bringing together a modified version of autograd (automatic obtaining of the gradient function through differentiation of a function) and TensorFlow's XLA (Accelerated Linear Algebra). It is designed to follow the structure and workflow of NumPy as closely as possible and works with various existing frameworks such as TensorFlow and PyTorch.[3] [4] The primary functions of JAX are:
See main article: Automatic differentiation. The below code demonstrates the grad function's automatic differentiation.
from jax import gradimport jax.numpy as jnp
def logistic(x): return jnp.exp(x) / (jnp.exp(x) + 1)
grad_logistic = grad(logistic)
grad_log_out = grad_logistic(1.0)) print(grad_log_out)
The final line should outputː
The below code demonstrates the jit function's optimization through fusion.
from jax import jitimport jax.numpy as jnp
def cube(x): return x * x * x
x = jnp.ones((10000, 10000))
jit_cube = jit(cube)
cube(x)jit_cube(x)
The computation time for (line #17) should be noticably shorter than that for (line #16). Increasing the values on line #7, will further exacerbate the difference.
See main article: Array programming. The below code demonstrates the vmap function's vectorization.
from jax import vmap partialimport jax.numpy as jnp
def grads(self, inputs): in_grad_partial = jax.partial(self._net_grads, self._net_params) grad_vmap = jax.vmap(in_grad_partial) rich_grads = grad_vmap(inputs) flat_grads = np.asarray(self._flatten_batch(rich_grads)) assert flat_grads.ndim
inputs.shape[0] return flat_grads
The GIF on the right of this section illustrates the notion of vectorized addition.
The below code demonstrates the pmap function's parallelization for matrix multiplication.
from jax import pmap, randomimport jax.numpy as jnp
random_keys = random.split(random.PRNGKey(0), 2)matrices = pmap(lambda key: random.normal(key, (5000, 6000)))(random_keys)
outputs = pmap(lambda x: jnp.dot(x, x.T))(matrices)
means = pmap(jnp.mean)(outputs)print(means)
The final line should print the valuesː