<aside> 💡 This is a brief introduction to some key ideas in JAX and Equinox (mostly on the latter). As a use case, I am implementing an activation function from one of my papers (kernel activation function, KAF [1]), which is a non-parametric AF that allows each unit in a network to learn its own non-linearity. The specific example is not crucial for the purposes of the tutorial.
</aside>
<aside> ☝ The code is also available as a Colab notebook: https://colab.research.google.com/drive/1Al9BdqOVMP3nRJm7nj2FtW4lRBjxe3bq?usp=sharing
</aside>
I wrote my first JAX tutorial in 2018 when the library came out — it extended another library I liked (Autograd) with a more efficient implementation in the back-end, and it seemed natural to play with it a bit. Today, everyone loves JAX for its speed, but quite a few people struggle with its functional approach, especially when considering high-level frameworks such as Haiku or Equinox. To ease the process, this is a short guided tour to some of key concepts in JAX and Equinox, following a realistic implementation of an activation function.
Putting aside efficiency, the simplest way to understand JAX is to consider it as an extension of NumPy with additional functional transformations, i.e., higher-order functions that consume functions and return other functions. The prototypical example is grad
, which returns another function that computes the gradient of its input function with respect to the first argument (I am using jaxtyping to annotate input parameters):
def f(x: Float[Array, "n"]):
return (x**2).sum()
print(grad(f)(jnp.ones(3)))
# [Out]: Array([2., 2., 2.], dtype=float32)
Note the brackets: grad(f)
consumes a function and returns another function, which is then called on some input parameters. The two other interesting transformations provided by JAX are jit
(to trace and compile a function to an optimized static version), and vmap
(vectorized map) to vectorize a function over a new axis. Let us see some examples of their application.
A KAF is a trainable activation function. For the purposes of this tutorial we only need a very brief introduction — if you have been following the KAN paper [2], the core idea is similar. Each KAF is a scalar function described as a weighted combination of $k$ basis functions:
$$ \phi(x)=\sum_{i=1}^k\alpha_i\kappa(x,c_i) \tag{1} $$
where $x$ is a scalar, $k$ is a hyper-parameter, the coefficients $\alpha_i$ are trainable, $\kappa$ is a generic kernel function, and for simplicity the centers $c_i$ are chosen equispaced over the $x$-axis. You can think of (1) equivalently as a two-layer MLP with one input and one output. In particular, suppose we choose $\kappa$ as a Gaussian kernel centered in $c$ with some hyper-parameter $\sigma^2$ controlling the width:
$$ \kappa(x,c)=\exp\left(\frac{-(x-c)^2}{2\sigma^2}\right) $$
Then, (1) is a weighted sum of $k$ univariate Gaussians put at equispaced points on the $x$-axis (i.e., a radial basis function network). Intuitively, by varying the coefficients $\alpha_i$ we can reasonably approximate any (univariate) function, including standard activation functions such as the ReLU, as shown next.
Let us start by a simple implementation of the KAF layer using only JAX. First, we write down the Gaussian kernel:
gauss_kernel = lambda x: Float, center: Float, sigma: Float:
jnp.exp(- (x - center)**2 / (2 * sigma**2))
Note that everything is scalar in this definition, but thanks to broadcasting it also works if the center
parameter is a vector, which we exploit below. We now write another function to implement (1):
def kaf(x: Float[Array, ""],
alpha: Float[Array, "k"],
centers: Float[Array, "k"],
sigma: Float):
K = gauss_kernel(x, centers, sigma)
return (K * alpha).sum()