JAX
documentation provides an example that sets both parameters,
in_axes and out_axes, to the same value to
explain vmap. However, it would be beneficial for users to
have more examples with different values of in_axes and
out_axes to better understand vmap.
in_axesThe name vmap is derived from the well-known
map function. The innocation map(f, l) applies
the function f to each element in the list l.
In contrast, the innocation vmap(f)(v) requires elements to
be packed into a tensor v.
Consider that f processes a vector. We could pack the
list of vectors as rows or columns of v. If they are rows,
we should call vmap(f, in_axies=0)(v), where
in_axies=0 indicates that the first dimension of
v corresponds to the list. Alternatively, we call
vmap(f, in_axies=1)(v) to specify that the columns of
v are the elements, as illustrated in the following
example:
import jax
import jax.numpy as jnp
def f(x):
return x.sum() # reduce a vector x
v = jnp.ones((3, 2))
print(f"v=\n{v}\n")
print(jax.vmap(f, in_axes=0)(v)) # batch=3
print(jax.vmap(f, in_axes=1)(v)) # batch=2out_axesJAX is pure-functional, which means its tensors are immutable.
Therefore, when you call jax.vmap(f)(v), it returns a new
tensor. The out_axes parameter instructs the function
jax.vmap(f) on how to construct the resulting tensor. If
out_axes=0, jax.vmap(f) packs the outputs from
each call to f as rows in the result tensor. Alternatively,
if out_axes=1, it packs them as columns, as illustrated in
the following tensor transpose example:
def g(x):
return x
print(jax.vmap(g, in_axes=0, out_axes=0)(v))
print(jax.vmap(g, in_axes=0, out_axes=1)(v))