New Course: Introduction to Deep Learning and Neural Networks

# Understanding SWAV: self-supervised learning with contrasting cluster assignments

Unsupervised Learning Self-supervised learning aims to extract representation from unsupervised visual data and it’s super famous in computer vision nowadays. This article covers the SWAV method, a robust self-supervised learning paper from a mathematical perspective. To that end, we provide insights and intuitions for why this method works. Additionally, we will discuss the optimal transport problem with entropy constraint and its fast approximation that is a key point of the SWAV method that is hidden when you read the paper.

In any case, if you want to learn more about general aspects of self-supervised learning, like augmentation, intuitions, softmax with temperature, and contrastive learning, consult our previous article.

## SWAV Method overview

### Definitions

Let two image features $\mathbf{z}_t$ and $\mathbf{z}_s$ be two different augmentations of the same image. The image features are generated by taking stochastic augmentations $t \sim T$ of the same image $\mathbf{X}$.

• Our actual targets: Let $\mathbf{q}_t$ and $\mathbf{q}_s$ be the respective codes of the image views. Codes can be regarded as a soft class of the image.

• Prototypes: consider a set of $K$ prototypes ${\mathbf{c}_1, .., \mathbf{c}_K}$ lying in the unit sphere. The prototypes are trainable vectors that will move based on the dataset’s over-represented (frequent) features. If the dataset consists only of cars, then it will be the most part of a car like a wheel, car windows, car lights, mirrors etc. One way to think about it is as a low-dimensional projection of the whole dataset.

Clusters and prototypes are used interchangeably throughout this article. Don’t confuse it with “codes” though! Nonetheless, codes and assignments are also used interchangeably.

SWAV 1 compares the features $\mathbf{z}_t$ and $\mathbf{z}_s$ using the intermediate codes (soft classes) $\mathbf{q}_t$ and $\mathbf{q}_s$. For now, ignore how we compute the codes. Keep it as a target in a standard supervised classification problem.

Intuition: If $\mathbf{z}_t$ and $\mathbf{z}_s$ capture similar information, we can predict the code $\mathbf{q}_s$ (soft class) from the other feature $\mathbf{z}_t$. In other words, if the two views share the same semantics, their targets (codes) will be similar. This is the whole “swapping” idea.

### Difference between SWAV and SimCLR

In contrastive learning methods, the features from different transformations of the same images are compared directly to each other. SWAV does not directly compare image features. Why?

In SwAV, there is the intermediate “codes'' step ($Q$). To create the codes (targets), we need to assign the image features to prototype vectors. We then solve a “swapped'' prediction problem wherein the codes (targets) are altered for the two image views.

Prototype vectors ${\mathbf{c}_1, .., \mathbf{c}_K}$ are learned but they are still in the unit sphere area, meaning their L2 norm will be 1.

### The unit sphere and its implications

By definition, a unit sphere is the set of points with L2 distance equal to 1 from a fixed central point, here the origin. Note that this is different from a unit ball, where the L2 distance is less than or equal to 1 from the centre.

Moving on the surface of the sphere corresponds to a smooth change in assignments. In fact many self-supervised methods are using this L2-norm trick, and especially contrastive methods. SWAV also applies L2-normalization to the features as well as to the prototypes throughout training.

### SWAV method Steps

Let’s recap the steps of SWAV:

1. Create $N$ views from input image $X$ using a set of stochastic transformations $T$

2. Calculate image feature representations $z$

3. Calculate softmax-normalized similarities between all $z$ and $c$: $softmax(z^T c)$

4. Calculate code matrix $Q$ iteratively. We intentionally ignored this part. See further on for this step.

5. Calculate cross-entropy loss between representation $t$, aka $z_t$ and the code of representation $s$, aka $q_s$

6. Average loss between all $N$ views.

Again, notice the difference between cluster assignments (codes) and cluster prototype vectors ($c$). Here is a detailed explanation of the loss function:

## Digging into SWAV’s math: approximating Q

### Understanding the Optimal Transport problem with Entropy

As discussed, the code vectors $\mathbf{q}_1,...,\mathbf{q}_B$ act as a target in the cross-entropy loss term. In SWAV, these code vectors are computed online during every iteration. Online means that we approximate $\mathbf{Q}$ in each forward pass by an iterative process. No gradients and backprop to estimate $\mathbf{Q}$.

For $K$ prototypes and batch size $B$, the optimal code matrix $\mathbf{Q} = [\mathbf{q}_1,...,\mathbf{q}_B] \in R^{K \times B}$ is defined as the solution to an optimal transport problem with entropic constraint. The solution is approximated using the iterative Sinkhorn-Knopp algorithm 3.

For a formal and very detailed formulation, analysis and solution of said problem, I recommend having a look at the paper 3.

For SwAV, we define the optimal code matrix $mathbf{Q}$ as:

