Loss Weighting in Multi-task Learning

Background
The loss weights are uniform or manually tuned.

GradNorm

GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks
Task imbalances impede proper training because they manifest as imbalances between backpropagated gradients.
Assume the linear form of the loss function.

The multi-task loss formulation: L(t)=\sum w_{i}(t) L_{i}(t)
Notations:
W : The subset of the full network weights W \subset \mathcal{W}, the network weight parameters that updated by GradNorm and is generally chosen as the last shared layer of weights.
G_{W}^{(i)}(t)=\left\|\nabla_{W} w_{i}(t) L_{i}(t)\right\|_{2} :  the L_2 norm of the gradient of the weighted single-task loss w_{i}(t) L_{i}(t) with respect to the chosen weights W.
\bar{G}_{W}(t)=E_{\mathrm{task}}\left[G_{W}^{(i)}(t)\right] : the average gradient norm across all tasks at training time {t}.
\tilde{L}_{i}(t)=L_{i}(t) / L_{i}(0) : the loss ratio for task i at time t, \tilde{L}_{i}(t) is a measure of the inverse training rate of task i (i.e. lower values of \tilde{L}_{i}(t) correspond to a faster training rate for task i).
r_{i}(t)=\tilde{L}_{i}(t) / E_{\text {task }}\left[\tilde{L}_{i}(t)\right] : the relative reverse training rate for task i.
The higher the value of r_{i}(t), the higher the gradient magnitudes should be for task i in order to encourage the task to train more quickly. So the desired or target gradient norm for each task i is :

    \[ G_{W}^{(i)}(t) \mapsto \bar{G}_{W}(t) \times\left[r_{i}(t)\right]^{\alpha} \]

The loss weight w_{i} is updated by minimizing the L_{1} loss  L_{grad} of between the actual and target gradient norms at each time-step for each task, summed over all tasks.

    \[ L_{\mathrm{grad}}\left(t ; w_{i}(t)\right)=\sum_{i}\left|G_{W}^{(i)}(t)-\bar{G}_{W}(t) \times\left[r_{i}(t)\right]^{\alpha}\right|_{1} \]

L_{grad} is differentiated only with respect to the w_{i}, the computed gradients \nabla_{w_{i}} L_{\text {grad }} are then applied via standard update rules to update each w_{i}.

Code implementation

 

PE-LTR

    \[ \mathcal{L}(\boldsymbol{\theta})=\sum_{i=1}^{K} \omega_{i} \mathcal{L}_{i}(\boldsymbol{\theta}) \]

where

    \[ \sum_{i=1}^{K} \omega_{i}=1, \quad \exists \omega_{i} \geq c_{i}, i \in\{1, \ldots, K\} \text { and } \sum_{i=1}^{K} \omega_{i} \nabla_{\boldsymbol{\theta}} \mathcal{L}_{i}(\boldsymbol{\theta})=0 \]

    \[ \begin{aligned} &\min \cdot\left\|\sum_{i=1}^{K} \omega_{i} \nabla_{\boldsymbol{\theta}} \mathcal{L}_{i}(\boldsymbol{\theta})\right\|_{2}^{2} \\ &\text { s.t. } \sum_{i=1}^{K} \omega_{i}=1, \omega_{i} \geq c_{i}, \forall i \in\{1, \ldots, K\} \end{aligned} \]

    \[ \begin{aligned} \min \cdot &\left\|\sum_{i=1}^{K}\left(\hat{\omega}_{i}+c_{i}\right) \nabla_{\boldsymbol{\theta}} \mathcal{L}_{i}(\boldsymbol{\theta})\right\|_{2}^{2} \\ \text { s.t. } & \sum_{i=1}^{K} \hat{\omega}_{i}=1-\sum_{i=1}^{K} c_{i}, \hat{\omega}_{i} \geq 0, \forall i \in\{1, \ldots, K\} \end{aligned} \]

    \[ \min \cdot\left\|\tilde{\omega}-\hat{\omega}^{*}\right\|_{2}^{2} \text { s.t. } \sum_{i=1}^{K} \tilde{\omega}_{i}=1, \tilde{\omega}_{i} \geq 0, \forall i \in\{1, \ldots, K\} \]


Code Implementation

 

PCGrad

 

Uncertainty Weighting

Joint learning the regression and classification tasks with different units and scales.
The optimal weighting loss of each task is dependent on the measurement scale (e.g. meters, centimetres or millimetres) and ultimately the magnitude of the task’s noise.
In the case of multiple model outputs, we often define the likelihood to factorize over the outputs, given some sufficient statistics.
In typical multi-task learning, the total loss is the weighted linear sum of the losses for each task:

    \[ L_{\text {total }}=\sum_{i} w_{i} L_{i} \]

