In this tutorial, we will explore how to develop a Neural Network (NN) with JAX. And what better model to choose than the Transformer. As JAX is growing in popularity, more and more developer teams are starting to experiment with it and incorporating it into their projects. Despite the fact that it lacks the maturity of Tensorflow or Pytorch, it provides some great features for building and training Deep Learning models.
For a solid understanding of JAX basics, check my previous article if you haven’t already. Also you can find the full code in our Github repository.
One of the common problems people have when starting with JAX is the choice of a framework. The people in Deepmind seem to be very busy and have already released a plethora of frameworks on top of JAX. Here is a list of the most famous ones:
Haiku: Haiku is the go-to framework for Deep Learning and it’s used by many Google and Deepmind internal teams. It provides some simple, composable abstractions for machine learning research as well as ready-to-use modules and layers.
Optax: Optax is a gradient processing and optimization library that contains out-of-the-box optimizers and related mathematical operations.
RLax: RLax is a reinforcement learning framework with many RL subcomponents and operations.
Chex: Chex is a library of utilities for testing and debugging JAX code.
Jraph: Jraph is a Graph Neural Networks library in JAX.
Flax: Flax is another neural network library with a variety of ready-to-use modules, optimizers, and utilities. It’s most likely the closest we have in an all-in JAX framework.
Objax: Objax is a third ml library that focuses on object-oriented programming and code readability. Once again it contains the most popular modules, activation functions, losses, optimizers as well a handful of pre-trained models.
Trax: Trax is an end-to-end library for deep learning that focuses on Transformers
JAXline: JAXline is a supervised-learning library that is used for distributed JAX training and evaluation.
ACME: ACME is another research framework for reinforcement learning.
JAX-MD: JAX-MD is a niche framework that deals with molecular dynamics.
Jaxchem: JAXChem is another niche library that emphasizes on chemical modeling.
Of course, the question is which one do I choose?
To be honest I’m not sure.
But if I were you and I wanted to learn JAX, I’d start with the most popular ones. Haiku and Flax seem to be used a lot inside Google/Deepmind and have the most active Github community. For this article, I will start with the first one and see if I’ll need another one down the road.
So are you ready to build a Transformer with JAX and Haiku? By the way, I assume that you have a solid understanding of transformers. If you haven’t, please advise our articles on attention and transformers.
Let’s start with the self-attention block.
The self-attention block
First, we need to import JAX and Haiku
import jaximport jax.numpy as jnpimport haiku as hkImport numpy as np
Luckily for us, Haiku has a built-in MultiHeadAttention
block that can be extended to build a masked self-attention block. Our block accepts the query, key, value as well as the mask and returns the output as a JAX array. You can see that the code is very familiar with standard Pytorch or Tensorflow code. All we do is build the causal mask, using np.trill()
which nullify all elements of the array above the kth, multiply with our mask and pass everything into the hk.MultiHeadAttention
module.
class SelfAttention(hk.MultiHeadAttention):"""Self attention with a causal mask applied."""def __call__(self,query: jnp.ndarray,key: Optional[jnp.ndarray] = None,value: Optional[jnp.ndarray] = None,mask: Optional[jnp.ndarray] = None,) -> jnp.ndarray:key = key if key is not None else queryvalue = value if value is not None else queryseq_len = query.shape[1]causal_mask = np.tril(np.ones((seq_len, seq_len)))mask = mask * causal_mask if mask is not None else causal_maskreturn super().__call__(query, key, value, mask)
This snippet allows me to introduce the first key principle of Haiku. All modules should be a subclass of hk.Module
. This means that they should implement __init__
and __call__
, alongside any other method. In a sense, it’s the same architecture with Pytorch modules, where we implement an __init__
and a forward
.
To make that crystal clear, let’s build a simple 2-layer MultilayerPerceptron as an hk.Module
, which conveniently will be used in the Transformer below.
The linear layer
A simple 2-layer MLP will look like this. Once again, you can notice how familiar it looks.
class DenseBlock(hk.Module):"""A 2-layer MLP"""def __init__(self,init_scale: float,widening_factor: int = 4,name: Optional[str] = None):super().__init__(name=name)self._init_scale = init_scaleself._widening_factor = widening_factordef __call__(self, x: jnp.ndarray) -> jnp.ndarray:hiddens = x.shape[-1]initializer = hk.initializers.VarianceScaling(self._init_scale)x = hk.Linear(self._widening_factor * hiddens, w_init=initializer)(x)x = jax.nn.gelu(x)return hk.Linear(hiddens, w_init=initializer)(x)
A few things to notice here:
Haiku provides us with a set of weights initializers under
hk.initializers
, where we can find the most common approaches.It also has built-in many popular layers and modules such as
hk.Linear
. For the complete list, take a peek at the official documentation.Activation functions are not provided because JAX already has a subpackage called
jax.nn
, where we can find activation functions such asrelu
orsoftmax
.
The normalization layer
Layer normalization is another integral block of the transformer architecture, which we can also find in the common modules inside Haiku.
def layer_norm(x: jnp.ndarray, name: Optional[str] = None) -> jnp.ndarray:"""Apply a unique LayerNorm to x with default settings."""return hk.LayerNorm(axis=-1,create_scale=True,create_offset=True,name=name)(x)
The transformer
And now for the good stuff. Below you can find a very simplistic Transformer, which makes use of our predefined modules. Inside __init__
, we define the basic variables such as the number of layers, attention heads, and the dropout rate. Inside __call__
, we compose a list of blocks using a for
loop.
As you can see, each block includes:
A normalization layer
A self-attention block
Two dropout layers
Two normalization layers
Two skip connections (
h = h + h_attn
andh = h + h_dense
)A 2-layered Dense block
In the end, we also add a final normalization layer.
class Transformer(hk.Module):"""A transformer stack."""def __init__(self,num_heads: int,num_layers: int,dropout_rate: float,name: Optional[str] = None):super().__init__(name=name)self._num_layers = num_layersself._num_heads = num_headsself._dropout_rate = dropout_ratedef __call__(self,h: jnp.ndarray,mask: Optional[jnp.ndarray],is_training: bool) -> jnp.ndarray:"""Connects the transformer.Args:h: Inputs, [B, T, H].mask: Padding mask, [B, T].is_training: Whether we're training or not.Returns:Array of shape [B, T, H]."""init_scale = 2. / self._num_layersdropout_rate = self._dropout_rate if is_training else 0.if mask is not None:mask = mask[:, None, None, :]for i in range(self._num_layers):h_norm = layer_norm(h, name=f'h{i}_ln_1')h_attn = SelfAttention(num_heads=self._num_heads,key_size=64,w_init_scale=init_scale,name=f'h{i}_attn')(h_norm, mask=mask)h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn)h = h + h_attnh_norm = layer_norm(h, name=f'h{i}_ln_2')h_dense = DenseBlock(init_scale, name=f'h{i}_mlp')(h_norm)h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense)h = h + h_denseh = layer_norm(h, name='ln_f')return h
I think that by now you have realized that building a Neural Network with JAX is dead simple.
The embeddings layer
For completion, let’s also include the embeddings layer. It is good to know that Haiku also provides an embedding layer which will create the tokens from our input sentence. The token are then added to the positional embeddings, which produce the final input.
def embeddings(data: Mapping[str, jnp.ndarray], vocab_size: int) :tokens = data['obs']input_mask = jnp.greater(tokens, 0)seq_length = tokens.shape[1]# Embed the input tokens and positions.embed_init = hk.initializers.TruncatedNormal(stddev=0.02)token_embedding_map = hk.Embed(vocab_size, d_model, w_init=embed_init)token_embs = token_embedding_map(tokens)positional_embeddings = hk.get_parameter('pos_embs', [seq_length, d_model], init=embed_init)input_embeddings = token_embs + positional_embeddingsreturn input_embeddings, input_mask
hk.get_parameter(param_name, ...)
is used to access the trainable parameters of a module. But you may ask, why not just using object properties as we do in Pytorch. This is where the second key principle of Haiku comes into play. We use this API so that we can convert the code into a pure function using hk.transform
. This is not very simple to grasp but I will try to make it as clear as possible.
Why pure functions?
The power of JAX comes into its function transformations: the ability to vectorize a function with vmap
, the automatic parallelization with pmap
, just in time compilation with jit
. The caveat here is that in order to transform a function, it needs to be pure.
A pure function is a function that has the following properties:
The function return values are identical for identical arguments (no variation with local static variables, non-local variables, mutable reference arguments, or input streams).
The function application has no side effects (no mutation of local static variables, non-local variables, mutable reference arguments, or input/output streams).
Source: Scala pure functions by O'Reily
This practically means that a pure function will always:
return the same result if invoked with the same inputs
all the input data is passed through the function arguments, all the results are output through the function results
Haiku provides a function transformation, called hk.transform
, that turns functions with object-oriented, functionally “impure” modules into pure functions that can be used with JAX. To see that in practice, let’s continue with the training of our Transformer model.
The forward pass
A typical forward pass includes:
Taking the input and compute the input embedding
Run through the Transformer’s blocks
Return the output
The aforementioned steps can be easily composed with JAX as following:
def build_forward_fn(vocab_size: int, d_model: int, num_heads: int,num_layers: int, dropout_rate: float):"""Create the model's forward pass."""def forward_fn(data: Mapping[str, jnp.ndarray],is_training: bool = True) -> jnp.ndarray:"""Forward pass."""input_embeddings, input_mask = embeddings(data, vocab_size)# Run the transformer over the inputs.transformer = Transformer(num_heads=num_heads, num_layers=num_layers, dropout_rate=dropout_rate)output_embeddings = transformer(input_embeddings, input_mask, is_training)# Reverse the embeddings (untied).return hk.Linear(vocab_size)(output_embeddings)return forward_fn
Although the code is straightforward, its structure might seem a bit odd. The actual forward pass is executed through the forward_fn
function. However, we wrap this with the build_forward_fn
function which returns the forward_fn
. What the heck?
Down the road, we will need to transform the forward_fn
function into a pure function using hk.transform
so that we can take advantage of automatic differentiation, parallelization etc.
This will be accomplished by:
forward_fn = build_forward_fn(vocab_size, d_model, num_heads,num_layers, dropout_rate)forward_fn = hk.transform(forward_fn)
That’s why instead of simply defining a function, we wrapp and return the function itself, or a callable to be more precise. This callable can then be passed into the hk.transform
and become a pure function. If this is clear, let’s continue with our loss function.
The loss function
The loss function is our well-known cross-entropy function with the difference that we are also taking the mask into consideration. Once again, JAX provides one_hot
and log_softmax
functionalities.
def lm_loss_fn(forward_fn,vocab_size: int,params,rng,data: Mapping[str, jnp.ndarray],is_training: bool = True) -> jnp.ndarray:"""Compute the loss on data wrt params."""logits = forward_fn(params, rng, data, is_training)targets = jax.nn.one_hot(data['target'], vocab_size)assert logits.shape == targets.shapemask = jnp.greater(data['obs'], 0)loss = -jnp.sum(targets * jax.nn.log_softmax(logits), axis=-1)loss = jnp.sum(loss * mask) / jnp.sum(mask)return loss
If you are still with me, take a sip of coffee because things are going to get serious from now on. It’s time to build our training loop.
The training loop
Because neither Jax nor Haiku has optimization functionalities built-in, we will make use of another framework, called Optax. As mentioned in the beginning, Optax is the goto package for gradient processing.
First here are some things you need to know about Optax:
The key transformation of Optax is the GradientTransformation
. The transformation is defined by two functions, the __init__
and the __update__
. The __init__
initializes the state and the __update__
transforms the gradients with respect to the state and the current value of the parameters
state = init(params)grads, state = update(grads, state, params=None)
One more thing to know before we see the code, is Python’s built-in functools.partial
function. The functools
package deals with higher-order functions and operations on callable objects.
A function is called a Higher Order function if it contains other functions as a parameter or returns a function as an output.
The partial
, which can also be used as an annotation, returns a new function based on an original one, but with fewer or fixed arguments. If for example, f multiplies two values x,y, the partial will create a new function where x will be fixed and equal with 2
from functools import partialdef f(x,y):return x * y# creates a new function that multiplies by 2 ( x will be fixed and equal with 2)g = partial(f,2)print(g(4))#returns 8
After this short detour, let’s proceed. To decongest our main
function, we will extract the gradients update into its own class.
First of all the GradientUpdater
accepts the model, the loss function, and an optimizer.
- The model will be a pure
forward_fn
function transformed byhk.transform
forward_fn = build_forward_fn(vocab_size, d_model, num_heads,num_layers, dropout_rate)forward_fn = hk.transform(forward_fn)
- The loss function will be the result of a partial with a fixed
forward_fn
and `vocab_size
loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)
- The optimizer is a set of optimization transformations that will run sequentially ( operations can be combined using
optax.chain
)
optimizer = optax.chain(optax.clip_by_global_norm(grad_clip_value),optax.adam(learning_rate, b1=0.9, b2=0.99))
The Gradient updater will be initialized as follows:
updater = GradientUpdater(forward_fn.init, loss_fn, optimizer)
and will look like this:
class GradientUpdater:"""A stateless abstraction around an init_fn/update_fn pair.This extracts some common boilerplate from the training loop."""def __init__(self, net_init, loss_fn,optimizer: optax.GradientTransformation):self._net_init = net_initself._loss_fn = loss_fnself._opt = optimizer@functools.partial(jax.jit, static_argnums=0)def init(self, master_rng, data):"""Initializes state of the updater."""out_rng, init_rng = jax.random.split(master_rng)params = self._net_init(init_rng, data)opt_state = self._opt.init(params)out = dict(step=np.array(0),rng=out_rng,opt_state=opt_state,params=params,)return out@functools.partial(jax.jit, static_argnums=0)def update(self, state: Mapping[str, Any], data: Mapping[str, jnp.ndarray]):"""Updates the state using some data and returns metrics."""rng, new_rng = jax.random.split(state['rng'])params = state['params']loss, g = jax.value_and_grad(self._loss_fn)(params, rng, data)updates, opt_state = self._opt.update(g, state['opt_state'])params = optax.apply_updates(params, updates)new_state = {'step': state['step'] + 1,'rng': new_rng,'opt_state': opt_state,'params': params,}metrics = {'step': state['step'],'loss': loss,}return new_state, metrics
Inside __init__
, we initialize our optimizer with self._opt.init(params)
and we declare the state of the optimization. The state will be a dictionary with:
The current step
The optimizer state
The trainable parameters
( A random generator key to pass into
jax.random.split
)
The update
function will update both the state of the optimizer as well as the trainable parameters. In the end, it will return the new state.
updates, opt_state = self._opt.update(g, state['opt_state'])params = optax.apply_updates(params, updates)
Two more things to notice here:
jax.value_and_grad()
is a special function that returns a differentiable function with its gradientsBoth
__init__
and__update__
are annotated with@functools.partial(jax.jit, static_argnums=0)
, which will trigger the just-in-time compiler and compile them into XLA during runtime. Note that if we haven’t transformedforward_fn
into a pure function, this wouldn’t be possible.
Finally, we are ready to build the entire training loop, which combines all the ideas and code mentioned so far.
def main():# Create the dataset.train_dataset, vocab_size = load(batch_size,sequence_length)# Set up the model, loss, and updater.forward_fn = build_forward_fn(vocab_size, d_model, num_heads,num_layers, dropout_rate)forward_fn = hk.transform(forward_fn)loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)optimizer = optax.chain(optax.clip_by_global_norm(grad_clip_value),optax.adam(learning_rate, b1=0.9, b2=0.99))updater = GradientUpdater(forward_fn.init, loss_fn, optimizer)# Initialize parameters.logging.info('Initializing parameters...')rng = jax.random.PRNGKey(428)data = next(train_dataset)state = updater.init(rng, data)logging.info('Starting train loop...')prev_time = time.time()for step in range(MAX_STEPS):data = next(train_dataset)state, metrics = updater.update(state, data)
Notice how we incorporate the GradientUpdate
. It’s just two lines of code:
state = updater.init(rng, data)
state, metrics = updater.update(state, data)
And that’s it. I hope that by now you have a more clear understanding of JAX and its capabilities.
Acknowledgments
The code presented is heavily inspired by the official examples of the Haiku framework. It has been modified to fit the needs of this article. For the complete list of examples, check the official repository
Conclusion
In this article, we saw how one can develop and train a vanilla Transformer in JAX using Haiku. Although the code isn’t necessarily hard to grasp, it still lacks the readability of Pytorch or Tensorflow. I highly recommend to play around with it, discover the strengths and weaknesses of JAX and see if it’d be a good fit for your next project. In my experience, JAX is very strong for research applications that require high performance but quite immature for real-life projects. Let us know what you think in our discord channel.
* Disclosure: Please note that some of the links above might be affiliate links, and at no additional cost to you, we will earn a commission if you decide to make a purchase after clicking through.