TensorFlow——本地加载fashion-mnist数据集

基本概念

Fashion MNIST:Fashion MNIST 旨在临时替代经典 MNIST 数据集,后者常被用作计算机视觉机器学习程序的“Hello, World”。MNIST 数据集包含手写数字(0、1、2 等)的图像,其格式与您将使用的衣物图像的格式相同。

问题描述

在动手写深度学习的TensorFlow实现版本中,需要用到数据集Fashion MNIST,如果直接用TensorFlow导入数据集:

from tensorflow.keras.datasets import fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

就会报错,下载数据集时会显示服务器连接超时,可能因为服务器在国内被墙了。

解决方案

1、下载

https://github.com/zalandoresearch/fashion-mnist 

 

下载完成后解压放在./data/fashion/文件夹下

 

2、导入 

接下导入数据集:

import mnist_readerx_train, y_train = mnist_reader.load_mnist('data/fashion', kind='train')
x_test, y_test = mnist_reader.load_mnist('data/fashion', kind='t10k')

注意:

mnist_reader是GitHub上该项目里面的一个文件,不要以为是某个库

 

代码: 

def load_mnist(path, kind='train'):import osimport gzipimport numpy as np"""Load MNIST data from `path`"""labels_path = os.path.join(path,'%s-labels-idx1-ubyte.gz'% kind)images_path = os.path.join(path,'%s-images-idx3-ubyte.gz'% kind)with gzip.open(labels_path, 'rb') as lbpath:labels = np.frombuffer(lbpath.read(), dtype=np.uint8,offset=8)with gzip.open(images_path, 'rb') as imgpath:images = np.frombuffer(imgpath.read(), dtype=np.uint8,offset=16).reshape(len(labels), 784)return images, labels

 3、测试

注意:

mnist_reader.pyfashion_mnist.load_data的结果并不相同,会影响后续操作

修改版本

def load_mnist(path, kind='train'):import osimport gzipimport numpy as np"""Load MNIST data from `path`"""labels_path = os.path.join(path,'%s-labels-idx1-ubyte.gz'% kind)images_path = os.path.join(path,'%s-images-idx3-ubyte.gz'% kind)with gzip.open(labels_path, 'rb') as lbpath:labels = np.frombuffer(lbpath.read(), dtype=np.uint8,offset=8)with gzip.open(images_path, 'rb') as imgpath:images = np.frombuffer(imgpath.read(), dtype=np.uint8,offset=16).reshape(len(labels), 28, 28)  # 关键点return images, labelsdef load_data(path):train_images, train_labels = load_mnist(path, kind='train')test_images, test_labels = load_mnist(path, kind='t10k')return (train_images, train_labels), (test_images, test_labels)def load_data():return load_data('data/fashion')

参考文章

使用matplotlib.pyplot.imshow() 显示图像时出现“TypeError: Invalid dimensions for image data”的问题

如何加载mnist和fashion-mnist数据集

Fashion MNIST的下载与导入

Tensorflow学习第1课——从本地加载MNIST以及FashionMNIST数据

TensorFlow——[基本图像分类]fashion-mnist及mnist_reader.py运行错误[TypeError: Invalid dimensions for image data]


本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

相关文章

立即
投稿

微信公众账号

微信扫一扫加关注

返回
顶部