In this tutorial, we will explore how to develop a Neural Network (NN) with JAX. And what better model to choose than the Transformer. As JAX is growing in popularity, more and more developer teams are starting to experiment with it and incorporating it into their projects. Despite the fact that it lacks the maturity of Tensorflow or Pytorch, it provides some great features for building and training Deep Learning models.

For a solid understanding of JAX basics, check my previous article if you haven’t already. Also you can find the full code in our Github repository.

One of the common problems people have when starting with JAX is the choice of a framework. The people in Deepmind seem to be very busy and have already released a plethora of frameworks on top of JAX. Here is a list of the most famous ones:

Haiku: Haiku is the go-to framework for Deep Learning and it’s used by many Google and Deepmind internal teams. It provides some simple, composable abstractions for machine learning research as well as ready-to-use modules and layers.

Optax: Optax is a gradient processing and optimization library that contains out-of-the-box optimizers and related mathematical operations.

RLax: RLax is a reinforcement learning framework with many RL subcomponents and operations.

Chex: Chex is a library of utilities for testing and debugging JAX code.

Jraph: Jraph is a Graph Neural Networks library in JAX.

Flax: Flax is another neural network library with a variety of ready-to-use modules, optimizers, and utilities. It’s most likely the closest we have in an all-in JAX framework.

Objax: Objax is a third ml library that focuses on object-oriented programming and code readability. Once again it contains the most popular modules, activation functions, losses, optimizers as well a handful of

**pre-trained models**.Trax: Trax is an end-to-end library for deep learning that focuses on Transformers

JAXline: JAXline is a supervised-learning library that is used for

**distributed JAX training**and evaluation.ACME: ACME is another research framework for reinforcement learning.

JAX-MD: JAX-MD is a niche framework that deals with molecular dynamics.

Jaxchem: JAXChem is another niche library that emphasizes on chemical modeling.

Of course, the question is which one do I choose?

To be honest I’m not sure.

But if I were you and I wanted to learn JAX, I’d start with the most popular ones. Haiku and Flax seem to be used a lot inside Google/Deepmind and have the most active Github community. For this article, I will start with the first one and see if I’ll need another one down the road.

So are you ready to build a Transformer with JAX and Haiku? By the way, I assume that you have a solid understanding of transformers. If you haven’t, please advise our articles on attention and transformers.

Let’s start with the self-attention block.

## The self-attention block

First, we need to import JAX and Haiku

import jaximport jax.numpy as jnpimport haiku as hkImport numpy as np

Luckily for us, Haiku has a built-in `MultiHeadAttention`

block that can be extended to build a masked self-attention block. Our block accepts the query, key, value as well as the mask and returns the output as a JAX array. You can see that the code is very familiar with standard Pytorch or Tensorflow code. All we do is build the causal mask, using `np.trill()`

which nullify all elements of the array above the kth, multiply with our mask and pass everything into the `hk.MultiHeadAttention`

module.

class SelfAttention(hk.MultiHeadAttention):"""Self attention with a causal mask applied."""def __call__(self,query: jnp.ndarray,key: Optional[jnp.ndarray] = None,value: Optional[jnp.ndarray] = None,mask: Optional[jnp.ndarray] = None,) -> jnp.ndarray:key = key if key is not None else queryvalue = value if value is not None else queryseq_len = query.shape[1]causal_mask = np.tril(np.ones((seq_len, seq_len)))mask = mask * causal_mask if mask is not None else causal_maskreturn super().__call__(query, key, value, mask)

This snippet allows me to introduce the first key principle of Haiku. All modules should be a subclass of `hk.Module`

. This means that they should implement `__init__`

and `__call__`

, alongside any other method. In a sense, it’s the same architecture with Pytorch modules, where we implement an `__init__`

and a `forward`

.

To make that crystal clear, let’s build a simple 2-layer MultilayerPerceptron as an `hk.Module`

, which conveniently will be used in the Transformer below.

## The linear layer

A simple 2-layer MLP will look like this. Once again, you can notice how familiar it looks.

class DenseBlock(hk.Module):"""A 2-layer MLP"""def __init__(self,init_scale: float,widening_factor: int = 4,name: Optional[str] = None):super().__init__(name=name)self._init_scale = init_scaleself._widening_factor = widening_factordef __call__(self, x: jnp.ndarray) -> jnp.ndarray:hiddens = x.shape[-1]initializer = hk.initializers.VarianceScaling(self._init_scale)x = hk.Linear(self._widening_factor * hiddens, w_init=initializer)(x)x = jax.nn.gelu(x)return hk.Linear(hiddens, w_init=initializer)(x)

