目标检测数据集

Note

目标检测领域没有像 MNIST 和 Fashion-MNIST 那样的小数据集。
为了快速测试目标检测模型,我们使用d2l收集并标记的一个小型数据集-香蕉数据集。

读取图像和标签

import pandas as pd
import torchvision
import os
import torch


#@save
def read_data_bananas(is_train=True):
    """读取香蕉检测数据集中的图像和标签。"""
    # 数据路径
    data_dir = "../data/banana-detection/{}".format(
        'bananas_train' if is_train else 'bananas_val')
    # 含标注信息的CSV文件
    csv_data = pd.read_csv(os.path.join(data_dir, 'label.csv'))
    csv_data = csv_data.set_index('img_name')
    images, targets = [], []
    for img_name, target in csv_data.iterrows():
        images.append(torchvision.io.read_image(
            os.path.join(data_dir, 'images', f'{img_name}')))
        # Here `target` contains (class, upper-left x, upper-left y,
        # lower-right x, lower-right y), where all the images have the same
        # banana class (index 0)
        targets.append(list(target))
    return images, torch.tensor(targets).unsqueeze(1) / 256

自定义Dataset

#@save
class BananasDataset(torch.utils.data.Dataset):
    """用于加载香蕉检测数据集"""
    def __init__(self, is_train):
        self.features, self.labels = read_data_bananas(is_train)
        print('read ' + str(len(self.features)) + (f' training examples' if
              is_train else f' validation examples'))

    def __getitem__(self, idx):
        return (self.features[idx].float(), self.labels[idx])

    def __len__(self):
        return len(self.features)

load函数

#@save
def load_data_bananas(batch_size):
    """加载香蕉检测数据集。"""
    train_iter = torch.utils.data.DataLoader(BananasDataset(is_train=True),
                                             batch_size, shuffle=True)
    val_iter = torch.utils.data.DataLoader(BananasDataset(is_train=False),
                                           batch_size)
    return train_iter, val_iter