What is: SimCLR?
Source | A Simple Framework for Contrastive Learning of Visual Representations |
Year | 2000 |
Data Source | CC BY-SA - https://paperswithcode.com |
SimCLR is a framework for contrastive learning of visual representations. It learns representations by maximizing agreement between differently augmented views of the same data example via a contrastive loss in the latent space. It consists of:
-
A stochastic data augmentation module that transforms any given data example randomly resulting in two correlated views of the same example, denoted and , which is considered a positive pair. SimCLR sequentially applies three simple augmentations: random cropping followed by resize back to the original size, random color distortions, and random Gaussian blur. The authors find random crop and color distortion is crucial to achieve good performance.
-
A neural network base encoder that extracts representation vectors from augmented data examples. The framework allows various choices of the network architecture without any constraints. The authors opt for simplicity and adopt ResNet to obtain where is the output after the average pooling layer.
-
A small neural network projection head that maps representations to the space where contrastive loss is applied. Authors use a MLP with one hidden layer to obtain where is a ReLU nonlinearity. The authors find it beneficial to define the contrastive loss on ’s rather than ’s.
-
A contrastive loss function defined for a contrastive prediction task. Given a set {} including a positive pair of examples and , the contrastive prediction task aims to identify in {} for a given .
A minibatch of examples is randomly sampled and the contrastive prediction task is defined on pairs of augmented examples derived from the minibatch, resulting in data points. Negative examples are not sampled explicitly. Instead, given a positive pair, the other augmented examples within a minibatch are treated as negative examples. A NT-Xent (the normalized temperature-scaled cross entropy loss) loss function is used (see components).