MTCNN网络的训练

MTCNN三个网络分别训练,其中置信度和偏置是用不同的样本进行训练,置信度用正样本和负样本进行训练,偏移用正样本和部分样本进行训练:

import torch
from torch import nn
from torch import optim
from DataSet import MyDataSet
from Net import Pnet,Rnet,Onet
import DataSet
from torch.utils import dataclass Train:def __init__(self,p_textpath,n_textpath,t_textpath,p_imgpath,n_imgpath,t_imgpath,net):self.p_textpath = p_textpathself.p_imgpath = p_imgpathself.n_textpath = n_textpathself.n_imgpath = n_imgpathself.t_textpath = t_textpathself.t_imgpath = t_imgpathself.net = net#创建训练数据集dataset = MyDataSet(p_textpath,n_textpath,t_textpath,p_imgpath,n_imgpath,t_imgpath)self.dataloader = data.DataLoader(dataset,batch_size=10,shuffle=True)def train(self):if self.net == 'pnet':net = Pnet()elif self.net == 'rnet':net = Rnet()elif self.net == 'onet':net = Onet()optimizer = optim.Adam(net.parameters())conf_loss_fun = nn.BCELoss()off_loos_fun = nn.MSELoss()for epoch in range(1000):imgdata, conf, offset = DataSet.GetIter(self.dataloader)confidence,offset_out = net(imgdata)#置信度的损失需要正负样本#获得置信度小于2的掩码conn_mask = torch.lt(conf,2)#得到符合条件的置信度conf_ = conf[conn_mask]confidence_ = confidence[conn_mask]#偏移的损失需要正样本和部分样本#得到置信度大于0的掩码off_mask = torch.gt(conf,0)#得到符合条件的偏移offset = offset[off_mask[:,0]]offset_out = offset_out[off_mask[:,0]]conf_loss = conf_loss_fun(confidence_,conf_)off_loss = off_loos_fun(offset_out,offset)loss = conf_loss + off_lossoptimizer.zero_grad()loss.backward()optimizer.step()print(loss)train=Train(DataSet.p_48txtpath,DataSet.n_48txtpath,DataSet.t_48txtpath,DataSet.p_48imgpath,DataSet.n_48imgpath,DataSet.t_48imgpath,'onet')
train.train()

 


本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

相关文章

立即
投稿

微信公众账号

微信扫一扫加关注

返回
顶部