小样本方式加载Stanford_Cars数据集

文章目录

    • 准备数据
    • 根据标签划分数据集
    • 加载成dataloader
    • 全部代码

准备数据

需要从数据集地址:https://ai.stanford.edu/~jkrause/cars/car_dataset.html下载3个压缩包,下载好之后解压
在这里插入图片描述

在这里插入图片描述

根据标签划分数据集

import scipy.io as scio
import os.path as osp# 这个路径下有三个文件夹
dataset_path = "XXX"train_img_path = osp.join(dataset_path, "cars_train")
test_img_path = osp.join(dataset_path, "cars_test")label_path = osp.join(dataset_path, "car_devkit", "devkit")
train_label_path = osp.join(label_path, "cars_train_annos.mat")
test_label_path = osp.join(label_path, "cars_test_annos.mat")img = 0def devide_dataset_by_label(path):global imgdata = scio.loadmat(path)data = data["annotations"]data = data.squeeze()labels = np.zeros(data.shape[0])for number, label in enumerate(labels):if not osp.exists(osp.join(dataset_path, "classes", str(int(label)))):os.makedirs(osp.join(dataset_path, "classes", str(int(label))))os.system(f"cp {osp.abspath(osp.join(train_img_path, str(int(number+1)).zfill(5) + '.jpg'))} {osp.abspath(osp.join(dataset_path, 'classes', str(int(label))))}")img += 1if __name__ == '__main__':devide_dataset_by_label(train_label_path)devide_dataset_by_label(test_label_path)print(f"devide {img} imgs")

一共给16185张图片分了类,和数据集官方数据一样
在这里插入图片描述

在这里插入图片描述

加载成dataloader

def get_Stanford_Cars_dataloader(mode="train", way=5, shot=2, query=10):if not osp.exists(osp.abspath(osp.join(dataset_path, "classes"))):devide_dataset_by_label(train_label_path)devide_dataset_by_label(test_label_path)print(f"devide {img} imgs")classes_path = osp.abspath(osp.join(dataset_path, "classes"))class_list = []for class_name in os.listdir(classes_path):if class_name.__contains__("DS_Store"):continueclass_list.append(os.path.join(classes_path, class_name))class_names = []for i in class_list:if os.listdir(i).__len__() >= query + shot:class_names.append(i)train_class_lists = class_names[:int(class_names.__len__() * 0.6)]val_class_lists = class_names[int(class_names.__len__() * 0.6):int(class_names.__len__() * 0.8)]test_class_lists = class_names[int(class_names.__len__() * 0.8):]transforms = [partial(convert_dict, "class"),partial(load_class_images, 64),partial(extract_episode, shot, query)]transforms = compose(transforms)lists = []if mode == "train":lists = train_class_listselif mode == "val":lists = val_class_listselif mode == "test":lists = test_class_listsepisode = int(len(lists) / way) + 1  # 加1防止有一类没学习ds = TransformDataset(ListDataset(lists),  # 先将list类型数的class_names处理成Dataset类的ListDatasettransforms)  # 然后对它使用transforms,对数据进行进一步处理sampler = EpisodicBatchSampler(len(ds), way, episode)  # 每个episode有10way,每个batch有30个episdoedataloader = torch.utils.data.DataLoader(ds, batch_sampler=sampler, num_workers=0)return dataloader

全部代码

分享至码云


本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

相关文章

立即
投稿

微信公众账号

微信扫一扫加关注

返回
顶部