For distributed computing, JAX organizes GPUs or TPUs into a mesh. It then shards a tensor across devices by overlaying its elements onto this mesh. The following sample program illustrates the concept of a device mesh and tensor sharding.
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental import mesh_utils
assert len(jax.devices()) == 8
dv = jax.experimental.mesh_utils.create_device_mesh((4, 2))
print(type(dv))
print(dv)
When you run the program on a host with 8 GPUs, the output is:
<class 'numpy.ndarray'>
[[CudaDevice(id=0) CudaDevice(id=1)]
[CudaDevice(id=2) CudaDevice(id=3)]
[CudaDevice(id=4) CudaDevice(id=5)]
[CudaDevice(id=6) CudaDevice(id=7)]]
This output indicates that a mesh is simply an array with elements of type CudaDevice
.
If we change the shape of the mesh from 4×2 to 2×4, as shown below:
dv = jax.experimental.mesh_utils.create_device_mesh((2, 4))
print(type(dv))
print(dv)
The output becomes:
<class 'numpy.ndarray'>
[[CudaDevice(id=0) CudaDevice(id=1) CudaDevice(id=2) CudaDevice(id=3)]
[CudaDevice(id=4) CudaDevice(id=5) CudaDevice(id=6) CudaDevice(id=7)]]
This demonstrates that the mesh is created by placing the devices into the array in row-major order.
Before using the array, we need to convert it into a jax.sharding.Mesh
. This conversion allows us to assign names to each axis of the array. In the example below, axis 0 is named “model” and axis 1 is named “data”.
mesh = jax.sharding.Mesh(dv, axis_names=("model", "data"))
print(mesh) # Mesh('model': 2, 'data': 4)
With the mesh defined, we can now shard a tensor by mapping it onto the mesh. The following example overlays an 8×8 tensor onto the 2×4 mesh.
x = jnp.ones((8, 8))
x_sharded = jax.device_put(
x,
device=jax.sharding.NamedSharding(
mesh,
jax.sharding.PartitionSpec("model", "data"),
),
)
jax.debug.visualize_array_sharding(x_sharded)
Calling jax.debug.visualize_array_sharding
displays how the tensor’s elements are distributed across the devices:
GPU 0 GPU 1 GPU 2 GPU 3
GPU 4 GPU 5 GPU 6 GPU 7
The names specified in the jax.sharding.PartitionSpec
determine how the tensor’s axes are mapped to the mesh’s named axes. In this example, the first axis (axis 0) of tensor x is mapped to the mesh’s “model” axis, and the second (axis 1) is mapped to the “data” axis. Consequently, axis 0 of x corresponds to the two rows of the mesh, and axis 1 corresponds to the four columns.
We can also try switching the mapping:
x = jnp.ones((8, 8))
x_sharded = jax.device_put(
x,
device=jax.sharding.NamedSharding(
mesh,
jax.sharding.PartitionSpec("data", "model"), # Swaps the mapping: model -> data, data -> model
),
)
jax.debug.visualize_array_sharding(x_sharded)
Now the sharding output is as follows, because axis 0 of x is mapped to the “data” axis of the mesh (width 4) and axis 1 is mapped to the “model” axis (width 2):
GPU 0 GPU 4
GPU 1 GPU 5
GPU 2 GPU 6
GPU 3 GPU 7