tensorflow分类的loss函数_Tensorflow入门教程(三十三)——图像分割损失函数FocalLoss...

常见的图像分割损失函数有交叉熵,dice系数,FocalLoss等。今天我将分享图像分割FocalLoss损失函数及Tensorflow版本的复现。

1、FocalLoss介绍

FocalLoss思想出自何凯明大神的论文《Focal Loss for Dense Object Detection》,主要是为了解决one-stage目标检测中正负样本比例严重失衡的问题。

FocalLoss是在交叉熵函数的基础上进行的改进,改进的地方主要在两个地方

(1)、改进第一点如下公式所示。

a014ecdbe0869226a337f1f89be6f846.png

首先在原有交叉熵函数基础上加了一个权重因子,其中gamma>0,使得更关注于困难的、错分的样本。比如:若 gamma = 2,对于正类样本来说,如果预测结果为0.97,那么肯定是易分类的样本,权重值为0.0009,损失函数值就会很小了;对于正类样本来说,如果预测结果为0.3,那么肯定是难分类的样本,权重值为0.49,其损失函数值相对就会很大;对于负类样本来说,如果预测结果为0.8,那么肯定是难分类的样本,权重值为0.64,其损失函数值相对就会很大;对于负类样本来说,如果预测结果为0.1,那么肯定是易分类的样本,权重值为0.01,其损失函数值就会很小。而对于预测概率为0.5时,损失函数值只减少了0.25倍,所以FocalLoss减少了简单样本的影响从而更加关注于难以区分的样本。

(2)、改进第二点如下公式所示。

0bc85e8e73fd936eac6df35b8d9b3326.png

FocalLoss还引入了平衡因子alpha,用来平衡正负样本本身的比例不均匀。alpha取值范围0~1,当alpha>0.5时,可以相对增加y=1所占的比例,保证正负样本的平衡。

(3)、虽然在何凯明的试验中, 认为gamma为2是最优的,但是不代表这个参数适合其他样本,在实际应用中还需要根据实际情况调整这两个参数:alpha和gamma。

2、FocalLoss公式推导

在github上已经可以找到很多FocalLoss的实现,如下二分类的FocalLoss实现。实现其实不是很难,但是在实际训练时会出现NAN的现象。

520924544d4a41791c3ec48a4dfaee23.png

下面将简单推导一下FocalLoss函数在二分类时的函数表达式。

FocalLoss函数可以表示如下公式所示:

423718196fd861dd1f0dd66f206fe2cb.png

假设网络的最后输出采用逻辑回归函数sigmod,对于二分类问题(0和1),预测输出可以表示为:

ddb05ec6c998406e4cd260ffb8a4e223.pnga01e66149add8cde4d8d93973933467a.png

将上述公式带入FocalLoss函数中,并进行推导。

6738b048a976805404177917a017d8d4.png

3、FocalLoss代码实现

按照上面导出的表达式FocalLoss的伪代码可以表示为:

380a886c6cff5241bd12a61544c56ea6.png

其中,

19c718a6c789aff53d85c116beb86ff8.png

24a5a6a98037eb394574387d470770aa.png

fcca2f7622c063ee576bdebb6eee0093.png

从这里可以看到1-y_pred项可能为0或1,这会导致log函数值出现NAN现象,所以好需要对y_pred项进行固定范围值的截断操作。最后在TensorFlow1.8下实现了该函数。

import tensorflow as tfdef focal_loss(y_true, y_pred, alpha=0.25, gamma=2):    epsilon = 1e-5    y_pred = tf.clip_by_value(y_pred, epsilon, 1 - epsilon)    logits = tf.log(y_pred / (1 - y_pred))    weight_a = alpha * tf.pow((1 - y_pred), gamma) * y_true    weight_b = (1 - alpha) * tf.pow(y_pred, gamma) * (1 - y_true)    loss = tf.log1p(tf.exp(-logits)) * (weight_a + weight_b) + logits * weight_b    return tf.reduce_mean(loss)


本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

相关文章

立即
投稿

微信公众账号

微信扫一扫加关注

返回
顶部