📖 You can now grab a copy of our new Deep Learning in Production Book 📖

# JAX vs Tensorflow vs Pytorch: Building a Variational Autoencoder (VAE)

Sergios Karagiannakoson2021-04-01·4 mins

I was very curious to see how JAX is compared to Pytorch or Tensorflow. I figured that the best way for someone to compare frameworks is to build the same thing from scratch in both of them. And that’s exactly what I did. In this article, I am developing a Variational Autoencoder with JAX, Tensorflow and Pytorch at the same time. I will present the code for each component side by side in order to find differences, similarities, weaknesses and strengths.

Shall we begin?

## Prologue

Some things to note before we explore the code:

• I will use Flax on top of JAX, which is a neural network library developed by Google. It contains many ready-to-use deep learning modules, layers, functions, and operations

• For the Tensorflow implementation, I will rely on Keras abstractions.

• For Pytorch, I will use the standard nn.module.

Because most of us are somewhat familiar with Tensorflow and Pytorch, we will pay more attention in JAX and Flax. That’s why I will explain things along the way that may be unfamiliar to many. So you can consider this article as a light tutorial on Flax as well.

Also, I assume that you are familiar with the basic principles behind VAEs. If not, you can advise my previous article on latent variable models. If everything seems clear, let’s continue.

Quick recap: The vanilla Autoencoder consists of an Encoder and a Decoder. The encoder converts the input to a latent representation $z$ and the decoder tries to reconstruct the input based on that representation. In Variational Autoencoders, stochasticity is also added to the mix in terms that the latent representation provides a probability distribution. This is happening with the reparametrization trick.

Image by author

## The encoder

For the encoder, a simple linear layer followed by a RELU activation should be enough for a toy example. The output of the layer will be both the mean and standard deviation of the probability distribution.

The basic building block of the Flax API is the Module abstraction, which is what we’ll use to implement our encoder in JAX. The module is part of the linen subpackage. Similar to Pytorch’s nn.module, we again need to define our class arguments. In Pytorch, we are used to declaring them inside the __init__ function and implementing the forward pass inside the forward method. In Flax, things are a little different. Arguments are defined either as dataclass attributes or as method arguments. Usually, fixed properties are defined as dataclass arguments while dynamic properties as method arguments. Also instead of implementing a forward method, we implement __call__

The Dataclass module is introduced in Python 3.7 as a utility tool to make structured classes especially for storing data. These classes hold certain properties and functions to deal specifically with the data and its representation. They also reduce a lot of boilerplate code compared to regular classes.

So to create a new module in Flax, we need to:

• Initialize a class that inherits flax.linen.nn.Module

• Define the static arguments as dataclass arguments

• Implement the forward pass inside the __call_ method.

To tie the arguments with the model and being able to define submodules directly within the module, we also need to annotate the __call__ method with @nn.compact.

Note that instead of using dataclass arguments and the @nn.compact annotation, we could have declared all arguments inside a setup method in the exact same way as we do in Pytorch’s or Tensorflow’s __init__.

##### Pytorch
import numpy as npimport jaximport jax.numpy as jnpfrom jax import randomfrom flax import linen as nnfrom flax import optim
class Encoder(nn.Module): latents: int
@nn.compact def __call__(self, x):   x = nn.Dense(500, name='fc1')(x)   x = nn.relu(x)   mean_x = nn.Dense(self.latents, name='fc2_mean')(x)   logvar_x = nn.Dense(self.latents, name='fc2_logvar')(x)   return mean_x, logvar_x
import tensorflow as tffrom tensorflow.keras import layers
class Encoder(layers.Layer): def __init__(self,              latent_dim =20,              name='encoder',              **kwargs):   super(Encoder, self).__init__(name=name, **kwargs)   self.enc1 = layers.Dense(500, activation='relu')   self.mean_x = layers.Dense(latent_dim)   self.logvar_x = layers.Dense(latent_dim)
def call(self, inputs):   x = self.enc1(inputs)   z_mean = self.mean_x(x)   z_log_var = self.logvar_x(x)   return z_mean, z_log_var
import torch
import torch.nn.functional as F
class Encoder(torch.nn.Module): def __init__(self, latent_dim=20):   super(Encoder, self).__init__()   self.enc1 = torch.nn.Linear(784, 500)   self.mean_x = torch.nn.Linear(500,latent_dim)   self.logvar_x = torch.nn.Linear(500, latent_dim)
def forward(self,inputs):   x = self.enc1(inputs)   x= F.relu(x)   z_mean = self.mean_x(x)   z_log_var = self.logvar_x(x)   return z_mean, z_log_var

A few more things to notice here before we proceed:

• Flax’s nn.linen package contains most deep learning layers and operation such as Dense, relu, and many more

• The code in Flax, Tensorflow, and Pytorch is almost indistinguishable from each other.