$\mathbf{Q}^* = \max_{\mathbf{Q} \in \mathcal{Q}} \text{Tr} (\mathbf{Q}^T \mathbf{C}^T \mathbf{Z}) + \varepsilon H(\mathbf{Q}),$
$\mathcal{Q} = \big\{ \mathbf{Q} \in \mathbb{R}^{K \times B}_+ | \mathbf{Q} \mathbf{1}_B = \frac{1}{K} \mathbf{1}_K, \mathbf{Q}^T \mathbf{1}_K = \frac{1}{B} \mathbf{1}_B \big\},$

with $H$ being the entropy $H(\mathbf{Q}) = - \sum_{ij} \mathbf{Q}_{ij} \log \mathbf{Q}_{ij}$ and $\varepsilon$ being a hyperparameter of the method.

The trace $\text{Tr}$ is defined to be the sum of the elements on the main diagonal.

A matrix $\mathbf{Q}$ from the set $\mathcal{Q}$ is constrained in three ways:

1. All its entries have to be positive.

2. The sum of each row has to be $1/K$

3. The sum of each column has to be $1/B$.

4. Note that this also implies that the sum of all entries to be $1$, hence these matrices allow for a probabilistic interpretation, for example, w.r.t. Entropy. However, it’s not a stochastic matrix.

A simple matrix in this set is a matrix whose entries are all $1/(BK)$, which corresponds to a uniform distribution over all entries. This matrix maximizes the entropy $H(Q)$.

With a good intuition on the set $\mathcal{Q}$, we can examine the target function.

### Optimal transport without entropy

Ignoring the entropy-term for now, we can go step-by-step through the first term

$\mathbf{Q}^* = \max_{\mathbf{Q} \in \mathcal{Q}} \text{ Tr} (\mathbf{Q}^T \mathbf{C}^T \mathbf{Z})$

Since both $\mathbf{C}$ and $\mathbf{Z}$ are L2 normalized, the matrix product $\mathbf{C}^T \mathbf{Z}$ computes the cosine similarity scores between all possible combinations of feature vectors $\mathbf{z}_1,...,\mathbf{z}_B$ and prototypes $\mathbf{c}_1,...,\mathbf{c}_K$.

$\mathbf{c}^T \mathbf{z} = \begin{bmatrix} c_1\\ c_2 \end{bmatrix} * \begin{bmatrix} z_1 & z_2 & z_3\\ \end{bmatrix} = \begin{bmatrix} c_1 z_1 & c_1 z_2 & c_1 z_3\\ c_2 z_1 & c_2 z_2 & c_2 z_3 \end{bmatrix}$

The first column of $\mathbf{C}^T \mathbf{Z}$ contains the similarity scores for the first feature vector $\mathbf{z}_1$ and all prototypes.

This means that the first diagonal entry of $\mathbf{Q}^T \mathbf{C}^T \mathbf{Z}$ is a weighted sum of the similarity scores of $\mathbf{z}_1$. For 2 prototypes and batch size 3 the first diagonal element will be:

$q_{11} c_1 z_1 + q_{21} c_2 z_1$

While its entropy term will be:

$- \varepsilon [ q_{11} log q_{11} + q_{21} log q_{21} ]$

Similarly, the second diagonal entry of $\mathbf{Q}^T \mathbf{C}^T \mathbf{Z}$ is a weighted sum of the similarity scores for $\mathbf{z}_2$ with different weights.

Doing this for all diagonal entries and taking the sum results in $\textbf{Tr}(\mathbf{Q}^T \mathbf{C}^T \mathbf{Z})$.

Intuition: While the optimal matrix $\mathbf{Q}^*$ is highly non-trivial, it is easy to see that $\mathbf{Q}^*$ will assign large weights to larger similarity scores and small weights to smaller ones while conforming to the row-sum and column-sum constraint.

Based on this design, such a method would be more biased to mode collapse by choosing one prototype than collapsing to a uniform distribution.

Solution? Enforcing entropy to the rescue!

### The entropy constraint

So why do we need the entropy term at all?

Well, while the resulting code vectors $\mathbf{q}_1,...,\mathbf{q}_B$ are already a 'soft' target compared to one-hot vectors (in SimCLR), the addition of the entropy term in the target function gives us control over the smoothness of the solution.

For $\varepsilon \rightarrow \infty$ the solution tends towards the trivial solution where all entries of $\mathbf{Q}$ are $1/(BK)$. Basically, all feature vectors are assigned uniformly to all prototypes.

When $\varepsilon = 0$ we have no smoothness term to further regularize $\mathbf{Q}$.

Finally, small values for $\varepsilon$ result in a slightly smoothed $\mathbf{Q}^*$.

Revisiting the constraints for $\mathcal{Q}$, the row-sum and column-sum constraints imply an equal amount of total weight is assigned to each prototype and each feature vector respectively.

The constraints impose a strong regularization that results in avoiding mode collapse, where all feature vectors are assigned to the same prototype all the time.

### Online estimation of Q* for SWAV

