crossentropy java_Pytorch中的CrossEntropyLoss()函数案例解读和结合one-hot编码计算Loss

使用Pytorch框架进行深度学习任务,特别是分类任务时,经常会用到如下:

import torch.nn as nn

criterion = nn.CrossEntropyLoss().cuda()

loss = criterion(output, target)

即使用torch.nn.CrossEntropyLoss()作为损失函数。

那nn.CrossEntropyLoss()内部到底是啥??

nn.CrossEntropyLoss()是torch.nn中包装好的一个类,对应torch.nn.functional中的cross_entropy。

此外,nn.CrossEntropyLoss()是nn.logSoftmax()和nn.NLLLoss()的整合(将两者结合到一个类中)。

nn.logSoftmax()

定义如下:

7d8942e3e6b4ca38b83ef29daf85219c.png

从公式看,其实就是先softmax在log。

nn.NLLLoss()

定义如下:

52643a9a223cf5665da1304a0295fb35.png

此loss期望的target是类别的索引 (0 to N-1, where N = number of classes)。

例子1:

import torch.nn as nn

m = nn.LogSoftmax()

loss = nn.NLLLoss()

# input is of size nBatch x nClasses = 3 x 5

input = autograd.Variable(torch.randn(3, 5), requires_grad=True)

# each element in target has to have 0 <= value < nclasses

target = autograd.Variable(torch.LongTensor([1, 0, 4]))

output = loss(m(input), target)

可以看到,nn.NLLLoss的输入target是类别值,并不是one-hot编码格式,这个要注意!!

nn.CrossEntropyLoss()

定义如下:

813f736e367734d1764a9f21999a7021.png

仔细看看公式,发现其实它就是nn.LogSoftmax() + nn.NLLLoss()

调用时输入参数如下:

input : 模型输出,包含每个类的得分,2-D tensor,shape为 batch * n类

target: 大小为 n 的 1—D tensor,包含类别的索引(0到 n-1)。

注意CrossEntropyLoss()的target输入也是类别值,不是one-hot编码格式

例子2:

import torch.nn as nn

loss = nn.CrossEntropyLoss()

# input is of size nBatch x nClasses = 3 x 5

input = autograd.Variable(torch.randn(3, 5), requires_grad=True)

# each element in target has to have 0 <= value < nclasses

target = autograd.Variable(torch.LongTensor([1, 0, 4]))

output = loss(input, target)

例子1和例子2结果等价

如果是one-hot编码该怎么计算loss?

for images, target in train_loader:

images, target = images.cuda(), target .cuda()

N = target .size(0)

# N 是batch-size大小

# C is the number of classes.

labels = torch.full(size=(N, C), fill_value=0).cuda()

labels.scatter_(dim=1, index=torch.unsqueeze(target, dim=1), value=1)

score = model(images)

log_prob = torch.nn.functional.log_softmax(score, dim=1)

loss = -torch.sum(log_prob * labels) / N

optimizer.zero_grad()

loss.backward()

optimizer.step()

其中N是类别数目,labels是one-hot编码格式的二维向量(2-D tensor)。

需要先将例子1,2的target转为one-hot形式labels。

该loss计算可以替代例子1和例子2的loss计算方式

上述计算案例如下:

import torch.nn as nn

import torch

from torch import autograd

import torch.nn.functional as F

# logsoft-max + NLLLoss

m = nn.LogSoftmax()

loss = nn.NLLLoss()

input = autograd.Variable(torch.randn(3, 5), requires_grad=True)

target = autograd.Variable(torch.LongTensor([1, 0, 4]))

output = loss(m(input), target)

print('logsoftmax + nllloss output is {}'.format(output))

# crossentripyloss

loss = nn.CrossEntropyLoss()

# input = autograd.Variable(torch.randn(3, 5), requires_grad=True)

target = autograd.Variable(torch.LongTensor([1, 0, 4]))

output = loss(input, target)

print('crossentropy output is {}'.format(output))

# one hot label loss

C = 5

target = autograd.Variable(torch.LongTensor([1, 0, 4]))

print('target is {}'.format(target))

N = target .size(0)

# N 是batch-size大小

# C is the number of classes.

labels = torch.full(size=(N, C), fill_value=0)

print('labels shape is {}'.format(labels.shape))

labels.scatter_(dim=1, index=torch.unsqueeze(target, dim=1), value=1)

print('labels is {}'.format(labels))

log_prob = torch.nn.functional.log_softmax(input, dim=1)

loss = -torch.sum(log_prob * labels) / N

print('N is {}'.format(N))

print('one-hot loss is {}'.format(loss))

结果如下:

logsoftmax + nllloss output is 3.005390167236328

crossentropy output is 3.005390167236328

target is tensor([1, 0, 4])

labels shape is torch.Size([3, 5])

labels is tensor([[0., 1., 0., 0., 0.],

[1., 0., 0., 0., 0.],

[0., 0., 0., 0., 1.]])

N is 3

one-hot loss is 3.005390167236328

可知相同的输入下全部等价。

补充:

以及关于cross entropy有关的函数及在torch.nn和torch.nn.functional中对应关系如下:

6301799ec29f0e07438448334b3afa68.png

torch.nn和torch.nn.functional的区别在于torch.nn中对应的函数其实就是对F里的函数进行包装的类。

参考

原文链接:https://blog.csdn.net/c2250645962/article/details/106014693


本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

相关文章

立即
投稿

微信公众账号

微信扫一扫加关注

返回
顶部