I was very curious to see how JAX is compared to Pytorch or Tensorflow. I figured that the best way for someone to compare frameworks is to build the same thing from scratch in both of them. And that’s exactly what I did. In this article, I am developing a Variational Autoencoder with JAX, Tensorflow and Pytorch at the same time. I will present the code for each component side by side in order to find differences, similarities, weaknesses and strengths.

Shall we begin?

## Prologue

Some things to note before we explore the code:

I will use Flax on top of JAX, which is a neural network library developed by Google. It contains many ready-to-use deep learning modules, layers, functions, and operations

For the Tensorflow implementation, I will rely on Keras abstractions.

For Pytorch, I will use the standard

`nn.module`

.

Because most of us are somewhat familiar with Tensorflow and Pytorch, we will pay more attention in JAX and Flax. That’s why I will explain things along the way that may be unfamiliar to many. So you can consider this article as a light tutorial on Flax as well.

Also, I assume that you are familiar with the basic principles behind VAEs. If not, you can advise my previous article on latent variable models. If everything seems clear, let’s continue.

**Quick recap**: The vanilla Autoencoder consists of an Encoder and a Decoder. The encoder converts the input to a latent representation $z$ and the decoder tries to reconstruct the input based on that representation. In Variational Autoencoders, stochasticity is also added to the mix in terms that the latent representation provides a probability distribution. This is happening with the reparametrization trick.

## The encoder

For the encoder, a simple linear layer followed by a RELU activation should be enough for a toy example. The output of the layer will be both the mean and standard deviation of the probability distribution.

The basic building block of the Flax API is the `Module`

abstraction, which is what we’ll use to implement our encoder in JAX. The `module`

is part of the `linen`

subpackage. Similar to Pytorch’s `nn.module`

, we again need to define our class arguments. In Pytorch, we are used to declaring them inside the `__init__`

function and implementing the forward pass inside the `forward`

method. In Flax, things are a little different. Arguments are defined either as dataclass attributes or as method arguments. Usually, fixed properties are defined as dataclass arguments while dynamic properties as method arguments. Also instead of implementing a `forward`

method, we implement `__call__`

The Dataclass module is introduced in Python 3.7 as a utility tool to make structured classes especially for storing data. These classes hold certain properties and functions to deal specifically with the data and its representation. They also reduce a lot of boilerplate code compared to regular classes.

So to create a new module in Flax, we need to:

Initialize a class that inherits

`flax.linen.nn.Module`

Define the static arguments as dataclass arguments

Implement the forward pass inside the

`__call_`

method.

To tie the arguments with the model and being able to define submodules directly within the module, we also need to annotate the `__call__`

method with `@nn.compact`

.

Note that instead of using dataclass arguments and the `@nn.compact`

annotation, we could have declared all arguments inside a `setup`

method in the exact same way as we do in Pytorch’s or Tensorflow’s `__init__`

.

##### JAX

##### Tensorflow

##### Pytorch

import numpy as npimport jaximport jax.numpy as jnpfrom jax import randomfrom flax import linen as nnfrom flax import optimclass Encoder(nn.Module):latents: int@nn.compactdef __call__(self, x):x = nn.Dense(500, name='fc1')(x)x = nn.relu(x)mean_x = nn.Dense(self.latents, name='fc2_mean')(x)logvar_x = nn.Dense(self.latents, name='fc2_logvar')(x)return mean_x, logvar_x

import tensorflow as tffrom tensorflow.keras import layersclass Encoder(layers.Layer):def __init__(self,latent_dim =20,name='encoder',**kwargs):super(Encoder, self).__init__(name=name, **kwargs)self.enc1 = layers.Dense(500, activation='relu')self.mean_x = layers.Dense(latent_dim)self.logvar_x = layers.Dense(latent_dim)def call(self, inputs):x = self.enc1(inputs)z_mean = self.mean_x(x)z_log_var = self.logvar_x(x)return z_mean, z_log_var

import torchimport torch.nn.functional as Fclass Encoder(torch.nn.Module):def __init__(self, latent_dim=20):super(Encoder, self).__init__()self.enc1 = torch.nn.Linear(784, 500)self.mean_x = torch.nn.Linear(500,latent_dim)self.logvar_x = torch.nn.Linear(500, latent_dim)def forward(self,inputs):x = self.enc1(inputs)x= F.relu(x)z_mean = self.mean_x(x)z_log_var = self.logvar_x(x)return z_mean, z_log_var

A few more things to notice here before we proceed:

Flax’s

`nn.linen`

package contains most deep learning layers and operation such as`Dense`

,`relu`

, and many moreThe code in Flax, Tensorflow, and Pytorch is almost indistinguishable from each other.

## The decoder

In a very similar fashion, we can develop the decoder in all 3 frameworks. The decoder will be two linear layers that receive the latent representation $z$ and output the reconstructed input.

Again the implementations are very similar.

##### JAX

##### Tensorflow

##### Pytorch

