构建自定义PyTorch Dataset类

在深度学习项目中,数据加载和预处理是一个至关重要的步骤。PyTorch提供了灵活而强大的工具,使这一过程变得简单高效。在本
博客中,我将向你展示如何使用PyTorch构建一个自己的Dataset类,以便于加载和转换数据。

为什么需要自定义Dataset类?

在处理机器学习任务时,我们经常需要处理各种格式的数据。PyTorch的Dataset类为我们提供了一个统一的接口,通过实现该接口,我们可以轻松地使我们的数据适配于PyTorch的数据加载器(DataLoader),进而利用多线程等高级功能来提高数据加载的效率。

如何实现自定义Dataset类?

自定义的Dataset类需要继承自PyTorch的Dataset类,并至少实现以下三个方法:__init__, __len__, 和 __getitem__

Talk is cheap,show me the code.

# 导入必要的库
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

# 定义Dataset子类
class CustomDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        """
        初始化函数
        :param data: 包含特征的numpy数组或者列表
        :param labels: 包含标签的numpy数组或者列表
        :param transform: 一个可选参数,用于数据增强
        """
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        """
        返回数据集中数据的总数
        """
        return len(self.data)

    def __getitem__(self, idx):
        """
        根据索引idx返回一个样本及其对应的标签
        """
        sample = self.data[idx]
        label = self.labels[idx]

        if self.transform:
            sample = self.transform(sample)

        return sample, label

1. __init__ 方法

  在__init__方法中,我们需要准备好数据列表或者能够找到数据的文件路径。在上面的代码中传递的data和label参数可以直接是数据,也可以是其路径列表,然后在调用类对象时可以据此列表依次加载数据和标签。一般情况下,还会设置一个可选参数transform,代表对数据做的一些增强操作。以下代码以加载分类数据集为例,传递的是数据集路径和一些可选参数。在__init__方法中,将路径所指向的分类数据集的文件类别和数据进行扫描,生成一个数据路径列表self.data,并且通过Compose将各增强方法构建为一个数据增强管道self.transform

from torch.utils.data import Dataset
from torchvision.transforms import Compose, Resize, ColorJitter, RandomHorizontalFlip, RandomRotation, ToTensor, Normalize
import os
from PIL import Image

class ClassifyDataset(Dataset):  
    def __init__(self, data_path, mean=None, std=None, resize=None, transforms=None):
        super().__init__()
        self.classes = [d.name for d in os.scandir(data_path) if d.is_dir()]
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        self.data = []
        for idx, name in enumerate(self.classes):
            sub_path = os.path.join(data_path, name)
            img_lists = os.listdir(sub_path)
            for img in img_lists:
                self.data.append([os.path.join(sub_path, img), idx]) 
        self.transform = transforms if transforms is not None else Compose([
            Resize(size=resize if resize else (224, 224)), 
            ColorJitter(), 
            RandomHorizontalFlip(), 
            RandomRotation(15), 
            ToTensor(), 
            Normalize(mean=mean, std=std)
        ])

2. __len__ 方法

  __len__方法返回数据集中数据的总数,这个方法的实现很简单,可以直接返回__init__方法中的self.data列表的长度:

def __len__(self):
    return len(self.data)

3. __getitem__ 方法

  __getitem__方法负责迭代的从数据集中加载一组数据,并进行必要的处理(如读取图片和标签、图像增强等),最后返回处理后的图像标签数据对:

def __getitem__(self, idx):
    img_path, label = self.data[idx]
    img = Image.open(img_path).convert('RGB')
    img = self.transform(img)
    return {'image': img, 'label': label}

数据增强(Transforms)

  数据增强是数据预处理的重要部分,PyTorch通过torchvision.transforms提供了丰富的数据转换工具,在上面代码的实现中,self.transform集合了torchvisionz中实现的几种图像增强方法,并且self.transform方法可以同时传递图像和标签数据,以实现一些会影响到标签的图像变换操作。除了torchvision中实现的方法外,我们也可以自己实现图像变换方法,然后在self.transform中进行组合成需要的数据集增强pipeline,只需要保证其中各方法的输入输出格式相匹配,保证流程畅通。除此之外albumentations是一个功能强大的专用于图像增强的库,可以提供更多高级的变换方法。以下是一些transforms应用举例。

使用albumentations库

import albumentations as A
from torchvision.transforms import ToTensorV2

transform = A.Compose([
    A.RandomCrop(width=256, height=256),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.pytorch.ToTensorV2()
])

image = Image.open("image.jpg").convert('RGB')
transformed = transform(image=np.array(image))
transformed_image = transformed["image"]

自定义方法

以下一个自定义的Resizer转换为例,通过它实现对图像进行最大限度的等比例缩放,然后填充到目标大小,其中输入是图像数据,输出是经过resize和padding后变成target_sizes的图像。

import numpy as np

class Resizer(object):
    def __init__(self, target_sizes=None):
        self.target_sizes = target_sizes if target_sizes is not None else [480, 640]

    def __call__(self, image):
        image = np.array(image)
        h, w = image.shape[:2]
        scale = min(self.target_sizes[0] / h, self.target_sizes[1] / w)
        nh, nw = int(h * scale), int(w * scale)
        image_resized = np.array(Image.fromarray(image).resize((nw, nh)))

        pad_h = self.target_sizes[0] - nh
        pad_w = self.target_sizes[1] - nw
        image_paded = np.pad(image_resized, ((0, pad_h), (0, pad_w), (0, 0)), mode='constant', constant_values=0)
        
        return image_paded
最后修改:2024 年 02 月 15 日
如果觉得我的文章对你有用,请随意赞赏