您现在的位置是:首页 >技术杂谈 >MTN模型LOSS均衡相关论文解读网站首页技术杂谈

MTN模型LOSS均衡相关论文解读

CVplayer111 2024-06-15 06:31:11
简介MTN模型LOSS均衡相关论文解读

一、综述

MTN模型主要用于两个方面,1.将多个模型合为一个显著降低车载芯片负载。2.将多个任务模型合为一个,有助于不同模型在共享层的特征可以进行互补,提高模型泛化性能的同时,也有可能提高指标。传统的方法是直接不同任务loss相加或者人为设置权重,这样很费时,也很难找到最优解。接下来的论文将会为大家介绍一些更优秀的MTN方法。

二、依据任务不确定性加权多任务损失

论文地址: https://arxiv.org/pdf/1705.07115.pdf

 文章通过对损失函数求最大似然估计,同时引入不同任务的不确定性(可以理解为噪声),最大似然估计的推理结果如下:

1.两个任务都为回归任务

 2.一个任务为回归任务,一个任务为分类任务

 可以看到损失是由不确定性估计的倒数来加权的,后面的log(不确定性)是为了防止不确定性变得太大(类似于正则项)。当模型不确定性变小后,任务权重会增大,造成无效学习,所以论文里使用的(annealing the lr with a power law) 不会翻译。。。。。同时论文里也表明这个不确定性的初始化很鲁棒,都可以收敛的很好。

代码:

log_vars = nn.Parameter(torch.zeros((2)))

def criterion(y_pred, y_true, log_vars):
  loss = 0
  for i in range(len(y_pred)):
    weight = torch.exp(-log_vars[i])
    diff = (y_pred[i]-y_true[i])**2.
    loss += torch.sum(weight * diff + log_vars[i], -1)
return torch.mean(loss)

其中diff表示一个任务的loss,log_vars是可学习的参数。权重为weight = torch.exp(-log_vars[i]),后面加个log_vars[i]是一个惩罚项,防止任务不确定性变得太大,导致权重很小不更新。大家跟论文里对比一下会发现,模型学习参数log_vars[i] = log(不确定性**2),使用log是为了让其更加平滑稳定,便于学习。

在实际应用中,训mtn有的时候loss会变负,就是因为log_vars有可能为负的,把loss带成负的了。

所以后续有论文对其进行了改进,论文地址:https://arxiv.org/pdf/1805.06334.pdf

 将正则项变为log(1+log_vars**2),这样正则项就不会为负,同时还能起到正则化效果。具体代码如下:

class AutomaticWeightedLoss(nn.Module):
    def __init__(self, num=2):
        super(AutomaticWeightedLoss, self).__init__()
        params = torch.ones(num, requires_grad=True)
        self.params = torch.nn.Parameter(params)

    def forward(self, *x):
        loss_sum = 0
        for i, loss in enumerate(x):
            loss_sum += 0.5 / (self.params[i] ** 2) * loss + torch.log(1 + self.params[i] ** 2)
        return loss_sum

Reference 

多任务权重自动学习论文介绍和代码实现 - 知乎 (zhihu.com)

风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。