For regression task with outputs \mathrm{f}^{\mathbf{W}}(\mathbf{x}),  the likelihood is defined as Gaussian with mean given by \mathrm{f}^{\mathbf{W}}(\mathbf{x}) and an observation noise scalar \sigma.

    \[ p\left(\mathbf{y} \mid \mathbf{f}^{\mathbf{W}}(\mathbf{x})\right)=\mathcal{N}\left(\mathbf{f}^{\mathbf{W}}(\mathbf{x}), \sigma^{2}\right) \]

For classification task, the model output \mathrm{f}^{\mathbf{W}}(\mathbf{x}) is passed into the Softmax function to get the probabilty.

    \[ p\left(\mathbf{y} \mid \mathbf{f}^{\mathbf{W}}(\mathbf{x})\right)=\operatorname{Softmax}\left(\mathbf{f}^{\mathbf{W}}(\mathbf{x})\right) \]

Generally, the multi-task likelihood is formed by factorizing overing the outputs assuming the \mathrm{f}^{\mathbf{W}}(\mathbf{x}) to be the sufficient statistics, denoted as bellow:

    \[ p\left(\mathbf{y}_{1}, \ldots, \mathbf{y}_{K} \mid \mathbf{f}^{\mathbf{W}}(\mathbf{x})\right)=p\left(\mathbf{y}_{1} \mid \mathbf{f}^{\mathbf{W}}(\mathbf{x})\right) \ldots p\left(\mathbf{y}_{K} \mid \mathbf{f}^{\mathbf{W}}(\mathbf{x})\right) \]

In maximum likelihood inference, we maximize the log likelihood of the model, take the regression task for example:

    \[ \log p\left(\mathbf{y} \mid \mathbf{f}^{\mathbf{W}}(\mathbf{x})\right) \propto-\frac{1}{2 \sigma^{2}}\left\|\mathbf{y}-\mathbf{f}^{\mathbf{w}}(\mathbf{x})\right\|^{2}-\log \sigma \]

    \[ \begin{aligned} p\left(\mathbf{y}_{1}, \mathbf{y}_{2} \mid \mathbf{f}^{\mathbf{W}}(\mathbf{x})\right) &=p\left(\mathbf{y}_{1} \mid \mathbf{f}^{\mathbf{W}}(\mathbf{x})\right) \cdot p\left(\mathbf{y}_{2} \mid \mathbf{f}^{\mathbf{W}}(\mathbf{x})\right) \\ &=\mathcal{N}\left(\mathbf{y}_{1} ; \mathbf{f}^{\mathbf{W}}(\mathbf{x}), \sigma_{1}^{2}\right) \cdot \mathcal{N}\left(\mathbf{y}_{2} ; \mathbf{f}^{\mathbf{W}}(\mathbf{x}), \sigma_{2}^{2}\right) \end{aligned} \]

Minimize objective:

    \[ \begin{aligned} &=-\log p\left(\mathbf{y}_{1}, \mathbf{y}_{2} \mid \mathbf{f}^{\mathbf{W}}(\mathbf{x})\right) \\ &\propto \frac{1}{2 \sigma_{1}^{2}}\left\|\mathbf{y}_{1}-\mathbf{f}^{\mathbf{W}}(\mathbf{x})\right\|^{2}+\frac{1}{2 \sigma_{2}^{2}}\left\|\mathbf{y}_{2}-\mathbf{f}^{\mathbf{W}}(\mathbf{x})\right\|^{2}+\log \sigma_{1} \sigma_{2} \\ &=\frac{1}{2 \sigma_{1}^{2}} \mathcal{L}_{1}(\mathbf{W})+\frac{1}{2 \sigma_{2}^{2}} \mathcal{L}_{2}(\mathbf{W})+\log \sigma_{1} \sigma_{2} \end{aligned} \]

For classification problem,

    \[ p\left(\mathbf{y} \mid \mathbf{f}^{\mathbf{W}}(\mathbf{x}), \sigma\right)=\operatorname{Softmax}\left(\frac{1}{\sigma^{2}} \mathbf{f}^{\mathbf{W}}(\mathbf{x})\right) \]

    \[ \begin{aligned} \log p\left(\mathbf{y}=c \mid \mathbf{f}^{\mathbf{W}}(\mathbf{x}), \sigma\right) &=\frac{1}{\sigma^{2}} f_{c}^{\mathbf{W}}(\mathbf{x}) \\ &-\log \sum_{c^{\prime}} \exp \left(\frac{1}{\sigma^{2}} f_{c^{\prime}}^{\mathbf{W}}(\mathbf{x})\right) \end{aligned} \]

