深度学习方法——NLLloss简单概括

网上很多文章介绍很详细啊,但是还顺带介绍了各种参数等等,如果只是想要了解他的作用,看着未免太费劲,所以我就一个公式简单总结一下:

\large \mathbf{NLL(log(softmax(input)),target) = -\Sigma_{i=1}^n OneHot(target)_i\times log(softmax(input)_i)}

\large (\mathbf{input\in R^{m\times n}})

不难看出NLLloss+log+softmax就是CrossEntropyLoss(softmax版的交叉熵损失函数),而其中的NLLloss就是在做交叉熵损失函数的最后一步:预测结果的取负求和。

而且它还顺带还帮你省了个OneHot编码,因为它是直接在 log(softmax(input)) 矩阵中,取出每个样本的target值对应的下标位置(该位置在onehot中为1,其余位置在onehot中为0)的预测结果进行取负求和运算。

纯文字的描述又会导致抽象,所以最后上两段代码来说明一下:

import torch
from torch import nn

# NLLLoss+LogSoftmax
# logsoftmax=log(softmax(x))
m = nn.LogSoftmax(dim=1) #横向计算
loss = nn.NLLLoss()
torch.manual_seed(2)
# 3行5列的输入,即3个样本各包含5个特征,每个样本通过softmax产生5个输出
input = torch.randn(3, 5, requires_grad=True)
target = torch.tensor([1, 0, 4])
# NLL将取输出矩阵中第0行的第1列、第1行的第0列、第2行的第4列加负号求和
output = loss(m(input), target)
output
tensor(2.1280, grad_fn=<NllLossBackward0>)
import torch
from torch import nn

# CrossEntropyLoss交叉熵损失
# 等价于NLLloss+log+softmax
loss = nn.CrossEntropyLoss()
torch.manual_seed(2)
input = torch.randn(3, 5, requires_grad=True)
target = torch.tensor([1, 0, 4])
output = loss(input, target)
output
tensor(2.1280, grad_fn=<NllLossBackward0>)

可见交叉熵损失函数的与logsoftmax+NLL损失函数结果一致,我们可以借此理解 CrossEntropyLoss 的底层实现。

来源:时生丶

物联沃分享整理
物联沃-IOTWORD物联网 » 深度学习方法——NLLloss简单概括

发表评论