class Decoder(nn.Module):@nn.compactdef __call__(self, z):z = nn.Dense(500, name='fc1')(z)z = nn.relu(z)z = nn.Dense(784, name='fc2')(z)return z

class Decoder(layers.Layer):def __init__(self,name='decoder',**kwargs):super(Decoder, self).__init__(name=name, **kwargs)self.dec1 = layers.Dense(500, activation='relu')self.out = layers.Dense(784)def call(self, inputs):z = self.dec1(z)return self.out(z)

class Decoder(torch.nn.Module):def __init__(self, latent_dim=20):super(Decoder, self).__init__()self.dec1 = torch.nn.Linear(latent_dim, 500)self.out = torch.nn.Linear(500, 784)def forward(self,z):z = self.dec1(z)z = F.relu(z)return self.out(z)

## Variational Autoencoder

To combine the encoder and the decoder, let’s have one more class, called `VAE`

, that will represent the entire architecture. Here we also need to write some code for the reparameterization trick. Overall we have: the latent variable from the encoder is reparameterized and fed to the decoder, which produces the reconstructed input.

As a reminder, here is an intuitive image that explains the reparameterization trick:

*Source: Alexander Amini and Ava Soleimany, Deep Generative Modeling | MIT 6.S191, http://introtodeeplearning.com/*

Notice that this time, in JAX we make use of the `setup`

method instead of the `nn.compact`

annotation. Also, check out how similar the reparameterization functions are. Sure each framework uses its own functions and operations but the general image is almost identical.

##### JAX

##### Tensorflow

##### Pytorch

class VAE(nn.Module):latents: int = 20def setup(self):self.encoder = Encoder(self.latents)self.decoder = Decoder()def __call__(self, x, z_rng):mean, logvar = self.encoder(x)z = reparameterize(z_rng, mean, logvar)recon_x = self.decoder(z)return recon_x, mean, logvardef reparameterize(rng, mean, logvar):std = jnp.exp(0.5 * logvar)eps = random.normal(rng, logvar.shape)return mean + eps * stddef model():return VAE(latents=LATENTS)

class VAE(tf.keras.Model):def __init__(self,latent_dim=20,name='vae',**kwargs):super(VAE, self).__init__(name=name, **kwargs)self.encoder = Encoder(latent_dim=latent_dim)self.decoder = Decoder()def call(self, inputs):z_mean, z_log_var = self.encoder(inputs)z = self.reparameterize(z_mean, z_log_var)reconstructed = self.decoder(z)return reconstructed, z_mean, z_log_vardef reparameterize(self, mean, logvar):eps = tf.random.normal(shape=mean.shape)return mean + eps * tf.exp(logvar * .5)

class VAE(torch.nn.Module):def __init__(self, latent_dim=20):super(VAE, self).__init__()self.encoder = Encoder(latent_dim)self.decoder = Decoder(latent_dim)def forward(self,inputs):z_mean, z_log_var = self.encoder(inputs)z = self.reparameterize(z_mean, z_log_var)reconstructed = self.decoder(z)return reconstructed, z_mean, z_log_vardef reparameterize(self, mu, log_var):std = torch.exp(0.5 * log_var)eps = torch.randn_like(std)return mu + (eps * std)

## Loss and Training step

Things are starting to differ when we begin implementing the training step and the loss function. But not by much.

In order to fully take advantage of JAX capabilities, we need to add automatic vectorization and XLA compiling to our code. This can be done easily with the help of

`vmap`

and`jit`

annotations.Moreover, we have to enable automatic differentiation, which can be accomplished with the

`grad_fn`

transformationWe use the

`flax.optim`

package for optimization algorithms

Another small difference that we need to be aware of is how we pass data to our model. This can be achieved through the apply method in the form of `model().apply({'params': params}, batch, z_rng)`

, where `batch`

is our training data.

##### JAX

##### Tensorflow

##### Pytorch

@jax.vmapdef kl_divergence(mean, logvar):return -0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar))@jax.vmapdef binary_cross_entropy_with_logits(logits, labels):logits = nn.log_sigmoid(logits)return -jnp.sum(labels * logits + (1. - labels) * jnp.log(-jnp.expm1(logits)))@jax.jitdef train_step(optimizer, batch, z_rng):def loss_fn(params):recon_x, mean, logvar = model().apply({'params': params}, batch, z_rng)bce_loss = binary_cross_entropy_with_logits(recon_x, batch).mean()kld_loss = kl_divergence(mean, logvar).mean()loss = bce_loss + kld_lossreturn loss, recon_xgrad_fn = jax.value_and_grad(loss_fn, has_aux=True)_, grad = grad_fn(optimizer.target)optimizer = optimizer.apply_gradient(grad)return optimizer

