Viet-Anh on Software Logo

What is: Channel-wise Cross Attention?

SourceUCTransNet: Rethinking the Skip Connections in U-Net from a Channel-wise Perspective with Transformer
Year2000
Data SourceCC BY-SA - https://paperswithcode.com

Channel-wise Cross Attention is a module for semantic segmentation used in the UCTransNet architecture. It is used to fuse features of inconsistent semantics between the Channel Transformer and U-Net decoder. It guides the channel and information filtration of the Transformer features and eliminates the ambiguity with the decoder features.

Mathematically, we take the ii-th level Transformer output O_iRC×H×W\mathbf{O\_{i}} \in \mathbb{R}^{C×H×W} and i-th level decoder feature map D_iRC×H×W\mathbf{D\_{i}} \in \mathbb{R}^{C×H×W} as the inputs of Channel-wise Cross Attention. Spatial squeeze is performed by a global average pooling (GAP) layer, producing vector G(X)RC×1×1\mathcal{G}\left(\mathbf{X}\right) \in \mathbb{R}^{C×1×1} with its kkth channel G(X)=1H×WH_i=1W_j=1Xk(i,j)\mathcal{G}\left(\mathbf{X}\right) = \frac{1}{H×W}\sum^{H}\_{i=1}\sum^{W}\_{j=1}\mathbf{X}^{k}\left(i, j\right). We use this operation to embed the global spatial information and then generate the attention mask:

M_i=L_1G(O_i)+L_2G(D_i)\mathbf{M}\_{i} = \mathbf{L}\_{1} \cdot \mathcal{G}\left(\mathbf{O\_{i}}\right) + \mathbf{L}\_{2} \cdot \mathcal{G}\left(\mathbf{D}\_{i}\right)

where L_1RC×C\mathbf{L}\_{1} \in \mathbb{R}^{C×C} and L_2RC×C\mathbf{L}\_{2} \in \mathbb{R}^{C×C} and being weights of two Linear layers and the ReLU operator δ()\delta\left(\cdot\right). This operation in the equation above encodes the channel-wise dependencies. Following ECA-Net which empirically showed avoiding dimensionality reduction is important for learning channel attention, the authors use a single Linear layer and sigmoid function to build the channel attention map. The resultant vector is used to recalibrate or excite O_i\mathbf{O\_{i}} to Oˉ_i=σ(M_i)O_i\mathbf{\bar{O}\_{i}} = \sigma\left(\mathbf{M\_{i}}\right) \cdot \mathbf{O\_{i}}, where the activation σ(M_i)\sigma\left(\mathbf{M\_{i}}\right) indicates the importance of each channel. Finally, the masked Oˉ_i\mathbf{\bar{O}}\_{i} is concatenated with the up-sampled features of the ii-th level decoder.