JAX’s documentation uses type annotation conventions from type theory. Consider this example:
lax.fori_loop(start: int, end: int, body: (int -> C -> C), init: C) -> CThe notation body: (int -> C -> C) indicates that
body is a function taking two parameters: one of type
int and another of type C, and returns a value
of type C. You might wonder why the first arrow
-> separates parameters while the second appears to
annotate the return type. This relates to the concept of currying.
To understand currying, let’s look at a simple example:
def regular_add(x: int, y: float) -> str:
return f"{x + y}"
def curried_add(x):
def inner(y):
return f"{x + y}"
return inner
assert regular_add(1, 2) == curried_add(1)(2)Here, regular_add takes two parameters (int
and float) and returns a str. In contrast,
curried_add takes one parameter (int) and
returns a function that takes a float and returns a
str. Though structured differently, these functions are
functionally equivalent. Therefore, the type annotation
int -> (float -> str) could appropriately describe
both functions.