1.Pytorch内置的Dataset
Pytorch中内置了许多数据集,我们可以从torchvision
库中进行导入。比如,我们可以导入Fashion-MNIST数据集
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
但如果torchvision
库中没有该数据集,我们需要自己构建一个。
其中一个方法就是把构建好的数据集使用torch.utils.data.TensorDataset()
封装以下,然后再传入torch.utils.data.DataLoader
trainloader = torch.utils.data.DataLoader(training_data, batch_size=32, shuffle=True)
testloader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False)
但是如果自己写一个类的话会高达上一些,嘻嘻。下面看看如何自己构建一个Dataset Class。
2.Build Custom Dataset
构建一个Custom Dataset需要继承``三个函数__init__
, __len__
, 和 __getitem__
。
-
__init__
: 对类进行初始化 -
__len__
: 使该类可以返回dataset样本数量 -
__getitem__
: 给定一个idx
,从数据集中导入并返回一个样本
下面我们来看看该如何构建Custom Dataset:
import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file) # load label
self.img_dir = img_dir
self.transform = transform # transformation
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels) # 返回sample的个数
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path) # load idx-th image
label = self.img_labels.iloc[idx, 1] # load idx-th label
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
注意:同时,__len__
控制着产生样本的总个数。例如,如果总共有20个样本,我们希望20个样本全都放入dataloader中,则:
def __len(self):
return 20
但如果我们只希望有20个样本中的15个放入到dataloader中,则:
def __len(self):
return 15
但值得注意的是,return
返回的数不能大于样本的总个数,即要小于等于20。并且,当返回的数小于总样本个数的时候,是取索引的前几个数,最后的几个数不会被放入dataloader中。比如return 15
,是将data[:15]个数放入dataloader,而后5个数要舍去。可以用如下代码验证:
>>> data = np.arange(15).reshape(5,3)
>>> print(data)
array([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11],
[12, 13, 14]])
>>> class Data(Dataset):
... def __init__(self, data) -> None:
... super(Data, self).__init__()
... self.data = data
... def __len__(self):
... return 4
... def __getitem__(self, index):
... out = self.data[index]
... return torch.from_numpy(out)
>>> loader = DataLoader(Data(data), batch_size=4, shuffle=True)
>>> for i, x in enumerate(loader):
... print(i, x)
0 tensor([[ 3, 4, 5],
[ 9, 10, 11],
[ 0, 1, 2],
[ 6, 7, 8]])
可以发现,无论如何都不会输出[12, 13, 14]
。
Reference:
Pytorch official tutorial
Writing custom datasets dataloaders and transforms