训练图像分类

Note

图像分类型任务训练时的步骤都是类似的,只是模型和数据集不同。
所以我们可以先定义好训练图像分类的函数,模型和数据集作其参数,这样会很方便。

一些辅助函数和类

import torch
import d2l


#@save
def try_gpu():
    """尽量使用gpu"""
    return torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
#@save
class Accumulator:
    """累计n个数据"""
    
    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
#@save
def correct_predictions(y_hat, y):
    """
    :param y_hat: (n_samples, n_categories)
    :param y: (n_samples, )
    :return: 正确预测的个数
    """
    y_hat = y_hat.argmax(axis=1)  # across columns
    is_correct = y_hat.type(y.dtype) == y
    return float(is_correct.type(y.dtype).sum())
#@save
def accuracy(net, data_iter, device):
    """
    :param net: 模型
    :param data_iter: 图像分类数据集
    :param device: 尽量使用GPU
    :return: 模型的准确率,这里使用了Accumulator和correct_predictions
    """
    net.eval()  # Set the model to evaluation mode
    metric = d2l.Accumulator(2)  # No. of correct predictions, no. of predictions
    # 预测时需no_grad
    with torch.no_grad():
        for X, y in data_iter:
            X, y = X.to(device), y.to(device)
            # y.numel()表示y中的数据数
            metric.add(d2l.correct_predictions(net(X), y), y.numel())
    return metric[0] / metric[1]

动画

为了让我们的训练过程更加直观,我们实现一个展示训练过程中各项数据动态变化的类

from IPython import display
import matplotlib.pyplot as plt


#@save
def use_svg_display():
    """使用svg格式"""
    display.set_matplotlib_formats('svg')
#@save
def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):
    """设置坐标轴"""
    # 设置坐标标签
    axes.set_xlabel(xlabel)
    axes.set_ylabel(ylabel)
    # 设置比例尺,{`linear`, `log`, ...}
    axes.set_xscale(xscale)
    axes.set_yscale(yscale)
    # 设置x轴和y轴的显示范围
    axes.set_xlim(xlim)
    axes.set_ylim(ylim)
    # 加上图例、网格
    if legend:
        axes.legend(legend)
    axes.grid()
#@save
class Animator:
    """动态画折线图"""

    def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,
                 ylim=None, xscale='linear', yscale='linear',
                 fmts=('-', 'm--', 'g-.', 'r:'), figsize=(3.5, 2.5)):
        """参数都是 matplotlib 画图的参数"""
        # 使用svg格式
        d2l.use_svg_display()
        # 获得画布和坐标轴
        self.fig, self.axes = plt.subplots(figsize=figsize)
        # config_axes() 即 d2l.set_axes(self.axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
        self.config_axes = lambda: d2l.set_axes(self.axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
        self.X, self.Y, self.fmts = None, None, fmts

    def add(self, x, y):
        """Add multiple data points into the figure"""
        if not hasattr(y, "__len__"):
            y = [y]
        # Total n curves
        n = len(y)
        if not hasattr(x, "__len__"):
            x = [x] * n
        # initialization
        if not self.X:
            self.X = [[] for _ in range(n)]
        if not self.Y:
            self.Y = [[] for _ in range(n)]
        # 添加数据
        for i, (a, b) in enumerate(zip(x, y)):
                self.X[i].append(a)
                self.Y[i].append(b)
        self.axes.cla()  # 清除子图目前状态,防止重叠
        for x, y, fmt in zip(self.X, self.Y, self.fmts):
            self.axes.plot(x, y, fmt)
        self.config_axes()
        display.display(self.fig)
        # 不是多图而是动态
        display.clear_output(wait=True)

训练图像分类的函数

分类问题的损失函数CrossEntropyLoss:

\[\mbox{loss}(x, class) = -\mbox{log}\left(\frac{\mbox{exp}(x[class])}{\sum_{j}\mbox{exp}(x[j])}\right) = -x[class] + \log\left ({\sum_{j}\exp({x[j]})}\right )\]
#@save
def train_image_classifier(net, train_iter, test_iter, learning_rate, num_epochs):
    """
    训练图像分类器,记录数据并打印
    e.g. training FashionMNIST
    """
    device = d2l.try_gpu()
    # 需模型和数据均转向device
    net.to(device=device)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
    # 记录误差和、正确预测样本数、总样本数
    metric = d2l.Accumulator(3)
    # 画训练误差、训练准确率、测试准确率
    animator = d2l.Animator(xlabel="epoch", xlim=[1, num_epochs], ylim=[0,1],
                            legend=["train_loss", "train_acc", "test_acc"])
    for epoch in range(num_epochs):
        net.train()  # 因为计算accuracy会使net转向eval模式
        metric.reset()
        for x, y in train_iter:
            # Compute prediction error
            x, y = x.to(device), y.to(device)
            y_hat = net(x)
            loss = loss_fn(y_hat, y)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # 记录数据
            metric.add(float(loss) * len(y), d2l.correct_predictions(y_hat, y), y.numel())
        # 画图
        animator.add(epoch + 1, 
                     (metric[0] / metric[2], metric[1] / metric[2], d2l.accuracy(net, test_iter, device)))
    # 打印最终的数据
    print(f"loss {animator.Y[0][-1]:.3f}, "
          f"train acc {animator.Y[1][-1]:3f}, "
          f"test acc {animator.Y[2][-1]: 3f}")