<aside> 💡 This is a brief introduction to some key ideas in JAX and Equinox (mostly on the latter). As a use case, I am implementing an activation function from one of my papers (kernel activation function, KAF [1]), which is a non-parametric AF that allows each unit in a network to learn its own non-linearity. The specific example is not crucial for the purposes of the tutorial.

</aside>

<aside> ☝ The code is also available as a Colab notebook: https://colab.research.google.com/drive/1Al9BdqOVMP3nRJm7nj2FtW4lRBjxe3bq?usp=sharing

</aside>

I wrote my first JAX tutorial in 2018 when the library came out — it extended another library I liked (Autograd) with a more efficient implementation in the back-end, and it seemed natural to play with it a bit. Today, everyone loves JAX for its speed, but quite a few people struggle with its functional approach, especially when considering high-level frameworks such as Haiku or Equinox. To ease the process, this is a short guided tour to some of key concepts in JAX and Equinox, following a realistic implementation of an activation function.

What is JAX?

Putting aside efficiency, the simplest way to understand JAX is to consider it as an extension of NumPy with additional functional transformations, i.e., higher-order functions that consume functions and return other functions. The prototypical example is grad, which returns another function that computes the gradient of its input function with respect to the first argument (I am using jaxtyping to annotate input parameters):

def f(x: Float[Array, "n"]):
  return (x**2).sum()

print(grad(f)(jnp.ones(3))) 
# [Out]: Array([2., 2., 2.], dtype=float32)

Note the brackets: grad(f) consumes a function and returns another function, which is then called on some input parameters. The two other interesting transformations provided by JAX are jit (to trace and compile a function to an optimized static version), and vmap (vectorized map) to vectorize a function over a new axis. Let us see some examples of their application.

An interlude - Kernel activation functions

A KAF is a trainable activation function. For the purposes of this tutorial we only need a very brief introduction — if you have been following the KAN paper [2], the core idea is similar. Each KAF is a scalar function described as a weighted combination of $k$ basis functions:

$$ \phi(x)=\sum_{i=1}^k\alpha_i\kappa(x,c_i) \tag{1} $$

where $x$ is a scalar, $k$ is a hyper-parameter, the coefficients $\alpha_i$ are trainable, $\kappa$ is a generic kernel function, and for simplicity the centers $c_i$ are chosen equispaced over the $x$-axis. You can think of (1) equivalently as a two-layer MLP with one input and one output. In particular, suppose we choose $\kappa$ as a Gaussian kernel centered in $c$ with some hyper-parameter $\sigma^2$ controlling the width:

$$ \kappa(x,c)=\exp\left(\frac{-(x-c)^2}{2\sigma^2}\right) $$

Then, (1) is a weighted sum of $k$ univariate Gaussians put at equispaced points on the $x$-axis (i.e., a radial basis function network). Intuitively, by varying the coefficients $\alpha_i$ we can reasonably approximate any (univariate) function, including standard activation functions such as the ReLU, as shown next.

Implementing the KAF layer (pure JAX)

Let us start by a simple implementation of the KAF layer using only JAX. First, we write down the Gaussian kernel:

gauss_kernel = lambda x: Float, center: Float, sigma: Float: 
				jnp.exp(- (x - center)**2 / (2 * sigma**2))

Note that everything is scalar in this definition, but thanks to broadcasting it also works if the center parameter is a vector, which we exploit below. We now write another function to implement (1):

def kaf(x: Float[Array, ""], 
							 alpha: Float[Array, "k"], 
							 centers: Float[Array, "k"], 
							 sigma: Float):
  K = gauss_kernel(x, centers, sigma)
  return (K * alpha).sum()