A few things to notice here:

Haiku provides us with a set of weights initializers under

`hk.initializers`

, where we can find the most common approaches.It also has built-in many popular layers and modules such as

`hk.Linear`

. For the complete list, take a peek at the official documentation.Activation functions are not provided because JAX already has a subpackage called

`jax.nn`

, where we can find activation functions such as`relu`

or`softmax`

.

## The normalization layer

Layer normalization is another integral block of the transformer architecture, which we can also find in the common modules inside Haiku.

def layer_norm(x: jnp.ndarray, name: Optional[str] = None) -> jnp.ndarray:"""Apply a unique LayerNorm to x with default settings."""return hk.LayerNorm(axis=-1,create_scale=True,create_offset=True,name=name)(x)

## The transformer

And now for the good stuff. Below you can find a very simplistic Transformer, which makes use of our predefined modules. Inside `__init__`

, we define the basic variables such as the number of layers, attention heads, and the dropout rate. Inside `__call__`

, we compose a list of blocks using a `for`

loop.

As you can see, each block includes:

A normalization layer

A self-attention block

Two dropout layers

Two normalization layers

Two skip connections (

`h = h + h_attn`

and`h = h + h_dense`

)A 2-layered Dense block

In the end, we also add a final normalization layer.

class Transformer(hk.Module):"""A transformer stack."""def __init__(self,num_heads: int,num_layers: int,dropout_rate: float,name: Optional[str] = None):super().__init__(name=name)self._num_layers = num_layersself._num_heads = num_headsself._dropout_rate = dropout_ratedef __call__(self,h: jnp.ndarray,mask: Optional[jnp.ndarray],is_training: bool) -> jnp.ndarray:"""Connects the transformer.Args:h: Inputs, [B, T, H].mask: Padding mask, [B, T].is_training: Whether we're training or not.Returns:Array of shape [B, T, H]."""init_scale = 2. / self._num_layersdropout_rate = self._dropout_rate if is_training else 0.if mask is not None:mask = mask[:, None, None, :]for i in range(self._num_layers):h_norm = layer_norm(h, name=f'h{i}_ln_1')h_attn = SelfAttention(num_heads=self._num_heads,key_size=64,w_init_scale=init_scale,name=f'h{i}_attn')(h_norm, mask=mask)h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn)h = h + h_attnh_norm = layer_norm(h, name=f'h{i}_ln_2')h_dense = DenseBlock(init_scale, name=f'h{i}_mlp')(h_norm)h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense)h = h + h_denseh = layer_norm(h, name='ln_f')return h

I think that by now you have realized that building a Neural Network with JAX is dead simple.

## The embeddings layer

For completion, let’s also include the embeddings layer. It is good to know that Haiku also provides an embedding layer which will create the tokens from our input sentence. The token are then added to the positional embeddings, which produce the final input.

def embeddings(data: Mapping[str, jnp.ndarray], vocab_size: int) :tokens = data['obs']input_mask = jnp.greater(tokens, 0)seq_length = tokens.shape[1]# Embed the input tokens and positions.embed_init = hk.initializers.TruncatedNormal(stddev=0.02)token_embedding_map = hk.Embed(vocab_size, d_model, w_init=embed_init)token_embs = token_embedding_map(tokens)positional_embeddings = hk.get_parameter('pos_embs', [seq_length, d_model], init=embed_init)input_embeddings = token_embs + positional_embeddingsreturn input_embeddings, input_mask

`hk.get_parameter(param_name, ...)`

is used to access the trainable parameters of a module. But you may ask, why not just using object properties as we do in Pytorch. This is where the second key principle of Haiku comes into play. **We use this API so that we can convert the code into a pure function using** `hk.transform`

. This is not very simple to grasp but I will try to make it as clear as possible.

## Why pure functions?

The power of JAX comes into its function transformations: the ability to vectorize a function with `vmap`

, the automatic parallelization with `pmap`

, just in time compilation with `jit`

. The caveat here is that in order to transform a function, it needs to be pure.

A **pure function** is a function that has the following properties:

The function return values are identical for identical arguments (no variation with local static variables, non-local variables, mutable reference arguments, or input streams).

The function application has no side effects (no mutation of local static variables, non-local variables, mutable reference arguments, or input/output streams).