Combination of regression and classification problem:

    \[ \begin{aligned} &=-\log p\left(\mathbf{y}_{1}, \mathbf{y}_{2}=c \mid \mathbf{f}^{\mathbf{W}}(\mathbf{x})\right) \\ &=-\log \mathcal{N}\left(\mathbf{y}_{1} ; \mathbf{f}^{\mathbf{W}}(\mathbf{x}), \sigma_{1}^{2}\right) \cdot \operatorname{Softmax}\left(\mathbf{y}_{2}=c ; \mathbf{f}^{\mathbf{W}}(\mathbf{x}), \sigma_{2}\right) \\ &=\frac{1}{2 \sigma_{1}^{2}}\left\|\mathbf{y}_{1}-\mathbf{f}^{\mathbf{W}}(\mathbf{x})\right\|^{2}+\log \sigma_{1}-\log p\left(\mathbf{y}_{2}=c \mid \mathbf{f}^{\mathbf{W}}(\mathbf{x}), \sigma_{2}\right) \\ &=\frac{1}{2 \sigma_{1}^{2}} \mathcal{L}_{1}(\mathbf{W})+\frac{1}{\sigma_{2}^{2}} \mathcal{L}_{2}(\mathbf{W})+\log \sigma_{1} \\ &\quad+\log \frac{\sum_{c^{\prime}} \exp \left(\frac{1}{\sigma_{2}^{2}} f_{c^{\prime}}^{\mathbf{W}}(\mathbf{x})\right)}{\left(\sum_{c^{\prime}} \exp \left(f_{c^{\prime}}^{\mathbf{W}}(\mathbf{x})\right)\right)^{\frac{1}{\sigma_{2}^{2}}}} \\ &\approx \frac{1}{2 \sigma_{1}^{2}} \mathcal{L}_{1}(\mathbf{W})+\frac{1}{\sigma_{2}^{2}} \mathcal{L}_{2}(\mathbf{W})+\log \sigma_{1}+\log \sigma_{2} \end{aligned} \]

Code implementation

Easy implementation

 

DWA

ref: End-to-End Multi-Task Learning with Attention.
Dynamic Weight Average (DWA): Calculating weight by only considering the numerical losses without accessing the internal gradients of the network like GradNorm. Temperature T controls the softness of the task weighting. A large T results in a more even distribution between different tasks.
Loss formulation:

    \[ \mathcal{L}_{t o t}\left(\mathbf{X}, \mathbf{Y}_{1: K}\right)=\sum_{i=1}^{K} \lambda_{i} \mathcal{L}_{i}\left(\mathbf{X}, \mathbf{Y}_{i}\right) \]

Weight update formula:

    \[ \lambda_{k}(t):=\frac{K \exp \left(w_{k}(t-1) / T\right)}{\sum_{i} \exp \left(w_{i}(t-1) / T\right)}, w_{k}(t-1)=\frac{\mathcal{L}_{k}(t-1)}{\mathcal{L}_{k}(t-2)} \]

Code implementation:

 

CoV-Weighting

We generally consider the loss being satisfied when its variance has decreased towards zeros(or being a constant), and the network can not be optimized any more. However, variance alone may not be suitable in the multi-task scenario, for the loss with a larger magnitude may have a larger variance and vice versa. Accordingly,  the author proposed using coefficient of variation c_{\mathcal{L}} of loss {\mathcal{L}}, which shows the variability of the observed loss in relation to the (observed) mean:

    \[ c_{\mathcal{L}}=\frac{\sigma_{\mathcal{L}}}{\mu_{\mathcal{L}}} \]

Instead of using the loss value itself, the loss ratio used in many literatures can also be considered as a measurement (two approaches, loss or loss-ratio):

    \[ \ell_{t}=\frac{\mathcal{L}_{t}}{\mu_{\mathcal{L}_{t-1}}} \]

The weight is based on the coefficient of variation of the loss-ratio \alpha_{i t} for loss {\mathcal{L}_{i}} at time step t:

    \[ \alpha_{i t}=\frac{1}{z_{t}} c_{\ell_{i t}}=\frac{1}{z_{t}} \frac{\sigma_{\ell_{i t}}}{\mu_{\ell_{i t}}} \]

Online estimation of the loss-ratio and the coefficient of variation using Welford’s algorithm:

    \[ \begin{aligned} \mu_{\mathcal{L}_{t}} &=\left(1-\frac{1}{t}\right) \mu_{\mathcal{L}_{t-1}}+\frac{1}{t} \mathcal{L}_{t}, \\ \mu_{\ell_{t}} &=\left(1-\frac{1}{t}\right) \mu_{\ell_{t-1}}+\frac{1}{t} \ell_{t}, \quad \text { and } \\ \boldsymbol{M}_{\ell_{t}} &=\left(1-\frac{1}{t}\right) \boldsymbol{M}_{\ell_{t-1}}+\frac{1}{t}\left(\ell_{t}-\mu_{\ell_{t-1}}\right)\left(\ell_{t}-\mu_{\ell_{t}}\right), \end{aligned} \]

The standard deviation is given by \sigma_{\ell_{t}}=\sqrt{\boldsymbol{M}_{\ell_{t}}}. Assuming converging losses and ample training iterations, the online mean and standard deviation converge to the true mean and standard deviation of the observed losses over the data. One drawback of this approach is the variance is smoothed out overtime, so decaying online estimate may be needed.

 

HMoE

Combination of All above (self research)

Leave a Reply

Your email address will not be published. Required fields are marked *