New Course: Introduction to Deep Learning and Neural Networks

Learn more

Understanding SWAV: self-supervised learning with contrasting cluster assignments

Tim Kaiser,Nikolas Adaloglouon2021-10-07·7 mins
Unsupervised Learning

Self-supervised learning aims to extract representation from unsupervised visual data and it’s super famous in computer vision nowadays. This article covers the SWAV method, a robust self-supervised learning paper from a mathematical perspective. To that end, we provide insights and intuitions for why this method works. Additionally, we will discuss the optimal transport problem with entropy constraint and its fast approximation that is a key point of the SWAV method that is hidden when you read the paper.

In any case, if you want to learn more about general aspects of self-supervised learning, like augmentation, intuitions, softmax with temperature, and contrastive learning, consult our previous article.

SWAV Method overview

Definitions

Let two image features zt\mathbf{z}_t and zs\mathbf{z}_s be two different augmentations of the same image. The image features are generated by taking stochastic augmentations tTt \sim T of the same image X\mathbf{X}.

image-augmentations-creation Source: BYOL

  • Our actual targets: Let qt\mathbf{q}_t and qs\mathbf{q}_s be the respective codes of the image views. Codes can be regarded as a soft class of the image.

  • Prototypes: consider a set of KK prototypes c1,..,cK{\mathbf{c}_1, .., \mathbf{c}_K} lying in the unit sphere. The prototypes are trainable vectors that will move based on the dataset’s over-represented (frequent) features. If the dataset consists only of cars, then it will be the most part of a car like a wheel, car windows, car lights, mirrors etc. One way to think about it is as a low-dimensional projection of the whole dataset.

swav-overview-definitions Source: SWAV paper, Caron et al 2020

Clusters and prototypes are used interchangeably throughout this article. Don’t confuse it with “codes” though! Nonetheless, codes and assignments are also used interchangeably.

SWAV 1 compares the features zt\mathbf{z}_t and zs\mathbf{z}_s using the intermediate codes (soft classes) qt\mathbf{q}_t and qs\mathbf{q}_s. For now, ignore how we compute the codes. Keep it as a target in a standard supervised classification problem.

Intuition: If zt\mathbf{z}_t and zs\mathbf{z}_s capture similar information, we can predict the code qs\mathbf{q}_s (soft class) from the other feature zt\mathbf{z}_t. In other words, if the two views share the same semantics, their targets (codes) will be similar. This is the whole “swapping” idea.

Difference between SWAV and SimCLR

In contrastive learning methods, the features from different transformations of the same images are compared directly to each other. SWAV does not directly compare image features. Why?

In SwAV, there is the intermediate “codes'' step (QQ). To create the codes (targets), we need to assign the image features to prototype vectors. We then solve a “swapped'' prediction problem wherein the codes (targets) are altered for the two image views.

swav-vs-simclr Source: SWAV paper, Caron et al 2020

Prototype vectors c1,..,cK{\mathbf{c}_1, .., \mathbf{c}_K} are learned but they are still in the unit sphere area, meaning their L2 norm will be 1.

The unit sphere and its implications

By definition, a unit sphere is the set of points with L2 distance equal to 1 from a fixed central point, here the origin. Note that this is different from a unit ball, where the L2 distance is less than or equal to 1 from the centre.

Moving on the surface of the sphere corresponds to a smooth change in assignments. In fact many self-supervised methods are using this L2-norm trick, and especially contrastive methods. SWAV also applies L2-normalization to the features as well as to the prototypes throughout training.

SWAV method Steps

Let’s recap the steps of SWAV:

  1. Create NN views from input image XX using a set of stochastic transformations TT

  2. Calculate image feature representations zz

  3. Calculate softmax-normalized similarities between all zz and cc: softmax(zTc)softmax(z^T c)

  4. Calculate code matrix QQ iteratively. We intentionally ignored this part. See further on for this step.

  5. Calculate cross-entropy loss between representation tt, aka ztz_t and the code of representation ss, aka qsq_s

  6. Average loss between all NN views.

