博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
PyTorch—— 图像分类数据集(Fashion-MNIST)
阅读量:2091 次
发布时间:2019-04-29

本文共 1799 字,大约阅读时间需要 5 分钟。

PyTorch—— 图像分类数据集(Fashion-MNIST)

本文是学习 的笔记,具体解释请参考原文。

0、前言

使用到的包主要是torchvision,它主要由以下几部分构成:

  1. torchvision.datasets:一些加载数据的函数及常用的数据集接口;
  2. torchvision.models:包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
  3. torchvision.transforms:常用的图片变换,例如裁剪、旋转等;
  4. torchvision.utils:其他的一些有用的方法。

一、获取数据集

1、下载数据

当不使用transform=torchvision.transforms.ToTensor()时,获取到的数据是尺寸为(H×W×C)且数据位于[0, 255]之间的 PIL 图像或者数据类型为 unit8 的 Numpy 数组

该语句将上述类型的数据,转换为尺寸为(C×H×W)且数据类型为 torch.float32 且位于[0.0, 1.0]之间的 Tensor

import torchvisionmnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=torchvision.transforms.ToTensor())mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=torchvision.transforms.ToTensor())

2、获取数据标签

下载完数据后,需要能够找到数据对应的标签。以下函数可以将数值标签转成相应的文本标签。

def get_fashion_mnist_labels(labels):	text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']	return [text_labels[int(i)] for i in labels]

二、读取小批量数据

torch.utilsdata的一个方法DataLoader能够很方便的读取 batch_size 大小的数据,三个常用的三个参数分别是dataset、batch_size、shuffle(是否不按顺序读取数据)

import torch.utils.data as Databatch_size = 256train_iter = Data.DataLoader(dataset=mnist_train, batch_size=batch_size, shuffle=True)test_iter = Data.DataLoader(dataset=mnist_test, batch_size=batch_size, shuffle=True)

三、画图

下面定义一个可以在一行里画出多张图像和对应标签的函数。

# 本函数已保存在d2lzh包中方便以后使用def show_fashion_mnist(images, labels):    d2l.use_svg_display()    # 这里的_表示我们忽略(不使用)的变量    _, figs = plt.subplots(1, len(images), figsize=(12, 12))    for f, img, lbl in zip(figs, images, labels):        f.imshow(img.view((28, 28)).numpy())        f.set_title(lbl)        f.axes.get_xaxis().set_visible(False)        f.axes.get_yaxis().set_visible(False)    plt.show()

转载地址:http://apqqf.baihongyu.com/

你可能感兴趣的文章
CentOS7下使用YUM安装MySQL5.6
查看>>
JVM内存空间
查看>>
Docker 守护进程+远程连接+安全访问+启动冲突解决办法 (完整收藏版)
查看>>
从零写分布式RPC框架 系列 2.0 (4)使用BeanPostProcessor实现自定义@RpcReference注解注入
查看>>
Java 设计模式 轻读汇总版
查看>>
Paxos学习笔记及图解
查看>>
深入解析Spring使用枚举接收参数和返回值机制并提供自定义最佳实践
查看>>
数据序列化框架——Kryo
查看>>
布隆过滤器(BloomFilter)——应用(三)
查看>>
MPP架构数据库优化总结——华为LibrA(MPPDB、GuassDB)与GreenPlum
查看>>
Spark代码可读性与性能优化——示例七(构建聚合器,以用于复杂聚合)
查看>>
Spark代码可读性与性能优化——示例八(一个业务逻辑,多种解决方式)
查看>>
简单理解 HTTPS
查看>>
简单理解 NAT
查看>>
RPC框架——Thrift简单示例
查看>>
RPC框架——gRPC简单示例
查看>>
JVM对象头的简单记录
查看>>
从Java代码到Java堆——理解并优化你的应用的内存使用量
查看>>
Redis持久化与过期机制
查看>>
关于在网络中使用BIO、NIO、AIO的示例
查看>>