## The decoder

In a very similar fashion, we can develop the decoder in all 3 frameworks. The decoder will be two linear layers that receive the latent representation $z$ and output the reconstructed input.

Again the implementations are very similar.

##### Pytorch
class Decoder(nn.Module):
@nn.compact def __call__(self, z):   z = nn.Dense(500, name='fc1')(z)   z = nn.relu(z)   z = nn.Dense(784, name='fc2')(z)   return z
class Decoder(layers.Layer):
def __init__(self,               name='decoder',               **kwargs):   super(Decoder, self).__init__(name=name, **kwargs)   self.dec1 = layers.Dense(500, activation='relu')   self.out = layers.Dense(784)
def call(self, z):   z = self.dec1(z)   return self.out(z)
class Decoder(torch.nn.Module): def __init__(self, latent_dim=20):   super(Decoder, self).__init__()   self.dec1 = torch.nn.Linear(latent_dim, 500)   self.out = torch.nn.Linear(500, 784)

def forward(self,z):   z = self.dec1(z)   z = F.relu(z)   return self.out(z)

## Variational Autoencoder

To combine the encoder and the decoder, let’s have one more class, called VAE, that will represent the entire architecture. Here we also need to write some code for the reparameterization trick. Overall we have: the latent variable from the encoder is reparameterized and fed to the decoder, which produces the reconstructed input.

As a reminder, here is an intuitive image that explains the reparameterization trick:

Notice that this time, in JAX we make use of the setup method instead of the nn.compact annotation. Also, check out how similar the reparameterization functions are. Sure each framework uses its own functions and operations but the general image is almost identical.

##### Pytorch
class VAE(nn.Module): latents: int = 20
def setup(self):   self.encoder = Encoder(self.latents)   self.decoder = Decoder()
def __call__(self, x, z_rng):   mean, logvar = self.encoder(x)   z = reparameterize(z_rng, mean, logvar)   recon_x = self.decoder(z)   return recon_x, mean, logvar
def reparameterize(rng, mean, logvar): std = jnp.exp(0.5 * logvar) eps = random.normal(rng, logvar.shape) return mean + eps * std
def model(): return VAE(latents=LATENTS)
class VAE(tf.keras.Model): def __init__(self,              latent_dim=20,              name='vae',              **kwargs):   super(VAE, self).__init__(name=name, **kwargs)   self.encoder = Encoder(latent_dim=latent_dim)   self.decoder = Decoder()
def call(self, inputs):   z_mean, z_log_var = self.encoder(inputs)   z = self.reparameterize(z_mean, z_log_var)   reconstructed = self.decoder(z)   return reconstructed, z_mean, z_log_var
def reparameterize(self, mean, logvar):     eps = tf.random.normal(shape=mean.shape)     return mean + eps * tf.exp(logvar * .5)
class VAE(torch.nn.Module):   def __init__(self, latent_dim=20):       super(VAE, self).__init__()       self.encoder = Encoder(latent_dim)       self.decoder = Decoder(latent_dim)
def forward(self,inputs):     z_mean, z_log_var = self.encoder(inputs)     z = self.reparameterize(z_mean, z_log_var)     reconstructed = self.decoder(z)     return reconstructed, z_mean, z_log_var
def reparameterize(self, mu, log_var):     std = torch.exp(0.5 * log_var)     eps = torch.randn_like(std)     return mu + (eps * std)

## Loss and Training step

Things are starting to differ when we begin implementing the training step and the loss function. But not by much.

1. In order to fully take advantage of JAX capabilities, we need to add automatic vectorization and XLA compiling to our code. This can be done easily with the help of vmap and jit annotations.

2. Moreover, we have to enable automatic differentiation, which can be accomplished with the grad_fn transformation

3. We use the flax.optim package for optimization algorithms

Another small difference that we need to be aware of is how we pass data to our model. This can be achieved through the apply method in the form of model().apply({'params': params}, batch, z_rng), where batch is our training data.

##### Pytorch
@jax.vmapdef kl_divergence(mean, logvar): return -0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar))
@jax.vmapdef binary_cross_entropy_with_logits(logits, labels): logits = nn.log_sigmoid(logits) return -jnp.sum(labels * logits + (1. - labels) * jnp.log(-jnp.expm1(logits)))
@jax.jitdef train_step(optimizer, batch, z_rng):  def loss_fn(params):   recon_x, mean, logvar = model().apply({'params': params}, batch, z_rng)
bce_loss = binary_cross_entropy_with_logits(recon_x, batch).mean()   kld_loss = kl_divergence(mean, logvar).mean()   loss = bce_loss + kld_loss   return loss, recon_x
grad_fn = jax.value_and_grad(loss_fn, has_aux=True) _, grad = grad_fn(optimizer.target) optimizer = optimizer.apply_gradient(grad) return optimizer
def kl_divergence(mean, logvar): return -0.5 * tf.reduce_sum(           1 + logvar - tf.square(mean) -           tf.exp(logvar), axis=1)
def binary_cross_entropy_with_logits(logits, labels): logits = tf.math.log(logits) return - tf.reduce_sum(           labels * logits +           (1-labels) * tf.math.log(- tf.math.expm1(logits)),           axis=1       )
@tf.functiondef train_step(model, x, optimizer):  with tf.GradientTape() as tape:   recon_x, mean, logvar = model(x)   bce_loss = tf.reduce_mean(binary_cross_entropy_with_logits(recon_x, batch))   kld_loss = tf.reduce_mean(kl_divergence(mean, logvar))   loss = bce_loss + kld_loss   print(loss, kld_loss, bce_loss)

