可视化学习笔记11-pytorch-GradCAM可视化自己的网络

续可视化笔记2-pytorch 可视化卷积网络中间特征层的基础上使用CAM方法可视化网络对待测对象关注的位置。

1.定义GradCAM类

注意:代码中需要改的3个地方已经用注释标清,大家使用时注意修改。

class GradCAM(nn.Module):def __init__(self):super(GradCAM, self).__init__()# 获取模型的特征提取层self.feature = nn.Sequential(OrderedDict({name: layer for name, layer in model.named_children()if name not in ['avgpool', 'classifier']#改1:根据自己的网络模型架构调整。}))# 获取模型最后的平均池化层self.avgpool = model.avgpool# 获取模型的输出层self.classifier = nn.Sequential(OrderedDict([('classifier', model.classifier)#改2:模型剩什么层就写什么层(这里我的网络除了avgpool就只剩classifier)。]))# 生成梯度占位符self.gradients = None# 获取梯度的钩子函数def activations_hook(self, grad):self.gradients = graddef forward(self, x):x = self.feature(x)# 注册钩子h = x.register_hook(self.activations_hook)# 对卷积后的输出使用平均池化x = self.avgpool(x)x = x.view((1, -1))x = self.classifier(x)#改3:同2return x# 获取梯度的方法def get_activations_gradient(self):return self.gradients# 获取卷积层输出的方法def get_activations(self, x):return self.feature(x)

2.获取热力图

# 获取热力图
def get_heatmap(model, img):model.eval()img_pre = model(img)# 获取预测最高的类别pre_class = torch.argmax(img_pre, dim=-1).item()# 获取相对于模型参数的输出梯度img_pre[:, pre_class].backward()# 获取模型的梯度gradients = model.get_activations_gradient()# 计算梯度相应通道的均值mean_gradients = torch.mean(gradients, dim=[0, 2, 3])# 获取图像在相应卷积层输出的卷积特征activations = model.get_activations(input_im).detach()# 每个通道乘以相应的梯度均值for i in range(len(mean_gradients)):activations[:, i, :, :] *= mean_gradients[i]# 计算所有通道的均值输出得到热力图heatmap = torch.mean(activations, dim=1).squeeze()# 使用Relu函数作用于热力图heatmap = F.relu(heatmap)# 对热力图进行标准化heatmap /= torch.max(heatmap)heatmap = heatmap.numpy()return heatmap
cam = GradCAM()
# 获取热力图
heatmap = get_heatmap(cam, input_im)
# 可视化热力图
plt.matshow(heatmap)
plt.show()

3.显示结果

# 合并热力图和原图,并显示结果
def merge_heatmap_image(heatmap, image_path):img = cv2.imread(image_path)heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))heatmap = np.uint8(255 * heatmap)heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)grad_cam_img = heatmap * 0.7 + imggrad_cam_img = grad_cam_img / grad_cam_img.max()# 可视化图像b,g,r = cv2.split(grad_cam_img)grad_cam_img = cv2.merge([r,g,b])plt.figure(figsize=(8,8))plt.imshow(grad_cam_img)plt.axis('off')plt.savefig("./CAM/CBAM_fig2")plt.show()
merge_heatmap_image(heatmap, img_path)

通过修改img_path变量,重新运行代码,及可以得到其他图片的CAM结果。这里仅展示部分结果:
在这里插入图片描述


本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

相关文章

立即
投稿

微信公众账号

微信扫一扫加关注

返回
顶部