PYTORCH:DenseNet做新冠肺炎CT照片是否确诊分类
完整项目代码:https://github.com/SPECTRELWF/pytorch-cnn-study
DenseNet网络结构

DenseNet是清华大学的黄高教授在CVPR的工作,在resnet提出的第二年提出,也拿到了当年的最佳论文。。
数据集描述
数据集使用的是来自格物钛的一个公开数据集,数据集下载地址:https://gas.graviti.cn/dataset/data-decorators/COVID_CT
里面包含715张图片,包含确诊和未确诊的,比例大概一比一,图像是处理过的CT图像。

网络结构
使用pytorch的torchvision里面提供的densenet(),未使用预训练模型。在后面再加上一层全连接层:
# !/usr/bin/python3
# -*- coding:utf-8 -*-
# Author:WeiFeng Liu
# @Time: 2021/11/9 下午4:57import torchvision
import torch.nn as nnclass my_densenet(nn.Module):def __init__(self):super(my_densenet, self).__init__()self.backbone = torchvision.models.densenet121(pretrained=False)self.fc2 = nn.Linear(1000,512)self.fc3 = nn.Linear(512,2)def forward(self,x):x = self.backbone(x)x = self.fc2(x)x = self.fc3(x)return x
train:
# !/usr/bin/python3
# -*- coding:utf-8 -*-
# Author:WeiFeng Liu
# @Time: 2021/11/9 下午4:48import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import torch.utils.data as data
from torch.utils.data import DataLoader
from dataload.COVID_Dataload import COVID
from densenet import my_densenet
from torch import nn,optimtransforms = transforms.Compose([transforms.Resize([224,224]),transforms.RandomHorizontalFlip(),# transforms.RandomCrop(224),transforms.ToTensor(),])batch_size = 32
train_set = COVID(transformer=transforms,train=True)
train_loader = DataLoader(train_set,batch_size = batch_size,shuffle = True,)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")#设置超参数
epochs = 200
lr = 1e-4net = my_densenet().cuda(device)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),lr=lr,momentum=0.9)
train_loss = []for epoch in range(epochs):sum_loss = 0for batch_idx,(x,y) in enumerate(train_loader):x = x.to(device)y = y.to(device)pred = net(x)optimizer.zero_grad()loss = loss_func(pred, y)loss.backward()optimizer.step()sum_loss += loss.item()train_loss.append(loss.item())print(["epoch:%d , batch:%d , loss:%.3f" % (epoch, batch_idx,loss.item())])torch.save(net.state_dict(),'model/no_pretrain/epoch' + str(epoch+1) + '.pth')
from utils import plot_curve
plot_curve(train_loss)
训练loss

test:
# !/usr/bin/python3
# -*- coding:utf-8 -*-
# Author:WeiFeng Liu
# @Time: 2021/11/4 下午1:29import torch
import torchvision
from dataload.COVID_Dataload import COVID
# 定义使用GPU
from torch.utils.data import DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import torchvision.transforms as transforms
from densenet import my_densenet
transform = transforms.Compose([transforms.Resize([224,224]),transforms.RandomHorizontalFlip(),transforms.ToTensor(),# transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5]),])test_dataset = COVID(train=False,transformer=transform)
test_loader = DataLoader(test_dataset,batch_size = 32,shuffle = False,)def predict():net = my_densenet().to(device)net.load_state_dict(torch.load('model/pretrain/epoch200.pth'))print(net)total_correct = 0for batch_idx, (x, y) in enumerate(test_loader):# x = x.view(x.size(0),28*28)# x = x.view(256,28,28)x = x.to(device)print(x.shape)y = y.to(device)print('y',y)out = net(x)# print(out)pred = out.argmax(dim=1)print('pred',pred)correct = pred.eq(y).sum().float().item()total_correct += correcttotal_num = len(test_loader.dataset)acc = total_correct / total_numprint("test acc:", acc)predict()

predict
# !/usr/bin/python3
# -*- coding:utf-8 -*-
# Author:WeiFeng Liu
# @Time: 2021/11/4 下午2:38##读入文件,显示正确分类和预测分类
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from densenet import my_densenettransform = transforms.Compose([transforms.Resize([224,224]),transforms.RandomHorizontalFlip(),transforms.ToTensor(),# transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5]),])file_name = input("输入要预测的文件名:")
img = Image.open(file_name).convert("RGB")
show_img = img
img = transform(img)
#
# print(img)
# print(img.shape)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
img = img.to(device)
img = img.unsqueeze(0)
net = my_densenet().to(device)
net.load_state_dict(torch.load(r'model/no_pretrain/epoch200.pth'))pred = net(img)
print(pred)
print(pred.argmax(dim = 1).cpu().numpy()[0])
res = ''
if pred.argmax(dim = 1) == 0:res += 'pred:no_covid'
else:res += 'pred:covid'plt.figure("Predict")
plt.imshow(show_img)
plt.axis("off")
plt.title(res)
plt.show()

本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!
