本篇文章介绍了使用PyTorch在MNIST数据集上训练MLP和CNN,并记录自己实现过程中的若干问题。
加载MNIST数据集
PyTorch中提供了MNIST,CIFAR,COCO等常用数据集的加载方法。MNIST
是torchvision.datasets
包中的一个类,负责根据传入的参数加载数据集。如果自己之前没有下载过该数据集,可以将download
参数设置为True
,会自动下载数据集并解包。如果之前已经下载好了,只需将其路径通过root
传入即可。
在加载图像后,我们常常需要对图像进行若干预处理。比如减去RGB通道的均值,或者裁剪或翻转图像实现augmentation等,这些操作可以在torchvision.transforms
包中找到对应的操作。在下面的代码中,通过使用transforms.Compose()
,我们构造了对数据进行预处理的复合操作序列,ToTensor
负责将PIL图像转换为Tensor数据(RGB通道从[0, 255]
范围变为[0, 1]
), Normalize
负责对图像进行规范化。这里需要注意,虽然MNIST中图像都是灰度图像,通道数均为1,但是仍要传入tuple
。
之后,我们通过DataLoader
返回一个数据集上的可迭代对象。一会我们通过for
循环,就可以遍历数据集了。
1 | import torch |
网络构建
在进行网络构建时,主要通过torch.nn
包中的已经实现好的卷积层、池化层等进行搭建。例如下面的代码展示了一个具有一个隐含层的MLP网络。nn.Linear
负责构建全连接层,需要提供输入和输出的通道数,也就是y = wx+b
中x
和y
的维度。
1 | class MLPNet(nn.Module): |
由于PyTorch可以实现自动求导,所以我们只需实现forward
过程即可。这里由于池化层和非线性变换都没有参数,所以使用了nn.functionals
中的对应操作实现。通过看文档,可以发现,一般nn
里面的各种层,都会在nn.functionals
里面有其对应。例如卷积层的对应实现,如下所示,需要传入卷积核的权重。
1 | # With square kernels and equal stride |
同样地,我们可以实现LeNet的结构如下。
1 | class LeNet(nn.Module): |
训练与测试
在训练时,我们首先应确定优化方法。这里我们使用带动量的SGD
方法。下面代码中的optim.SGD
初始化需要接受网络中待优化的Parameter
列表(或是迭代器),以及学习率lr
,动量momentum
。
1 | optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) |
接下来,我们只需要遍历数据集,同时在每次迭代中清空待优化参数的梯度,前向计算,反向传播以及优化器的迭代求解即可。
1 | ## training |
当优化完毕后,需要保存模型。这里官方文档给出了推荐的方法,如下所示:1
2
3torch.save(model.state_dict(), PATH) #保存网络参数
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH)) #读取网络参数
该博客的完整代码可以见:PyTorch MNIST demo。