Viet-Anh on Software Logo

What is: LayerScale?

SourceGoing deeper with Image Transformers
Year2000
Data SourceCC BY-SA - https://paperswithcode.com

LayerScale is a method used for vision transformer architectures to help improve training dynamics. It adds a learnable diagonal matrix on output of each residual block, initialized close to (but not at) 0. Adding this simple layer after each residual block improves the training dynamic, allowing for the training of deeper high-capacity image transformers that benefit from depth.

Specifically, LayerScale is a per-channel multiplication of the vector produced by each residual block, as opposed to a single scalar, see Figure (d). The objective is to group the updates of the weights associated with the same output channel. Formally, LayerScale is a multiplication by a diagonal matrix on output of each residual block. In other words:

x_l=x_l+diag(λ_l,1,,λ_l,d)×SA(η(x_l))x\_{l}^{\prime} =x\_{l}+\operatorname{diag}\left(\lambda\_{l, 1}, \ldots, \lambda\_{l, d}\right) \times \operatorname{SA}\left(\eta\left(x\_{l}\right)\right)
x_l+1=x_l+diag(λ_l,1,,λ_l,d)×FFN(η(x_l))x\_{l+1} =x\_{l}^{\prime}+\operatorname{diag}\left(\lambda\_{l, 1}^{\prime}, \ldots, \lambda\_{l, d}^{\prime}\right) \times \operatorname{FFN}\left(\eta\left(x\_{l}^{\prime}\right)\right)

where the parameters λ_l,i\lambda\_{l, i} and λ_l,i\lambda\_{l, i}^{\prime} are learnable weights. The diagonal values are all initialized to a fixed small value ε:\varepsilon: we set it to ε=0.1\varepsilon=0.1 until depth 18 , ε=105\varepsilon=10^{-5} for depth 24 and ε=106\varepsilon=10^{-6} for deeper networks.

This formula is akin to other normalization strategies ActNorm or LayerNorm but executed on output of the residual block. Yet LayerScale seeks a different effect: ActNorm is a data-dependent initialization that calibrates activations so that they have zero-mean and unit variance, like BatchNorm. In contrast, in LayerScale, we initialize the diagonal with small values so that the initial contribution of the residual branches to the function implemented by the transformer is small. In that respect the motivation is therefore closer to that of ReZero, SkipInit, Fixup and T-Fixup: to train closer to the identity function and let the network integrate the additional parameters progressively during the training. LayerScale offers more diversity in the optimization than just adjusting the whole layer by a single learnable scalar as in ReZero/SkipInit, Fixup and T-Fixup.