构建自定义PyTorch Dataset类
在深度学习项目中,数据加载和预处理是一个至关重要的步骤。PyTorch提供了灵活而强大的工具,使这一过程变得简单高效。在本
博客中,我将向你展示如何使用PyTorch构建一个自己的Dataset
类,以便于加载和转换数据。
为什么需要自定义Dataset类?
在处理机器学习任务时,我们经常需要处理各种格式的数据。PyTorch的Dataset
类为我们提供了一个统一的接口,通过实现该接口,我们可以轻松地使我们的数据适配于PyTorch的数据加载器(DataLoader
),进而利用多线程等高级功能来提高数据加载的效率。
如何实现自定义Dataset类?
自定义的Dataset
类需要继承自PyTorch的Dataset
类,并至少实现以下三个方法:__init__
, __len__
, 和 __getitem__
Talk is cheap,show me the code.
# 导入必要的库
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
# 定义Dataset子类
class CustomDataset(Dataset):
def __init__(self, data, labels, transform=None):
"""
初始化函数
:param data: 包含特征的numpy数组或者列表
:param labels: 包含标签的numpy数组或者列表
:param transform: 一个可选参数,用于数据增强
"""
self.data = data
self.labels = labels
self.transform = transform
def __len__(self):
"""
返回数据集中数据的总数
"""
return len(self.data)
def __getitem__(self, idx):
"""
根据索引idx返回一个样本及其对应的标签
"""
sample = self.data[idx]
label = self.labels[idx]
if self.transform:
sample = self.transform(sample)
return sample, label
1. __init__
方法
在__init__
方法中,我们需要准备好数据列表或者能够找到数据的文件路径。在上面的代码中传递的data和label参数可以直接是数据,也可以是其路径列表,然后在调用类对象时可以据此列表依次加载数据和标签。一般情况下,还会设置一个可选参数transform
,代表对数据做的一些增强操作。以下代码以加载分类数据集为例,传递的是数据集路径和一些可选参数。在__init__
方法中,将路径所指向的分类数据集的文件类别和数据进行扫描,生成一个数据路径列表self.data
,并且通过Compose
将各增强方法构建为一个数据增强管道self.transform
。
from torch.utils.data import Dataset
from torchvision.transforms import Compose, Resize, ColorJitter, RandomHorizontalFlip, RandomRotation, ToTensor, Normalize
import os
from PIL import Image
class ClassifyDataset(Dataset):
def __init__(self, data_path, mean=None, std=None, resize=None, transforms=None):
super().__init__()
self.classes = [d.name for d in os.scandir(data_path) if d.is_dir()]
self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
self.data = []
for idx, name in enumerate(self.classes):
sub_path = os.path.join(data_path, name)
img_lists = os.listdir(sub_path)
for img in img_lists:
self.data.append([os.path.join(sub_path, img), idx])
self.transform = transforms if transforms is not None else Compose([
Resize(size=resize if resize else (224, 224)),
ColorJitter(),
RandomHorizontalFlip(),
RandomRotation(15),
ToTensor(),
Normalize(mean=mean, std=std)
])
2. __len__
方法
__len__
方法返回数据集中数据的总数,这个方法的实现很简单,可以直接返回__init__
方法中的self.data
列表的长度:
def __len__(self):
return len(self.data)
3. __getitem__
方法
__getitem__
方法负责迭代的从数据集中加载一组数据,并进行必要的处理(如读取图片和标签、图像增强等),最后返回处理后的图像标签数据对:
def __getitem__(self, idx):
img_path, label = self.data[idx]
img = Image.open(img_path).convert('RGB')
img = self.transform(img)
return {'image': img, 'label': label}
数据增强(Transforms)
数据增强是数据预处理的重要部分,PyTorch通过torchvision.transforms
提供了丰富的数据转换工具,在上面代码的实现中,self.transform
集合了torchvisionz中实现的几种图像增强方法,并且self.transform
方法可以同时传递图像和标签数据,以实现一些会影响到标签的图像变换操作。除了torchvision中实现的方法外,我们也可以自己实现图像变换方法,然后在self.transform
中进行组合成需要的数据集增强pipeline
,只需要保证其中各方法的输入输出格式相匹配,保证流程畅通。除此之外albumentations是一个功能强大的专用于图像增强的库,可以提供更多高级的变换方法。以下是一些transforms应用举例。
使用albumentations库
import albumentations as A
from torchvision.transforms import ToTensorV2
transform = A.Compose([
A.RandomCrop(width=256, height=256),
A.HorizontalFlip(p=0.5),
A.RandomBrightnessContrast(p=0.2),
A.pytorch.ToTensorV2()
])
image = Image.open("image.jpg").convert('RGB')
transformed = transform(image=np.array(image))
transformed_image = transformed["image"]
自定义方法
以下一个自定义的Resizer
转换为例,通过它实现对图像进行最大限度的等比例缩放,然后填充到目标大小,其中输入是图像数据,输出是经过resize和padding后变成target_sizes的图像。
import numpy as np
class Resizer(object):
def __init__(self, target_sizes=None):
self.target_sizes = target_sizes if target_sizes is not None else [480, 640]
def __call__(self, image):
image = np.array(image)
h, w = image.shape[:2]
scale = min(self.target_sizes[0] / h, self.target_sizes[1] / w)
nh, nw = int(h * scale), int(w * scale)
image_resized = np.array(Image.fromarray(image).resize((nw, nh)))
pad_h = self.target_sizes[0] - nh
pad_w = self.target_sizes[1] - nw
image_paded = np.pad(image_resized, ((0, pad_h), (0, pad_w), (0, 0)), mode='constant', constant_values=0)
return image_paded