为sketch数据集生成txt文本,并自定义Dataset

线稿上色的数据集:
dataset link:https://pan.baidu.com/s/1Abm7V6J2uNOy5U6nvsRSlg
key:eepv

txt文件生成
在这里插入图片描述

import os
import globdef Create_Txt(data_name, data_path, data_class,txt_path,ratio = 0.01):# absolute pathdata_path = os.path.join(data_path,data_name)txt_path = os.path.join(txt_path,data_name)# find the required fileimgs_path = glob.glob(data_path+"/"+data_class[0]+"\*.png")num_data = int(len(imgs_path) * ratio)# create the txt filetxt_class = ["train.txt","val.txt","test.txt"]txt_class_ratio = [0.7, 0.05, 0.25]if not os.path.exists(txt_path):os.makedirs(txt_path)start = 0for i in range(len(txt_class)):i_txt_path = os.path.join(txt_path,txt_class[i])txt = open(i_txt_path, mode='w')if i != len(txt_class)-1:end = start + int(num_data * txt_class_ratio[i])else:end = num_datafor j in range(start, end):name = os.path.basename(imgs_path[j])data = []for k in range(len(data_class)):temp = data_path + "/" + data_class[k] + "/"+nameif k != len(data_class)-1:temp = temp + " "data.append(temp)data.append("\n")txt.write(''.join(data))start = endif __name__ == '__main__':current_path = os.getcwd()data_name = "sketch"data_path = current_path + "/data"data_class = ["img","label"]txt_path =  current_path + "/list"Create_Txt(data_name, data_path, data_class,txt_path)

使用txt文本读入数据可以减少内存的需要,有时候自定义加载数据集是非常必要的。
在这里插入图片描述

自定义Dataset

from torch.utils.data import Dataset
import os
import cv2
import numpy as np
import torch
import torchvision.transforms.functional as F
import torchvisiondef cuda(*args):return (item.cuda() for item in args)class Sketch(Dataset):def __init__(self,list_path, mode="train"):self.mode = modeif mode == "train":list_path = os.path.join(list_path,"train.txt")elif mode == "test":list_path = os.path.join(list_path, "test.txt")else:list_path = os.path.join(list_path, "val.txt")# .txt/.lst数据获取:打开文件,以空格分割每一行(注意:不要有空行)self.img_list = [line.strip().split() for line in open(list_path)]# 添加信息:sample{image_path,label_path, name}self.files = self.read_files()def __len__(self):return len(self.files)def __getitem__(self, index):return self.load_item(index)def read_files(self):files = []if self.mode == "test":for item in self.img_list:image_path = itemname = os.path.splitext(os.path.basename(image_path[0]))[0]files.append({"img": image_path[0],"name": name,})else:for item in self.img_list:image_path, label_path = itemname = os.path.splitext(os.path.basename(label_path))[0]files.append({"img": image_path,"label": label_path,"name": name,})return files# 根据索引,获得对象items:[images, label]def load_item(self, index):item = self.files[index]image = self.read_image(item["img"],cv2.IMREAD_COLOR)label = self.read_image(item["label"],cv2.IMREAD_GRAYSCALE)name = item["name"]if self.mode == "test":return F.to_tensor(label),namereturn F.to_tensor(image), F.to_tensor(label), namedef read_image(self, img_path, read_mode):image = cv2.imread(img_path, read_mode).astype(np.float32)if read_mode == cv2.IMREAD_COLOR:# BGR -> RGBimage = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)else:# Add 3rd dimension to grayscaleimage = image[:, :, np.newaxis]return imageif __name__ == '__main__':current_path = os.getcwd()txt_path_train =  current_path + "/list/sketch"# initialDataset = Sketch(txt_path_train, mode="train")# the way of dataloaderdataloader = torch.utils.data.DataLoader(Dataset,batch_size= 1,shuffle = False)for index, items in enumerate(dataloader):images, labels, name = itemsimages, labels = cuda(*[images, labels])torchvision.transforms.ToPILImage()(images[0].cpu()).show()torchvision.transforms.ToPILImage()(labels[0].cpu()).show()


本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

相关文章

立即
投稿

微信公众账号

微信扫一扫加关注

返回
顶部