Pytorch的数据加载主要依赖torch.utils.data.Dataset和torch.utils.data.DataLoader两个模块,可以完成如下格式的傻瓜式加载。
train_dataset = CustomDataset(train_data_path)
train_loader = torch.utils.data.DataLoader(train_dataset)
1 Dataset
阅读源码后,我们可以指导,继承该方法实现3个方法:
● init():主要是数据格式的转换,还有一部分处理
● getItem():主要是从数据集里面获取数据项的item和label。
● lens():返回数据的个数
2 DataLoader
提供对Dataset的操作,操作如下:‘
torch.utils.data.DataLoader(dataset,batch_size,shuffle,drop_last,num_workers)
● dataset: 加载torch.utils.data.Dataset对象数据
● batch_size: 每个batch的大小
● shuffle:是否对数据进行打乱
● drop_last:是否对无法整除的最后一个datasize进行丢弃
● num_workers:表示加载的时候子进程数
3 案例
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
# 自定义数据集类
class CustomDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx] # 根据索引获取样本
label = self.labels[idx] # 根据索引获取标签
return sample, label
# 创建数据集实例
data = [1, 2, 3, 4, 5]
labels = [0, 1, 0, 1, 0]
X_train, X_test, y_train, y_test = train_test_split(data, labels, test_size=0.2, random_state=42)
train_dataset = CustomDataset(X_train, y_train)
test_dataset = CustomDataset(X_test, y_test)
# 创建数据加载器
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=False)
# 迭代加载训练集数据
print("训练集")
for batch in train_dataloader:
samples, labels = batch
print(samples, labels)
# 迭代加载测试集数据
print("测试集")
for batch in test_dataloader:
samples, labels = batch
print(samples, labels)
输出:
训练集
tensor([4, 1]) tensor([1, 0])
tensor([3, 5]) tensor([0, 0])
测试集
tensor([2]) tensor([1])