*Source: Scala pure functions by O'Reily*

This practically means that a pure function will always:

**return the same result if invoked with the same inputs****all the input data is passed through the function arguments, all the results are output through the function results**

Haiku provides a function transformation, called `hk.transform`

, that turns functions with object-oriented, functionally “impure” modules into pure functions that can be used with JAX. To see that in practice, let’s continue with the training of our Transformer model.

## The forward pass

A typical forward pass includes:

Taking the input and compute the input embedding

Run through the Transformer’s blocks

Return the output

The aforementioned steps can be easily composed with JAX as following:

def build_forward_fn(vocab_size: int, d_model: int, num_heads: int,num_layers: int, dropout_rate: float):"""Create the model's forward pass."""def forward_fn(data: Mapping[str, jnp.ndarray],is_training: bool = True) -> jnp.ndarray:"""Forward pass."""input_embeddings, input_mask = embeddings(data, vocab_size)# Run the transformer over the inputs.transformer = Transformer(num_heads=num_heads, num_layers=num_layers, dropout_rate=dropout_rate)output_embeddings = transformer(input_embeddings, input_mask, is_training)# Reverse the embeddings (untied).return hk.Linear(vocab_size)(output_embeddings)return forward_fn

Although the code is straightforward, its structure might seem a bit odd. The actual forward pass is executed through the `forward_fn`

function. However, we wrap this with the `build_forward_fn`

function which returns the `forward_fn`

. What the heck?

Down the road, we will need to transform the `forward_fn`

function into a pure function using `hk.transform`

so that we can take advantage of automatic differentiation, parallelization etc.

This will be accomplished by:

forward_fn = build_forward_fn(vocab_size, d_model, num_heads,num_layers, dropout_rate)forward_fn = hk.transform(forward_fn)

That’s why instead of simply defining a function, we wrapp and return the function itself, or a callable to be more precise. This callable can then be passed into the `hk.transform`

and become a pure function. If this is clear, let’s continue with our loss function.

## The loss function

The loss function is our well-known cross-entropy function with the difference that we are also taking the mask into consideration. Once again, JAX provides `one_hot`

and `log_softmax`

functionalities.

def lm_loss_fn(forward_fn,vocab_size: int,params,rng,data: Mapping[str, jnp.ndarray],is_training: bool = True) -> jnp.ndarray:"""Compute the loss on data wrt params."""logits = forward_fn(params, rng, data, is_training)targets = jax.nn.one_hot(data['target'], vocab_size)assert logits.shape == targets.shapemask = jnp.greater(data['obs'], 0)loss = -jnp.sum(targets * jax.nn.log_softmax(logits), axis=-1)loss = jnp.sum(loss * mask) / jnp.sum(mask)return loss

If you are still with me, take a sip of coffee because things are going to get serious from now on. It’s time to build our training loop.

## The training loop

Because neither Jax nor Haiku has optimization functionalities built-in, we will make use of another framework, called Optax. As mentioned in the beginning, Optax is the goto package for gradient processing.

First here are some things you need to know about Optax:

The key transformation of Optax is the `GradientTransformation`

. The transformation is defined by two functions, the `__init__`

and the `__update__`

. The `__init__`

initializes the state and the `__update__`

transforms the gradients with respect to the state and the current value of the parameters

state = init(params)grads, state = update(grads, state, params=None)

One more thing to know before we see the code, is Python’s built-in `functools.partial`

function. The `functools`

package deals with higher-order functions and operations on callable objects.

A function is called a Higher Order function if it contains other functions as a parameter or returns a function as an output.

The `partial`

, which can also be used as an annotation, returns a new function based on an original one, but with fewer or fixed arguments. If for example, f multiplies two values x,y, the partial will create a new function where x will be fixed and equal with 2

from functools import partialdef f(x,y):return x * y# creates a new function that multiplies by 2 ( x will be fixed and equal with 2)g = partial(f,2)print(g(4))#returns 8

After this short detour, let’s proceed. To decongest our `main`

function, we will extract the gradients update into its own class.

First of all the `GradientUpdater`

accepts the model, the loss function, and an optimizer.

- The model will be a pure
`forward_fn`

function transformed by`hk.transform`

forward_fn = build_forward_fn(vocab_size, d_model, num_heads,num_layers, dropout_rate)forward_fn = hk.transform(forward_fn)

- The loss function will be the result of a partial with a fixed
`forward_fn`

