本文共 1799 字,大约阅读时间需要 5 分钟。
本文是学习 的笔记,具体解释请参考原文。
使用到的包主要是torchvision
,它主要由以下几部分构成:
torchvision.datasets
:一些加载数据的函数及常用的数据集接口;torchvision.models
:包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;torchvision.transforms
:常用的图片变换,例如裁剪、旋转等;torchvision.utils
:其他的一些有用的方法。1、下载数据
当不使用transform=torchvision.transforms.ToTensor()
时,获取到的数据是尺寸为(H×W×C)且数据位于[0, 255]之间的 PIL 图像或者数据类型为 unit8 的 Numpy 数组。 该语句将上述类型的数据,转换为尺寸为(C×H×W)且数据类型为 torch.float32 且位于[0.0, 1.0]之间的 Tensor 。
import torchvisionmnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=torchvision.transforms.ToTensor())mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=torchvision.transforms.ToTensor())
2、获取数据标签
下载完数据后,需要能够找到数据对应的标签。以下函数可以将数值标签转成相应的文本标签。
def get_fashion_mnist_labels(labels): text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'] return [text_labels[int(i)] for i in labels]
torch.utils
中data
的一个方法DataLoader
能够很方便的读取 batch_size
大小的数据,三个常用的三个参数分别是dataset、batch_size、shuffle(是否不按顺序读取数据)
import torch.utils.data as Databatch_size = 256train_iter = Data.DataLoader(dataset=mnist_train, batch_size=batch_size, shuffle=True)test_iter = Data.DataLoader(dataset=mnist_test, batch_size=batch_size, shuffle=True)
下面定义一个可以在一行里画出多张图像和对应标签的函数。
# 本函数已保存在d2lzh包中方便以后使用def show_fashion_mnist(images, labels): d2l.use_svg_display() # 这里的_表示我们忽略(不使用)的变量 _, figs = plt.subplots(1, len(images), figsize=(12, 12)) for f, img, lbl in zip(figs, images, labels): f.imshow(img.view((28, 28)).numpy()) f.set_title(lbl) f.axes.get_xaxis().set_visible(False) f.axes.get_yaxis().set_visible(False) plt.show()
转载地址:http://apqqf.baihongyu.com/