半监督学习之数据加载
图像分类数据集
半监督数据加载:把需要设置为无标签样本的标签设置为-1,这样可以在交叉熵的时候设置忽略-1的标签
class_criterion = nn.CrossEntropyLoss(reduction='sum', ignore_index=NO_LABEL)
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from functools import reduce
from operator import __or__
from torch.utils.data.sampler import Sampler
import itertools
import numpy as npdef load_data(path, args, NO_LABEL=-1):if args.dataset == 'cifar10':mean = [x / 255 for x in [125.3, 123.0, 113.9]]std = [x / 255 for x in [63.0, 62.1, 66.7]]elif args.dataset == 'cifar100':mean = [x / 255 for x in [129.3, 124.1, 112.4]]std = [x / 255 for x in [68.2, 65.4, 70.4]]elif args.dataset == 'svhn':mean = [x / 255 for x in [127.5, 127.5, 127.5]]std = [x / 255 for x in [127.5, 127.5, 127.5]]elif args.dataset == 'mnist':mean = (0.5, )std = (0.5, )elif args.dataset == 'stl10':assert False, 'Do not finish stl10 code'elif args.dataset == 'imagenet':assert False, 'Do not finish imagenet code'else:assert False, "Unknow dataset : {}".format(args.dataset)if args.dataset == 'svhn':train_transform = transforms.Compose([transforms.RandomCrop(32, padding=2),transforms.ToTensor(),transforms.Normalize(mean, std)])test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean, std)])elif args.dataset == 'mnist':train_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])else:train_transform = TransformTwice(transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomCrop(32, padding=2),transforms.ToTensor(),transforms.Normalize(mean, std)]))test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean, std)])if args.dataset == 'cifar10':train_data = datasets.CIFAR10(path, train=True, transform=train_transform, download=True)test_data = datasets.CIFAR10(path, train=False, transform=test_transform, download=True)num_classes = 10elif args.dataset == 'cifar100':train_data = datasets.CIFAR100(path, train=True, transform=train_transform, download=True)test_data = datasets.CIFAR100(path, train=False, transform=test_transform, download=True)num_classes = 100elif args.dataset == 'svhn':train_data = datasets.SVHN(path, split='train', transform=train_transform, download=True)test_data = datasets.SVHN(path, split='test', transform=test_transform, download=True)num_classes = 10elif args.dataset == 'mnist':train_data = datasets.MNIST(path, train=True, transform=train_transform, download=True)test_data = datasets.MNIST(path, train=False, transform=test_transform, download=True)num_classes = 10elif args.dataset == 'stl10':train_data = datasets.STL10(path, split='train', transform=train_transform, download=True)test_data = datasets.STL10(path, split='test', transform=test_transform, download=True)num_classes = 10elif args.dataset == 'imagenet':assert False, 'Do not finish imagenet code'else:assert False, 'Do not support dataset : {}'.format(args.dataset)labeled_idxs, unlabeled_idxs = spilt_l_u(args.dataset, train_data, args.num_labels)# if args.labeled_batch_size:# batch_sampler = TwoStreamBatchSampler(# unlabeled_idxs, labeled_idxs, args.batch_size, args.labeled_batch_size)# else:# assert False, "labeled batch size {}".format(args.labeled_batch_size)if args.dataset == 'svhn':train_data.labels = np.array(train_data.labels)train_data.labels[unlabeled_idxs] = NO_LABELelse:train_data.targets = np.array(train_data.targets)train_data.targets[unlabeled_idxs] = NO_LABELtrain_loader = DataLoader(train_data,batch_size=args.batch_size,shuffle=True,num_workers=args.workers,pin_memory=True,drop_last=True)eval_loader = DataLoader(test_data,batch_size=args.eval_batch_size,shuffle=False,num_workers=args.workers, # Needs images twice as fastpin_memory=True,drop_last=False)return train_loader, eval_loaderdef spilt_l_u(dataset, train_data, num_labels, num_val=400, classes=10):if dataset == 'mnist':labels = train_data.targets.numpy()elif dataset == 'svhn':labels = train_data.labelselse:labels = train_data.targetsv = num_valn = int(num_labels / classes)(indices,) = np.where(reduce(__or__, [labels == i for i in np.arange(classes)]))# Ensure uniform distribution of labelsnp.random.shuffle(indices)indices_train = np.hstack([list(filter(lambda idx: labels[idx] == i, indices))[:n] for i in range(classes)])indices_unlabelled = np.hstack([list(filter(lambda idx: labels[idx] == i, indices))[n:] for i in range(classes)])indices_train = torch.from_numpy(indices_train)indices_unlabelled = torch.from_numpy(indices_unlabelled)return indices_train, indices_unlabelledclass TransformTwice:def __init__(self, transform):self.transform = transformdef __call__(self, inp):out1 = self.transform(inp)out2 = self.transform(inp)return out1, out2class TwoStreamBatchSampler(Sampler):"""Labeled + unlabeled data in a batchIterate two sets of indicesAn 'epoch' is one iteration through the primary indices.During the epoch, the secondary indices are iterated throughas many times as needed."""def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size):self.primary_indices = primary_indicesself.secondary_indices = secondary_indicesself.secondary_batch_size = secondary_batch_sizeself.primary_batch_size = batch_size - secondary_batch_sizeassert len(self.primary_indices) >= self.primary_batch_size > 0assert len(self.secondary_indices) >= self.secondary_batch_size > 0def __iter__(self):primary_iter = iterate_once(self.primary_indices)secondary_iter = iterate_eternally(self.secondary_indices)return (primary_batch + secondary_batchfor (primary_batch, secondary_batch)in zip(grouper(primary_iter, self.primary_batch_size),grouper(secondary_iter, self.secondary_batch_size)))def __len__(self):return len(self.primary_indices) // self.primary_batch_sizedef iterate_once(iterable):return np.random.permutation(iterable)def iterate_eternally(indices):def infinite_shuffles():while True:yield np.random.permutation(indices)return itertools.chain.from_iterable(infinite_shuffles())def grouper(iterable, n):"Collect data into fixed-length chunks or blocks"# grouper('ABCDEFG', 3) --> ABC DEF"args = [iter(iterable)] * nreturn zip(*args)
参考资料
- https://blog.csdn.net/Z609834342/article/details/106863690
本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!
