您现在的位置是:首页 >技术杂谈 >MTN模型LOSS均衡相关论文解读网站首页技术杂谈
MTN模型LOSS均衡相关论文解读
一、综述
MTN模型主要用于两个方面,1.将多个模型合为一个显著降低车载芯片负载。2.将多个任务模型合为一个,有助于不同模型在共享层的特征可以进行互补,提高模型泛化性能的同时,也有可能提高指标。传统的方法是直接不同任务loss相加或者人为设置权重,这样很费时,也很难找到最优解。接下来的论文将会为大家介绍一些更优秀的MTN方法。
二、依据任务不确定性加权多任务损失
论文地址: https://arxiv.org/pdf/1705.07115.pdf
文章通过对损失函数求最大似然估计,同时引入不同任务的不确定性(可以理解为噪声),最大似然估计的推理结果如下:
1.两个任务都为回归任务
2.一个任务为回归任务,一个任务为分类任务
可以看到损失是由不确定性估计的倒数来加权的,后面的log(不确定性)是为了防止不确定性变得太大(类似于正则项)。当模型不确定性变小后,任务权重会增大,造成无效学习,所以论文里使用的(annealing the lr with a power law) 不会翻译。。。。。同时论文里也表明这个不确定性的初始化很鲁棒,都可以收敛的很好。
代码:
log_vars = nn.Parameter(torch.zeros((2)))
def criterion(y_pred, y_true, log_vars):
loss = 0
for i in range(len(y_pred)):
weight = torch.exp(-log_vars[i])
diff = (y_pred[i]-y_true[i])**2.
loss += torch.sum(weight * diff + log_vars[i], -1)
return torch.mean(loss)
其中diff表示一个任务的loss,log_vars是可学习的参数。权重为weight = torch.exp(-log_vars[i]),后面加个log_vars[i]是一个惩罚项,防止任务不确定性变得太大,导致权重很小不更新。大家跟论文里对比一下会发现,模型学习参数log_vars[i] = log(不确定性**2),使用log是为了让其更加平滑稳定,便于学习。
在实际应用中,训mtn有的时候loss会变负,就是因为log_vars有可能为负的,把loss带成负的了。
所以后续有论文对其进行了改进,论文地址:https://arxiv.org/pdf/1805.06334.pdf
将正则项变为log(1+log_vars**2),这样正则项就不会为负,同时还能起到正则化效果。具体代码如下:
class AutomaticWeightedLoss(nn.Module):
def __init__(self, num=2):
super(AutomaticWeightedLoss, self).__init__()
params = torch.ones(num, requires_grad=True)
self.params = torch.nn.Parameter(params)
def forward(self, *x):
loss_sum = 0
for i, loss in enumerate(x):
loss_sum += 0.5 / (self.params[i] ** 2) * loss + torch.log(1 + self.params[i] ** 2)
return loss_sum