method-swav-overview-unit-sphere Source: SWAV paper, Caron et al 2020

Again, notice the difference between cluster assignments (codes) and cluster prototype vectors (cc). Here is a detailed explanation of the loss function:

Digging into SWAV’s math: approximating Q

Understanding the Optimal Transport problem with Entropy

As discussed, the code vectors q1,...,qB\mathbf{q}_1,...,\mathbf{q}_B act as a target in the cross-entropy loss term. In SWAV, these code vectors are computed online during every iteration. Online means that we approximate Q\mathbf{Q} in each forward pass by an iterative process. No gradients and backprop to estimate Q\mathbf{Q}.

For KK prototypes and batch size BB, the optimal code matrix Q=[q1,...,qB]RK×B\mathbf{Q} = [\mathbf{q}_1,...,\mathbf{q}_B] \in R^{K \times B} is defined as the solution to an optimal transport problem with entropic constraint. The solution is approximated using the iterative Sinkhorn-Knopp algorithm 3.

For a formal and very detailed formulation, analysis and solution of said problem, I recommend having a look at the paper 3.

For SwAV, we define the optimal code matrix mathbfQmathbf{Q} as:

Q=maxQQTr(QTCTZ)+εH(Q),\mathbf{Q}^* = \max_{\mathbf{Q} \in \mathcal{Q}} \text{Tr} (\mathbf{Q}^T \mathbf{C}^T \mathbf{Z}) + \varepsilon H(\mathbf{Q}),
Q={QR+K×BQ1B=1K1K,QT1K=1B1B},\mathcal{Q} = \big\{ \mathbf{Q} \in \mathbb{R}^{K \times B}_+ | \mathbf{Q} \mathbf{1}_B = \frac{1}{K} \mathbf{1}_K, \mathbf{Q}^T \mathbf{1}_K = \frac{1}{B} \mathbf{1}_B \big\},

with HH being the entropy H(Q)=ijQijlogQijH(\mathbf{Q}) = - \sum_{ij} \mathbf{Q}_{ij} \log \mathbf{Q}_{ij} and ε\varepsilon being a hyperparameter of the method.

The trace Tr\text{Tr} is defined to be the sum of the elements on the main diagonal.

A matrix Q\mathbf{Q} from the set Q\mathcal{Q} is constrained in three ways:

  1. All its entries have to be positive.

  2. The sum of each row has to be 1/K1/K

  3. The sum of each column has to be 1/B1/B.

  4. Note that this also implies that the sum of all entries to be 11, hence these matrices allow for a probabilistic interpretation, for example, w.r.t. Entropy. However, it’s not a stochastic matrix.

A simple matrix in this set is a matrix whose entries are all 1/(BK)1/(BK), which corresponds to a uniform distribution over all entries. This matrix maximizes the entropy H(Q)H(Q).

With a good intuition on the set Q\mathcal{Q}, we can examine the target function.

Optimal transport without entropy

Ignoring the entropy-term for now, we can go step-by-step through the first term

Q=maxQQ Tr(QTCTZ)\mathbf{Q}^* = \max_{\mathbf{Q} \in \mathcal{Q}} \text{ Tr} (\mathbf{Q}^T \mathbf{C}^T \mathbf{Z})

Since both C\mathbf{C} and Z\mathbf{Z} are L2 normalized, the matrix product CTZ\mathbf{C}^T \mathbf{Z} computes the cosine similarity scores between all possible combinations of feature vectors z1,...,zB\mathbf{z}_1,...,\mathbf{z}_B and prototypes c1,...,cK\mathbf{c}_1,...,\mathbf{c}_K.

cTz=[c1c2][z1z2z3]=[c1z1c1z2c1z3c2z1c2z2c2z3] \mathbf{c}^T \mathbf{z} = \begin{bmatrix} c_1\\ c_2 \end{bmatrix} * \begin{bmatrix} z_1 & z_2 & z_3\\ \end{bmatrix} = \begin{bmatrix} c_1 z_1 & c_1 z_2 & c_1 z_3\\ c_2 z_1 & c_2 z_2 & c_2 z_3 \end{bmatrix}

