对于训练时loss出现负值的情况

在这里插入图片描述
在训练时候loss出现负值,就立马停下来分析一下原因在哪。最有可能是损失函数出现问题,开始只使用交叉熵损失时没有出现过,在加上了dice loss时就出现了问题。于是就去dice loss中寻找原因。
1:首先需要明白语义分割的GT,每一个像素点的值就是像素的类别。

# -*- coding: utf-8 -*-
import numpy as np
from torchvision import transforms
import torch
from PIL import Image
img = Image.open('C:/Users/翰墨大人/Desktop/0003_lable.png') #图像所在位置
img1 = np.array(img)
img1 = torch.from_numpy(img1).type(torch.FloatTensor)
# trans = transforms.ToTensor()
# img1 = trans(img1)a = torch.unique(img1) # 查看图片内的像素值
print(a)
print(img.mode) # 查看图片模式

打印结果:

tensor([ 0.,  1.,  5.,  7.,  8., 26., 29., 38., 40.])
P

原图一共有四十个类别,在003_lable这张GT上只出现了上述的类别。剩下的像素点和上述像素点是重复的。
注意:
这里将np格式转换为tensor格式时候,不能使用transforms.ToTensor(),他会将像素值发生改变。这样新的像素值点和类别就不是一一对应的。

tensor([0.0000, 0.0039, 0.0196, 0.0275, 0.0314, 0.1020, 0.1137, 0.1490, 0.1569])
P

2:语义分割的GT一共有四十类,没有通道,而在模型中pred的输出为一个通道。
在计算二分类dice loss时候,首先要将pred进行sigmoid。GT是四十个类别,如果是二分类的话那么标签就必须是[0,1]。而loss为负值的原因就是没有将标签转换为[0,1]。
X是pred,Y是GT,分子就是X和Y进行矩阵的相乘再相加,当Y中含有大于1的类别,比如30,40的话,而X又是进过sigmoid之后再(0,1)之内,那么分子除以分母的值就会大于1,造成dice loss就变成了负值。
在这里插入图片描述
如何将GT变为[0,1]呢?因为我需要的是对GT进行提边,使用一个拉普拉斯对GT进行卷积,然后再使用一个阈值,大于阈值为1,小于为0。为是边缘和不是边缘。GT就又灰度图像转换为了二值图像。

l = F.conv2d(x,sobel,padding=1,stride=1)
print(l.shape)
ll = torch.unique(l) # 查看图片内的像素值
print(ll)l[l>0.1]=1
l[l<0.1]=0
l_ = torch.unique(l) # 查看图片内的像素值
print(l_)

结果:
在这里插入图片描述
对于多分类的dice loss:
GT需要进行one-hot编码,这样一个通道,像素点为(0-40)的GT就会变为四十个通道,每个通道像素点为(0,1),每个通道都可以看做一个二分类问题,属于该类别和不属于该类别。而pred的输出通道也为40,通过计算pred的每个通道和GT的每个通道的loss,最后求均值得到总loss。
3:代码代码参考
单分类代码:

import torch
import torch.nn as nnclass BinaryDiceLoss(nn.Model):def __init__(self):super(BinaryDiceLoss, self).__init__()def forward(self, input, targets):# 获取每个批次的大小 NN = targets.size()[0]# 平滑变量smooth = 1# 将宽高 reshape 到同一纬度input_flat = input.view(N, -1)targets_flat = targets.view(N, -1)# 计算交集intersection = input_flat * targets_flat N_dice_eff = (2 * intersection.sum(1) + smooth) / (input_flat.sum(1) + targets_flat.sum(1) + smooth)# 计算一个批次中平均每张图的损失loss = 1 - dice_eff.sum() / Nreturn loss

多分类代码:

import torch
import torch.nn as nnclass MultiClassDiceLoss(nn.Module):def __init__(self, weight=None, ignore_index=None, **kwargs):super(MultiClassDiceLoss, self).__init__()self.weight = weightself.ignore_index = ignore_indexself.kwargs = kwargsdef forward(self, input, target):"""input tesor of shape = (N, C, H, W)target tensor of shape = (N, H, W)"""# 先将 target 进行 one-hot 处理,转换为 (N, C, H, W)nclass = input.shape[1]target = one_hot(target.long(), nclass)assert input.shape == target.shape, "predict & target shape do not match"binaryDiceLoss = BinaryDiceLoss()total_loss = 0# 归一化输出logits = F.softmax(input, dim=1)C = target.shape[1]# 遍历 channel,得到每个类别的二分类 DiceLossfor i in range(C):dice_loss = binaryDiceLoss(logits[:, i], target[:, i])total_loss += dice_loss# 每个类别的平均 dice_lossreturn total_loss / C


本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

相关文章

立即
投稿

微信公众账号

微信扫一扫加关注

返回
顶部