def kl_divergence(mean, logvar):return -0.5 * tf.reduce_sum(1 + logvar - tf.square(mean) -tf.exp(logvar), axis=1)def binary_cross_entropy_with_logits(logits, labels):logits = tf.math.log(logits)return - tf.reduce_sum(labels * logits +(1-labels) * tf.math.log(- tf.math.expm1(logits)),axis=1)@tf.functiondef train_step(model, x, optimizer):with tf.GradientTape() as tape:recon_x, mean, logvar = model(x)bce_loss = tf.reduce_mean(binary_cross_entropy_with_logits(recon_x, batch))kld_loss = tf.reduce_mean(kl_divergence(mean, logvar))loss = bce_loss + kld_lossprint(loss, kld_loss, bce_loss)gradients = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(gradients, model.trainable_variables))

def final_loss(reconstruction, train_x, mu, logvar):BCE = torch.nn.BCEWithLogitsLoss(reduction='sum')(reconstruction, train_x)KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())return BCE + KLDdef train_step(train_x):train_x = torch.from_numpy(train_x)optimizer.zero_grad()reconstruction, mu, logvar = model(train_x)loss = final_loss(reconstruction, train_x, mu, logvar)running_loss += loss.item()loss.backward()optimizer.step()

Remember that VAEs are trained by maximizing the evidence lower bound, known as ELBO.

$L_{\theta,\phi}(x) = \textbf{E}_{q_{\phi}(z|x)} [ log p_{\theta}(x|z) ] - \textbf{KL}(q_{\phi}(z |x) || p_{\theta}(z))$## Training loop

Finally, it’s time for the entire training loop which will execute the `train_step`

function iteratively.

In Flax, the model has to be initialized before training, which is done by the `init`

function such as: `params = model().init(key, init_data, rng)['params']`

. A similar initialization is necessary for the optimizer as well: `optimizer = optim.Adam( learning_rate = LEARNING_RATE ).create( params )`

.

`jax.device_put`

is used to transfer the optimizer into the GPU’s memory.

##### JAX

##### Tensorflow

##### Pytorch

rng = random.PRNGKey(0)rng, key = random.split(rng)init_data = jnp.ones((BATCH_SIZE, 784), jnp.float32)params = model().init(key, init_data, rng)['params']optimizer = optim.Adam(learning_rate=LEARNING_RATE).create(params)optimizer = jax.device_put(optimizer)rng, z_key, eval_rng = random.split(rng, 3)z = random.normal(z_key, (64, LATENTS))steps_per_epoch = 50000 // BATCH_SIZEfor epoch in range(NUM_EPOCHS):for _ in range(steps_per_epoch):batch = next(train_ds)rng, key = random.split(rng)optimizer = train_step(optimizer, batch, key)

vae = VAE(latent_dim=LATENTS)optimizer = tf.keras.optimizers.Adam(1e-4)for epoch in range(NUM_EPOCHS):for train_x in train_ds:train_step(vae, train_x, optimizer)

def train(model,training_data):optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)running_loss = 0.0for epoch in range(NUM_EPOCHS):for i, train_x in enumerate(training_data, 0):train_step(train_x)vae = VAE(LATENTS)train(vae, train_ds)

## Load and Process Data

One thing I haven’t mentioned is data. How do we load and preprocess data in Flax? Well, Flax doesn’t include data manipulation packages yet besides the basic operations of `jax.numpy`

. Right now, our best is to borrow packages from other frameworks such as Tensorflow datasets (tfds) or Torchvision. To make the article self-complete, I will include the code I used to load a sample training dataset with `tfds`

. Feel free though to use your own dataloader if you’re planning to run the implementations presented in this article.

import tensorflow_datasets as tfdstf.config.experimental.set_visible_devices([], 'GPU')def prepare_image(x):x = tf.cast(x['image'], tf.float32)x = tf.reshape(x, (-1,))return xds_builder = tfds.builder('binarized_mnist')ds_builder.download_and_prepare()train_ds = ds_builder.as_dataset(split=tfds.Split.TRAIN)train_ds = train_ds.map(prepare_image)train_ds = train_ds.cache()train_ds = train_ds.repeat()train_ds = train_ds.shuffle(50000)train_ds = train_ds.batch(BATCH_SIZE)train_ds = iter(tfds.as_numpy(train_ds))test_ds = ds_builder.as_dataset(split=tfds.Split.TEST)test_ds = test_ds.map(prepare_image).batch(10000)test_ds = np.array(list(test_ds)[0])

## Final observations

To close the article, let’s discuss a few final observations that appear after a close analysis of the code:

All 3 frameworks have reduced the boilerplate code to a minimum with Flax being the one that requires a bit more, especially on the training part. However this is only to ensure that we exploit all the available transformations such as automatic differentiation, vectorization and just-in-time compiler.

The definition of modules, layers and models is almost identical in all of them

Flax and JAX is by design quite flexible and expandable

Flax doesn’t have data loading and processing capabilities yet

In terms of ready-to-use layers and optimizers, Flax doesn’t need to be jealous of Tensorflow and Pytorch. For sure it lacks the giant library of its competitors but it’s gradually getting there.

** Disclosure: Please note that some of the links above might be affiliate links, and at no additional cost to you, we will earn a commission if you decide to make a purchase after clicking through.*