The first column of CTZ\mathbf{C}^T \mathbf{Z} contains the similarity scores for the first feature vector z1\mathbf{z}_1 and all prototypes.

This means that the first diagonal entry of QTCTZ\mathbf{Q}^T \mathbf{C}^T \mathbf{Z} is a weighted sum of the similarity scores of z1\mathbf{z}_1. For 2 prototypes and batch size 3 the first diagonal element will be:

q11c1z1+q21c2z1q_{11} c_1 z_1 + q_{21} c_2 z_1

While its entropy term will be:

ε[q11logq11+q21logq21]- \varepsilon [ q_{11} log q_{11} + q_{21} log q_{21} ]

Similarly, the second diagonal entry of QTCTZ\mathbf{Q}^T \mathbf{C}^T \mathbf{Z} is a weighted sum of the similarity scores for z2\mathbf{z}_2 with different weights.

Doing this for all diagonal entries and taking the sum results in Tr(QTCTZ)\textbf{Tr}(\mathbf{Q}^T \mathbf{C}^T \mathbf{Z}).

Intuition: While the optimal matrix Q\mathbf{Q}^* is highly non-trivial, it is easy to see that Q\mathbf{Q}^* will assign large weights to larger similarity scores and small weights to smaller ones while conforming to the row-sum and column-sum constraint.

Based on this design, such a method would be more biased to mode collapse by choosing one prototype than collapsing to a uniform distribution.

Solution? Enforcing entropy to the rescue!

The entropy constraint

So why do we need the entropy term at all?

Well, while the resulting code vectors q1,...,qB\mathbf{q}_1,...,\mathbf{q}_B are already a 'soft' target compared to one-hot vectors (in SimCLR), the addition of the entropy term in the target function gives us control over the smoothness of the solution.

For ε\varepsilon \rightarrow \infty the solution tends towards the trivial solution where all entries of Q\mathbf{Q} are 1/(BK)1/(BK). Basically, all feature vectors are assigned uniformly to all prototypes.

When ε=0\varepsilon = 0 we have no smoothness term to further regularize Q\mathbf{Q}.

Finally, small values for ε\varepsilon result in a slightly smoothed Q\mathbf{Q}^*.

Revisiting the constraints for Q\mathcal{Q}, the row-sum and column-sum constraints imply an equal amount of total weight is assigned to each prototype and each feature vector respectively.

The constraints impose a strong regularization that results in avoiding mode collapse, where all feature vectors are assigned to the same prototype all the time.

Online estimation of Q* for SWAV

What is left now is to compute Q\mathbf{Q}^* in every iteration of the training process, which luckily turns out to be very efficient using the results of 3.

Using Lemma 2 from page 5, we know that the solution takes the form:

Q=Diag(u)exp(CTZε)Diag(v),\mathbf{Q}^* = \text{Diag}(\mathbf{u}) \exp \bigg(\frac{\mathbf{C}^T\mathbf{Z}}{\varepsilon} \bigg) \text{Diag}(\mathbf{v}),

where u\mathbf{u} and v\mathbf{v} act as column and row normalization vectors respectively. An exact computation here is inefficient. However, the Sinkhorn-Knopp\textbf{Sinkhorn-Knopp} algorithm provides a fast, iterative alternative. We can initialize a matrix Q\mathbf{Q} as the exponential term from Q\mathbf{Q}^* and then alternate between normalizing the rows and columns of this matrix.

Sinkhorn-Knopp Code analysis

Here is the pseudocode, given by the authors on the approximation of Q from the similarity scores:

# Sinkhorn-Knopp
def sinkhorn(scores, eps=0.05, niters=3):
Q = exp(scores / eps).T
Q /= sum(Q)
K, B = Q.shape
u, r, c = zeros(K), ones(K) / K, ones(B) / B
for _ in range(niters):
u = sum(Q, dim=1)
Q *= (r / u).unsqueeze(1) # row norm
Q *= (c / sum(Q, dim=0)).unsqueeze(0) # column norm
return (Q / sum(Q, dim=0, keepdim=True)).T

