TensorFlow——本地加载fashion-mnist数据集
基本概念
Fashion MNIST:Fashion MNIST 旨在临时替代经典 MNIST 数据集,后者常被用作计算机视觉机器学习程序的“Hello, World”。MNIST 数据集包含手写数字(0、1、2 等)的图像,其格式与您将使用的衣物图像的格式相同。
问题描述
在动手写深度学习的TensorFlow实现版本中,需要用到数据集Fashion MNIST,如果直接用TensorFlow导入数据集:
from tensorflow.keras.datasets import fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
就会报错,下载数据集时会显示服务器连接超时,可能因为服务器在国内被墙了。
解决方案
1、下载
https://github.com/zalandoresearch/fashion-mnist

下载完成后解压放在./data/fashion/文件夹下

2、导入
接下导入数据集:
import mnist_readerx_train, y_train = mnist_reader.load_mnist('data/fashion', kind='train')
x_test, y_test = mnist_reader.load_mnist('data/fashion', kind='t10k')
注意:
mnist_reader是GitHub上该项目里面的一个文件,不要以为是某个库
代码:
def load_mnist(path, kind='train'):import osimport gzipimport numpy as np"""Load MNIST data from `path`"""labels_path = os.path.join(path,'%s-labels-idx1-ubyte.gz'% kind)images_path = os.path.join(path,'%s-images-idx3-ubyte.gz'% kind)with gzip.open(labels_path, 'rb') as lbpath:labels = np.frombuffer(lbpath.read(), dtype=np.uint8,offset=8)with gzip.open(images_path, 'rb') as imgpath:images = np.frombuffer(imgpath.read(), dtype=np.uint8,offset=16).reshape(len(labels), 784)return images, labels
3、测试

注意:
mnist_reader.py和fashion_mnist.load_data的结果并不相同,会影响后续操作
修改版本
def load_mnist(path, kind='train'):import osimport gzipimport numpy as np"""Load MNIST data from `path`"""labels_path = os.path.join(path,'%s-labels-idx1-ubyte.gz'% kind)images_path = os.path.join(path,'%s-images-idx3-ubyte.gz'% kind)with gzip.open(labels_path, 'rb') as lbpath:labels = np.frombuffer(lbpath.read(), dtype=np.uint8,offset=8)with gzip.open(images_path, 'rb') as imgpath:images = np.frombuffer(imgpath.read(), dtype=np.uint8,offset=16).reshape(len(labels), 28, 28) # 关键点return images, labelsdef load_data(path):train_images, train_labels = load_mnist(path, kind='train')test_images, test_labels = load_mnist(path, kind='t10k')return (train_images, train_labels), (test_images, test_labels)def load_data():return load_data('data/fashion')
参考文章
使用matplotlib.pyplot.imshow() 显示图像时出现“TypeError: Invalid dimensions for image data”的问题
如何加载mnist和fashion-mnist数据集
Fashion MNIST的下载与导入
Tensorflow学习第1课——从本地加载MNIST以及FashionMNIST数据
TensorFlow——[基本图像分类]fashion-mnist及mnist_reader.py运行错误[TypeError: Invalid dimensions for image data]
本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