and `vocab_size

loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)

- The optimizer is a set of optimization transformations that will run sequentially ( operations can be combined using
`optax.chain`

)

optimizer = optax.chain(optax.clip_by_global_norm(grad_clip_value),optax.adam(learning_rate, b1=0.9, b2=0.99))

The Gradient updater will be initialized as follows:

updater = GradientUpdater(forward_fn.init, loss_fn, optimizer)

and will look like this:

class GradientUpdater:"""A stateless abstraction around an init_fn/update_fn pair.This extracts some common boilerplate from the training loop."""def __init__(self, net_init, loss_fn,optimizer: optax.GradientTransformation):self._net_init = net_initself._loss_fn = loss_fnself._opt = optimizer@functools.partial(jax.jit, static_argnums=0)def init(self, master_rng, data):"""Initializes state of the updater."""out_rng, init_rng = jax.random.split(master_rng)params = self._net_init(init_rng, data)opt_state = self._opt.init(params)out = dict(step=np.array(0),rng=out_rng,opt_state=opt_state,params=params,)return out@functools.partial(jax.jit, static_argnums=0)def update(self, state: Mapping[str, Any], data: Mapping[str, jnp.ndarray]):"""Updates the state using some data and returns metrics."""rng, new_rng = jax.random.split(state['rng'])params = state['params']loss, g = jax.value_and_grad(self._loss_fn)(params, rng, data)updates, opt_state = self._opt.update(g, state['opt_state'])params = optax.apply_updates(params, updates)new_state = {'step': state['step'] + 1,'rng': new_rng,'opt_state': opt_state,'params': params,}metrics = {'step': state['step'],'loss': loss,}return new_state, metrics

Inside `__init__`

, we initialize our optimizer with `self._opt.init(params)`

and we declare the state of the optimization. The state will be a dictionary with:

The current step

The optimizer state

The trainable parameters

( A random generator key to pass into

`jax.random.split`

)

The `update`

function will update both the state of the optimizer as well as the trainable parameters. In the end, it will return the new state.

updates, opt_state = self._opt.update(g, state['opt_state'])params = optax.apply_updates(params, updates)

Two more things to notice here:

`jax.value_and_grad()`

is a special function that returns a differentiable function with its gradientsBoth

`__init__`

and`__update__`

are annotated with`@functools.partial(jax.jit, static_argnums=0)`

, which will trigger the just-in-time compiler and compile them into XLA during runtime. Note that if we haven’t transformed`forward_fn`

into a pure function, this wouldn’t be possible.

Finally, we are ready to build the entire training loop, which combines all the ideas and code mentioned so far.

def main():# Create the dataset.train_dataset, vocab_size = load(batch_size,sequence_length)# Set up the model, loss, and updater.forward_fn = build_forward_fn(vocab_size, d_model, num_heads,num_layers, dropout_rate)forward_fn = hk.transform(forward_fn)loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)optimizer = optax.chain(optax.clip_by_global_norm(grad_clip_value),optax.adam(learning_rate, b1=0.9, b2=0.99))updater = GradientUpdater(forward_fn.init, loss_fn, optimizer)# Initialize parameters.logging.info('Initializing parameters...')rng = jax.random.PRNGKey(428)data = next(train_dataset)state = updater.init(rng, data)logging.info('Starting train loop...')prev_time = time.time()for step in range(MAX_STEPS):data = next(train_dataset)state, metrics = updater.update(state, data)

Notice how we incorporate the `GradientUpdate`

. It’s just two lines of code:

`state = updater.init(rng, data)`

`state, metrics = updater.update(state, data)`

And that’s it. I hope that by now you have a more clear understanding of JAX and its capabilities.

## Acknowledgments

The code presented is heavily inspired by the official examples of the Haiku framework. It has been modified to fit the needs of this article. For the complete list of examples, check the official repository

## Conclusion

In this article, we saw how one can develop and train a vanilla Transformer in JAX using Haiku. Although the code isn’t necessarily hard to grasp, it still lacks the readability of Pytorch or Tensorflow. I highly recommend to play around with it, discover the strengths and weaknesses of JAX and see if it’d be a good fit for your next project. In my experience, JAX is very strong for research applications that require high performance but quite immature for real-life projects. Let us know what you think in our discord channel.

** Disclosure: Please note that some of the links above might be affiliate links, and at no additional cost to you, we will earn a commission if you decide to make a purchase after clicking through.*