在解决机器学习问题的时候,人们花了大量精力准备数据。准备好的数据需要进行一些预处理,比如裁剪,旋转等,然后将数据进行加载,加载时希望能够并行处理,设置batch,对数据打乱等,而pytorch提供了许多工具来让载入数据更简单,并尽量让你的代码的可读性更高。下面就针对这一部分内容进行介绍。
数据加载器
torch.utils.data里面定义了数据加载的接口。通过下面的例子进行讲解
import torch
import torch.utils.data as Data
BATCH_SIZE = 5
x = torch.linspace(1,10,10)
y = torch.linspace(10,1,10)
#数据加载
torch_dataset = Data.TensorDataset(x,y)
loader = Data.DataLoader(
dataset = torch_dataset,
batch_size = BATCH_SIZE,
shuffle = True,
num_workers = 2
)
for epoch in range(3):
for step,(batch_x,batch_y) in enumerate(loader):
print("Epoch",epoch,'|Step:',step,'|batch x:',batch_x.numpy(),'|batch y:',batch_y.numpy())
上面的例子中,torch.utils.data.DataLoader就是对数据加载的接口函数,通过设置相应的参数,实现对数据的加载,并且可以实现并行处理。
当然,通常做视觉方面的小伙伴,更多的是要对图像数据进行处理,因此,另外一个接口函数就更加重要了。
import torch
from torchvision import transforms, datasets
data_transform = transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
hymenoptera_dataset = datasets.ImageFolder(root='xxx'
transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
batch_size=4, shuffle=True,
num_workers=4)
其中,transforms.Compose中可以设置需要对图像数据进行的预处理操作,包括归一化,随机裁剪,旋转等等。之后通过datasets.ImageFolder函数对目标位置中的图像数据进行预处理。最后再通过上面讲过的torch.utils.data.DataLoader函数进行加载。此时,一个完整的数据加载过程就实现了。
参考链接:https://pytorch.apachecn.org/docs/1.0/#/