What is left now is to compute $\mathbf{Q}^*$ in every iteration of the training process, which luckily turns out to be very efficient using the results of 3.

Using Lemma 2 from page 5, we know that the solution takes the form:

$\mathbf{Q}^* = \text{Diag}(\mathbf{u}) \exp \bigg(\frac{\mathbf{C}^T\mathbf{Z}}{\varepsilon} \bigg) \text{Diag}(\mathbf{v}),$

where $\mathbf{u}$ and $\mathbf{v}$ act as column and row normalization vectors respectively. An exact computation here is inefficient. However, the $\textbf{Sinkhorn-Knopp}$ algorithm provides a fast, iterative alternative. We can initialize a matrix $\mathbf{Q}$ as the exponential term from $\mathbf{Q}^*$ and then alternate between normalizing the rows and columns of this matrix.

#### Sinkhorn-Knopp Code analysis

Here is the pseudocode, given by the authors on the approximation of Q from the similarity scores:

# Sinkhorn-Knopp
def sinkhorn(scores, eps=0.05, niters=3):
Q = exp(scores / eps).T
Q /= sum(Q)
K, B = Q.shape
u, r, c = zeros(K), ones(K) / K, ones(B) / B
for _ in range(niters):

u = sum(Q, dim=1)

Q *= (r / u).unsqueeze(1) # row norm

Q *= (c / sum(Q, dim=0)).unsqueeze(0) # column norm
return (Q / sum(Q, dim=0, keepdim=True)).T

To approximate $Q$, we take as input only the similarity score matrix $C^T Z$ and output our estimation for $Q$.

## Intuition on the clusters/prototypes

So what is actually learned in these clusters/prototypes?

Well, the prototypes’ main purpose is to summarize the dataset. So SWAV is still a form of contrastive learning. In fact, it can also be interpreted as a way of contrasting image views by comparing their cluster assignments instead of their features.

Ultimately, we contrast with the clusters and not the whole dataset. SimCLR uses batch information, called negative samples, but it is not always representative of the whole dataset. That makes the SWAV objective more tractable.

This can be observed from the experiments. Compared to SimCLR, SWAV pretraining converges faster and is less sensitive to the batch size. Moreover, SWAV is not that sensitive to the number of clusters. Typically 3K clusters are used for ImageNet. In general, it is recommended to use approximately one order of magnitude larger than the real class labels. For STL10 which has 10 classes, 512 clusters would be enough.

## The multi-crop idea: augmenting views with smaller images

Every time I read about contrastive self-supervised learning methods I think, why just 2 views? Well, the obvious question is answered in the SWAV paper also.

To this end, SwAV proposes a multi-crop augmentation strategy where the same image is cropped randomly with 2 global (i.e. 224x224) views and $N=4$ local (i.e. 96x96) views.

As shown below, multi-crop is a very general trick to improve self-supervised learning representations. It can be used out of the box for any method with surprisingly good results ~2% improvement on SimCLR!

The authors also observed that mapping small parts of a scene to more global views significantly boosts the performance.

## Results

To evaluate the learned representation of $f$, the backbone model i.e. Resnet50 is frozen. A single linear layer is trained on top. This is a fair comparison for the learned representations, called linear evaluation. Below are the results of SWAV compared to other state-if-the-art-methods.

(left) Comparison between clustering-based and contrastive instance methods and impact of multi-crop. (right) Performance as a function of epochs. Source: SWAV paper, Caron et al 2020

Left: Classification accuracy on ImageNet is reported. The linear layers are trained on frozen features from different self-supervised methods with a standard ResNet-50. Right: Performance of wide ResNets-50 by factors of 2, 4, and 5.

## Conclusion

In this post an overview of SWAV and its hidden math is provided. We covered the details of optimal transport with and without the entropy constraint. This post would not be possible without the detailed mathematical analysis of Tim.

Finally you can check out this interview on SWAV by its first author (Mathilde Caron).

For further reading, take a look at self-supervised representation learning on videos or SWAV’s experimental report. You can even run your own experiments with the official code if you have a multi-GPU machine!

Finally, I have to say that I’m a bit biased on the work of FAIR on visual self-supervised learning. This team really rocks!

1. SWAV Paper

2. SWAV Code

3. Ref paper on optimal transport

4. SWAV’s Report in WANDB

5. Optimal Transport problem

## Cite as:

@article{kaiser2021swav,    title   = "Understanding SWAV: self-supervised learning with contrasting cluster assignments",    author  = "Kaiser, Tim and Adaloglou, Nikolaos",    journal = "https://theaisummer.com/",    year    = "2021",    howpublished = {https://theaisummer.com/swav/},  }

## Introduction to Deep Learning & Neural Networks

#### You will learn the basics behind CNNs, LSTMs, Autoencoders, GANs, Transformers and Graph Neural Networks using Pytorch in a 100% text-based way.

* Disclosure: Please note that some of the links above might be affiliate links, and at no additional cost to you, we will earn a commission if you decide to make a purchase after clicking through.