Pytorch读取npy数据格式,编写dataset模块,可配合Dataloader进行使用

        在训练模型前,最重要的部分就是制作好数据集,有些情况下,由于图片数据过多,然后存储很不方便,我们就需要将数据制作成npy类型的数据格式。npy数据格式是一个四维的数组[N,H,W, C],其中N代表数据集的总数,H, W,C分别代表每一张图片对应的长、宽、以及通道数。

数据制作好之后,就是如何加载数据问题,TF中加载数据相对比较容易,但是Pytorch中,我们一般都是将数据制作成dataset,再传入Dataloader进行加载,因此就需要继承Dataset的类,然后编写读取npy的数据格式。Dataset中,我们需要定义三个函数。

一、__init__(self,data) 函数

主要是用来加载npy数据的,也可以加载数据预处理的函数,比如将数据转化为tensor之类的操作

 def __init__(self, data):self.data = np.load(data) #加载npy数据self.transforms = transform #转为tensor形式

二、__len__(self)函数

这个函数就是用来返回数据的总个数

 def __len__(self):return self.data.shape[0] #返回数据的总个数

三、 __getitem__(self,index)函数

这个是最要的函数,类似一个for循环,从头开始,每次读取一个保存在npy里面的数据,然后进行处理后,可以同时返回训练数据,以及对应的标签

    def __getitem__(self, index):hdct= self.data[index, :, :, :]  # 读取每一个npy的数据hdct = np.squeeze(hdct)  # 删掉一维的数据,就是把通道数这个维度删除ldct = 2.5 * skimage.util.random_noise(hdct * (0.4 / 255), mode='poisson', seed=None) * 255 #加poisson噪声hdct=Image.fromarray(np.uint8(hdct)) #转成image的形式ldct=Image.fromarray(np.uint8(ldct)) #转成image的形式hdct= self.transforms(hdct)  #转为tensor形式ldct= self.transforms(ldct)  #转为tensor形式return ldct,hdct #返回数据还有标签

完整的代码如下:

import torch
import numpy as np
import skimage
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
torch.manual_seed(1)  # reproducibletransform = transforms.Compose([transforms.ToTensor(),  # 将图片转换为Tensor,归一化至[0,1]
])
'''NPY数据格式'''
class MyDataset(Dataset):def __init__(self, data):self.data = np.load(data) #加载npy数据self.transforms = transform #转为tensor形式def __getitem__(self, index):hdct= self.data[index, :, :, :]  # 读取每一个npy的数据hdct = np.squeeze(hdct)  # 删掉一维的数据,就是把通道数这个维度删除ldct = 2.5 * skimage.util.random_noise(hdct * (0.4 / 255), mode='poisson', seed=None) * 255 #加poisson噪声hdct=Image.fromarray(np.uint8(hdct)) #转成image的形式ldct=Image.fromarray(np.uint8(ldct)) #转成image的形式hdct= self.transforms(hdct)  #转为tensor形式ldct= self.transforms(ldct)  #转为tensor形式return ldct,hdct #返回数据还有标签def __len__(self):return self.data.shape[0] #返回数据的总个数def main():dataset=MyDataset('.\data_npy\img_covid_poisson_glay_clean_BATCH_64_PATS_100.npy')data= DataLoader(dataset, batch_size=64, shuffle=True, pin_memory=True)if __name__ == '__main__':main()

 


本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

相关文章

立即
投稿

微信公众账号

微信扫一扫加关注

返回
顶部