To approximate QQ, we take as input only the similarity score matrix CTZC^T Z and output our estimation for QQ.

Intuition on the clusters/prototypes

So what is actually learned in these clusters/prototypes?

Well, the prototypes’ main purpose is to summarize the dataset. So SWAV is still a form of contrastive learning. In fact, it can also be interpreted as a way of contrasting image views by comparing their cluster assignments instead of their features.

Ultimately, we contrast with the clusters and not the whole dataset. SimCLR uses batch information, called negative samples, but it is not always representative of the whole dataset. That makes the SWAV objective more tractable.

This can be observed from the experiments. Compared to SimCLR, SWAV pretraining converges faster and is less sensitive to the batch size. Moreover, SWAV is not that sensitive to the number of clusters. Typically 3K clusters are used for ImageNet. In general, it is recommended to use approximately one order of magnitude larger than the real class labels. For STL10 which has 10 classes, 512 clusters would be enough.

The multi-crop idea: augmenting views with smaller images

Every time I read about contrastive self-supervised learning methods I think, why just 2 views? Well, the obvious question is answered in the SWAV paper also.

Multi-crop Multi-crop. Source: SWAV paper, Caron et al 2020

To this end, SwAV proposes a multi-crop augmentation strategy where the same image is cropped randomly with 2 global (i.e. 224x224) views and N=4N=4 local (i.e. 96x96) views.

As shown below, multi-crop is a very general trick to improve self-supervised learning representations. It can be used out of the box for any method with surprisingly good results ~2% improvement on SimCLR!

multi-crop-comprarison Source: SWAV paper, Caron et al 2020

The authors also observed that mapping small parts of a scene to more global views significantly boosts the performance.

Results

To evaluate the learned representation of ff, the backbone model i.e. Resnet50 is frozen. A single linear layer is trained on top. This is a fair comparison for the learned representations, called linear evaluation. Below are the results of SWAV compared to other state-if-the-art-methods.

swav-results (left) Comparison between clustering-based and contrastive instance methods and impact of multi-crop. (right) Performance as a function of epochs. Source: SWAV paper, Caron et al 2020

Left: Classification accuracy on ImageNet is reported. The linear layers are trained on frozen features from different self-supervised methods with a standard ResNet-50. Right: Performance of wide ResNets-50 by factors of 2, 4, and 5.

Conclusion

In this post an overview of SWAV and its hidden math is provided. We covered the details of optimal transport with and without the entropy constraint. This post would not be possible without the detailed mathematical analysis of Tim.

Finally you can check out this interview on SWAV by its first author (Mathilde Caron).

For further reading, take a look at self-supervised representation learning on videos or SWAV’s experimental report. You can even run your own experiments with the official code if you have a multi-GPU machine!

Finally, I have to say that I’m a bit biased on the work of FAIR on visual self-supervised learning. This team really rocks!

References

  1. SWAV Paper

  2. SWAV Code

  3. Ref paper on optimal transport

  4. SWAV’s Report in WANDB

  5. Optimal Transport problem

Cite as:

@article{kaiser2021swav,
title = "Understanding SWAV: self-supervised learning with contrasting cluster assignments",
author = "Kaiser, Tim and Adaloglou, Nikolaos",
journal = "https://theaisummer.com/",
year = "2021",
howpublished = {https://theaisummer.com/swav/},
}

Introduction to Deep Learning & Neural Networks

For a more comprehensive understanding of the fundamental archutectures of Deep Learning, check out our interactive course.

You will learn the basics behind CNNs, LSTMs, Autoencoders, GANs, Transformers and Graph Neural Networks using Pytorch in a 100% text-based way.

Learn more

* Disclosure: Please note that some of the links above might be affiliate links, and at no additional cost to you, we will earn a commission if you decide to make a purchase after clicking through.