Viet-Anh on Software Logo

What is: Temporal Distribution Matching?

SourceAdaRNN: Adaptive Learning and Forecasting of Time Series
Year2000
Data SourceCC BY-SA - https://paperswithcode.com

Temporal Distribution Matching, or TDM, is a module used in the AdaRNN architecture to match the distributions of the discovered periods to build a time series prediction model M\mathcal{M} Given the learned time periods, the TDM module is designed to learn the common knowledge shared by different periods via matching their distributions. Thus, the learned model M\mathcal{M} is expected to generalize well on unseen test data compared with the methods which only rely on local or statistical information.

Within the context of AdaRNN, Temporal Distribution Matching aims to adaptively match the distributions between the RNN cells of two periods while capturing the temporal dependencies. TDM introduces the importance vector αRV^\mathbf{\alpha} \in \mathbb{R}^{\hat{V}} to learn the relative importance of VV hidden states inside the RNN, where all the hidden states are weighted with a normalized α\alpha. Note that for each pair of periods, there is an α\mathbf{\alpha}, and we omit the subscript if there is no confusion. In this way, we can dynamically reduce the distribution divergence of cross-periods.

Given a period-pair (D_i,D_j)\left(\mathcal{D}\_{i}, \mathcal{D}\_{j}\right), the loss of temporal distribution matching is formulated as:

L_tdm(D_i,D_j;θ)=t=1Vα_i,jtd(h_it,h_jt;θ)\mathcal{L}\_{t d m}\left(\mathcal{D}\_{i}, \mathcal{D}\_{j} ; \theta\right)=\sum_{t=1}^{V} \alpha\_{i, j}^{t} d\left(\mathbf{h}\_{i}^{t}, \mathbf{h}\_{j}^{t} ; \theta\right)

where α_i,jt\alpha\_{i, j}^{t} denotes the distribution importance between the periods D_i\mathcal{D}\_{i} and D_j\mathcal{D}\_{j} at state tt.

All the hidden states of the RNN can be easily computed by following the standard RNN computation. Denote by δ()\delta(\cdot) the computation of a next hidden state based on a previous state. The state computation can be formulated as

h_it=δ(x_it,h_it1)\mathbf{h}\_{i}^{t}=\delta\left(\mathbf{x}\_{i}^{t}, \mathbf{h}\_{i}^{t-1}\right)

The final objective of temporal distribution matching (one RNN layer) is:

L(θ,α)=L_pred (θ)+λ2K(K1)_i,jijL_tdm(D_i,D_j;θ,α)\mathcal{L}(\theta, \mathbf{\alpha})=\mathcal{L}\_{\text {pred }}(\theta)+\lambda \frac{2}{K(K-1)} \sum\_{i, j}^{i \neq j} \mathcal{L}\_{t d m}\left(\mathcal{D}\_{i}, \mathcal{D}\_{j} ; \theta, \mathbf{\alpha}\right)

where λ\lambda is a trade-off hyper-parameter. Note that in the second term, we compute the average of the distribution distances of all pairwise periods. For computation, we take a mini-batch of Di\mathcal{D}_{i} and D_j\mathcal{D}\_{j} to perform forward operation in RNN layers and concatenate all hidden features. Then, we can perform TDM using the above equation.