gradients = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(gradients, model.trainable_variables))
def final_loss(reconstruction, train_x, mu, logvar):   BCE = torch.nn.BCEWithLogitsLoss(reduction='sum')(reconstruction, train_x)   KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())   return BCE + KLD
def train_step(train_x): train_x = torch.from_numpy(train_x) optimizer.zero_grad() reconstruction, mu, logvar = model(train_x) loss = final_loss(reconstruction, train_x,  mu, logvar) running_loss += loss.item() loss.backward() optimizer.step()

Remember that VAEs are trained by maximizing the evidence lower bound, known as ELBO.

$L_{\theta,\phi}(x) = \textbf{E}_{q_{\phi}(z|x)} [ log p_{\theta}(x|z) ] - \textbf{KL}(q_{\phi}(z |x) || p_{\theta}(z))$

## Training loop

Finally, it’s time for the entire training loop which will execute the train_step function iteratively.

In Flax, the model has to be initialized before training, which is done by the init function such as: params = model().init(key, init_data, rng)['params']. A similar initialization is necessary for the optimizer as well: optimizer = optim.Adam( learning_rate = LEARNING_RATE ).create( params ).

jax.device_put is used to transfer the optimizer into the GPU’s memory.

##### Pytorch
rng = random.PRNGKey(0)rng, key = random.split(rng)
init_data = jnp.ones((BATCH_SIZE, 784), jnp.float32)params = model().init(key, init_data, rng)['params']
rng, z_key, eval_rng = random.split(rng, 3)z = random.normal(z_key, (64, LATENTS))
steps_per_epoch = 50000 // BATCH_SIZE
for epoch in range(NUM_EPOCHS): for _ in range(steps_per_epoch):   batch = next(train_ds)   rng, key = random.split(rng)   optimizer = train_step(optimizer, batch, key)
vae = VAE(latent_dim=LATENTS)optimizer = tf.keras.optimizers.Adam(1e-4)
for epoch in range(NUM_EPOCHS): for train_x in train_ds:   train_step(vae, train_x, optimizer)
def train(model,training_data):
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)   running_loss = 0.0
for epoch in range(NUM_EPOCHS):       for i, train_x in enumerate(training_data, 0):          train_step(train_x)
vae = VAE(LATENTS)train(vae, train_ds)

One thing I haven’t mentioned is data. How do we load and preprocess data in Flax? Well, Flax doesn’t include data manipulation packages yet besides the basic operations of jax.numpy. Right now, our best is to borrow packages from other frameworks such as Tensorflow datasets (tfds) or Torchvision. To make the article self-complete, I will include the code I used to load a sample training dataset with tfds. Feel free though to use your own dataloader if you’re planning to run the implementations presented in this article.

import tensorflow_datasets as tfds
tf.config.experimental.set_visible_devices([], 'GPU')
def prepare_image(x): x = tf.cast(x['image'], tf.float32) x = tf.reshape(x, (-1,)) return x
ds_builder = tfds.builder('binarized_mnist')ds_builder.download_and_prepare()train_ds = ds_builder.as_dataset(split=tfds.Split.TRAIN)train_ds = train_ds.map(prepare_image)train_ds = train_ds.cache()train_ds = train_ds.repeat()train_ds = train_ds.shuffle(50000)train_ds = train_ds.batch(BATCH_SIZE)train_ds = iter(tfds.as_numpy(train_ds))
test_ds = ds_builder.as_dataset(split=tfds.Split.TEST)test_ds = test_ds.map(prepare_image).batch(10000)test_ds = np.array(list(test_ds)[0])

## Final observations

To close the article, let’s discuss a few final observations that appear after a close analysis of the code:

• All 3 frameworks have reduced the boilerplate code to a minimum with Flax being the one that requires a bit more, especially on the training part. However this is only to ensure that we exploit all the available transformations such as automatic differentiation, vectorization and just-in-time compiler.

• The definition of modules, layers and models is almost identical in all of them

• Flax and JAX is by design quite flexible and expandable