JAX provides JIT compilation through the jax.jit
API, as
shown in the following example:
= jax.jit(jnp.sin)
fsin = jnp.linspace(0, 1.0, 5)
x print(jnp.array_equal(jnp.sin(x), fsin(x)))
This example may lead users to believe that
jax.jit(jnp.sin)
compiles and returns a “faster” version of
jnp.sin
. However, in reality, the first call to
fsin
triggers the actual compilation.
This misconception can lead to further confusion about the code’s
behavior. For example, users might assume that jax.jit(jnp.sin) is
time-consuming due to compilation. However, it is the call to
fsin(x)
that initiates the compilation and thus takes
significant time.
More importantly, this misconception may prevent users from
understanding JAX’s requirement for fixed-shape input arrays during
compilation. The call jax.jit(jnp.sin)
alone does not
involve any input arrays, which is why the actual compilation happens
only when fsin
is called with an input array.
The following example demonstrates that the initial call to a
function decorated with jax.jit
triggers the time-consuming
compilation process, while subsequent calls execute much faster due to
caching.
import jax.numpy as jnp
= jnp.linspace(0., 1., 1000)
x %time jax.jit(jnp.sin)(x).block_until_ready() # 33.6 ms
%time jax.jit(jnp.sin)(x).block_until_ready() # 852 µs
%time jax.jit(jnp.sin)(x).block_until_ready() # 910 µs
%time jax.jit(jnp.sin)(x).block_until_ready() # 891 µs
In the example above, calls to block_until_ready
ensure that the results are fully computed. According to the JAX
documentation on Asynchronous
Dispatch, jax.Array
is a future – a placeholder for a
value that will be computed on an accelerator device but may not be
available immediately. Calling block_until_ready
forces the
program to wait for the execution of jax.jit(jnp.sin)
to
complete and return the result.
Asynchronous dispatch is useful because it enables Python programs to
enqueue substantial amounts of work for the accelerator. MLX adopts a
similar design. To ensure an array x
is ready in MLX, you
can call mx.eval(x)
.
jax.make_jaxpr
JAX
documentation mentions that JIT uses jax.make_jaxpr
to
“trace” Python code and produce an intermediate representation called
JAXPR. However, it does not reveal details about
make_jaxpr
. So I crafted the following example allows a
peek into the hole.
import jax
import jax.numpy as jnp
def f(x, y) -> jax.Array:
print("type(x):", type(x))
print("x:", x)
print("type(y):", type(y))
print("y:", y)
return jnp.dot(x + 1, y + 1)
The normal way to call this function is to pass in two arrays:
= jnp.array([1.0, 2.0])
x = jnp.array([1.0, 2.0])
y print(f(x, y))
The function f
prints the type and value of
x
and y
, as well as the final result
13
.
type(x): <class 'jaxlib.xla_extension.ArrayImpl'>
x: [1. 2.]
type(y): <class 'jaxlib.xla_extension.ArrayImpl'>
y: [1. 2.]
13.0
Now let us check what jax.make_jaxpr
returns:
= jax.make_jaxpr(f)
ff print("type(ff):", type(ff))
print("ff:", ff)
It returns a function:
type(ff): <class 'function'>
ff: <function make_jaxpr(f) at 0x10a06af80>
Let us try calling the function:
= ff(x, y)
z print("type(z):", type(z))
print("z:", z)
This prints the type and value of x
and y
as well as the returned value:
type(x): <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>
x: Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>
type(y): <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>
y: Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>
type(z): <class 'jax._src.core.ClosedJaxpr'>
z: { lambda ; a:f32[2] b:f32[2]. let
c:f32[2] = add a 1.0
d:f32[2] = add b 1.0
e:f32[] = dot_general[
dimension_numbers=(([0], [0]), ([], []))
preferred_element_type=float32
] c d
in (e,) }
We passed in two arrays to ff
. However, the calls to
print
by f
show that x
and
y
are of type DynamicJaxprTracer
, not arrays.
Obviously, the function ff
, created by
jax.make_jaxpr
, calls f
, which is why the
print
calls in f
work. But, before calling
f
, ff
converts the input arrays into
DynamicJaxprTracer
.
The DynamicJaxprTracer
contains only the
ShapedArray
with float32
dtype and shape
[2]
; the actual data is missing. That is the purpose of
tracing: capturing the dtype and shape of arrays but not the
content.
As expected, the return value is not an array but a representation of
the operations within f
. For this exmaple, it is a short
program that calls the XLA
operation dot_general
.
From this, we can infer how jax.make_jaxpr
works. It
returns a function that calls f
with arguments converted to
DynamicJaxprTracer
instances, capturing dtype and shape
while allowing functions like jnp.dot to treat them like arrays. Thanks
to Python’s support of duck-typing, JAX
functions can operate on tracers like they are operating arrays.
def make_jaxpr(f):
def ff(*args, **kwargs):
= convert_arrays_to_tracer(args, kwargs)
args, kwargs return f(args, kwargs)
return ff
Now, let us hypothesize how jax.jit
work. According to
the initial example, jax.jit
takes a function
f
as input. If f
has already been compiled,
jax.jit
should return the cached result. If not,
jax.jit
should return a function ff
that runs
the identical operations as f
. When called with arguments
like x
and y
, ff
would:
f
given x
and y
,x
and y
.The source code of jax.jit
may look something like the
following:
def jit(f):
if ff := cache.of(f):
return ff
def trigger(*args, **kwargs):
= jax.make_jaxpr(f)
ff = f(args, kwargs)
trace = compile_using_xla(trace)
compiled
cache.add(compiled)return compiled(args, kwargs)
return tigger
I haven’t read the JAX codebase to verify if my hypothesis is correct. But I plan to. :-)