In our recent work on a novel speculative decoding approach tailored for Apple Intelligence, we prototyped the idea using PyTorch. Programming the vectorized algorithm that samples token sequences from a draft model posed a slight challenge. We had to consider that for each prompt, we would sample a beam of multiple token sequences, and the input would be a batch of multiple prompts.
This inspires me to attempt writing the non-vectorized and comprehensible version of the code in JAX and using JAX’s vmap to automatically add the batch and beam dimensions. The trial was successful, as illustrated by the following code:
from typing import Tuple
import jax
import jax.numpy as jnp
= 8 # hidden dim of llm
d = 4 # hidden dim of rnn
h = 3 # number of tokens drawn from rnn after the one from llm
l = 5 # beam width
w = 2 # batch size
b
def rnn_drafter(llm_state, rnn_state, input_token) -> Tuple[jax.Array, jax.Array]:
"""[d], [h], [] -> [h], []"""
print(f"\nrnn_state=\n{rnn_state}")
= rnn_state + input_token
rnn_state return rnn_state, input_token + rnn_state.sum()
def sample_seq(llm_state, llm_token) -> jax.Array:
"""[h], [] -> [l]"""
= llm_token
token = jnp.zeros((h,))
rnn_state = jnp.array([token], dtype=jnp.int32)
seq for _ in range(l):
= rnn_drafter(llm_state, rnn_state, token)
rnn_state, token = jnp.append(seq, token)
seq return seq
def test_sample_seq():
= jnp.zeros((d,)), jnp.array([100], dtype=jnp.int32)
llm_state, llm_token assert jnp.all(sample_seq(llm_state, llm_token) == jnp.arange(100, 100 + l + 1))
def test_sample_beam():
= jnp.zeros((d,)), jnp.arange(0, w * 100, 100)
llm_state, llm_token assert jnp.all(
=(None, 0), out_axes=0)(llm_state, llm_token)
jax.vmap(sample_seq, in_axes== jnp.array([[0, 0, 0, 0], [100, 500, 2900, 16900], [200, 1000, 5800, 33800]])
)
def test_sample_beam_batched():
= (
llm_state, rnn_state
jnp.zeros((b, d)),
jnp.zeros((b, w, h)),
)= jnp.tile(jnp.arange(0, w * 100, 100), (b, 1))
llm_token assert jnp.all(
=(None, 0), out_axes=0), in_axes=(0, 0), out_axes=0)(
jax.vmap(jax.vmap(sample_seq, in_axes
llm_state, llm_token
)== jnp.array([[[0, 0, 0, 0], [100, 500, 2900, 16900], [200, 1000, 5800, 33800]]] * 2)
)
An further improvement is to replace the Python loop in
sample_seq
with jax.lax.fori_loop
, which could
be JIT-compiled.