复现一个小网络--CIFAR-10分类

dataset介绍:

CIFAR-10数据介绍

CIFAR-10^3是一个常用的彩色图片数据集,它有10个类别: 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'。每张图片都是 3×32×32 ,也即3-通道彩色图片,分辨率为 32×32 。

下面我们来尝试实现对CIFAR-10数据集的分类,步骤如下:

1、使用torchvision加载并预处理CIFAR-10数据集

2、定义网络

3、定义损失函数和优化器

4、训练网络并更新网络参数

5、测试网络

CIFAR-10数据加载及预处理


import torchvision as tv

import torchvision.transforms as transforms

from torchvision.transforms import ToPILImage

show = ToPILImage() # 可以把Tensor转成Image,方便可视化


# 第一次运行程序torchvision会自动下载CIFAR-10数据集,

# 大约100M,需花费一定的时间,

# 如果已经下载有CIFAR-10,可通过root参数指定

# 定义对数据的预处理

transform = transforms.Compose([

        transforms.ToTensor(), # 转为Tensor

        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # 归一化

                            ])

# 训练集

trainset = tv.datasets.CIFAR10(

                    root='/home/自己数据集的地址',

                    train=True,

                    download=True,

                    transform=transform)

trainloader = t.utils.data.DataLoader(

                    trainset,

                    batch_size=4,

                    shuffle=True,

                    num_workers=2)

# 测试集

testset = tv.datasets.CIFAR10(

                    '/home/自己数据集的地址',

                    train=False,

                    download=True,

                    transform=transform)

testloader = t.utils.data.DataLoader(

                    testset,

                    batch_size=4,

                    shuffle=False,

                    num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',

          'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

定义网络

拷贝上面的LeNet网络,修改self.conv1第一个参数为3通道,因CIFAR-10是3通道彩图。


import torch.nn as nn

import torch.nn.functional as F

class Net(nn.Module):

    def __init__(self):

        super(Net, self).__init__()

        self.conv1 = nn.Conv2d(3, 6, 5)

        self.conv2 = nn.Conv2d(6, 16, 5) 

        self.fc1  = nn.Linear(16*5*5, 120) 

        self.fc2  = nn.Linear(120, 84)

        self.fc3  = nn.Linear(84, 10)

    def forward(self, x):

        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))

        x = F.max_pool2d(F.relu(self.conv2(x)), 2)

        x = x.view(x.size()[0], -1)

        x = F.relu(self.fc1(x))

        x = F.relu(self.fc2(x))

        x = self.fc3(x)       

        return x

net = Net()

print(net)

上边我们定义的网络结果

image

定义损失函数和优化器(loss和optimizer)


from torch import optim

criterion = nn.CrossEntropyLoss() # 交叉熵损失函数

optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

训练网络

所有网络的训练流程都是类似的,不断地执行如下流程:

输入数据

前向传播+反向传播

更新参数


t.set_num_threads(8)

for epoch in range(4): 

    running_loss = 0.0

    for i, data in enumerate(trainloader, 0):

        # 输入数据

        inputs, labels = data

        inputs, labels = Variable(inputs), Variable(labels)

        # 梯度清零

        optimizer.zero_grad()

        # forward + backward

        outputs = net(inputs)

        loss = criterion(outputs, labels)

        loss.backward() 

        # 更新参数

        optimizer.step()

        # 打印log信息

        running_loss += loss.item()

        if i % 2000 == 1999: # 每2000个batch打印一下训练状态

            print('[%d, %5d] loss: %.3f' \

                  % (epoch+1, i+1, running_loss / 2000))

            running_loss = 0.0

print('Finished Training')

image
image

以上结果对比,我们可以看到增加迭代的轮数是有利于网络的训练,直到loss值趋于一个平缓的值,训练完成。

来看看网络有没有效果。将测试图片输入到网络中,计算它的label,然后与实际的label进行比较。


dataiter = iter(testloader)

images, labels = dataiter.next() # 一个batch返回4张图片

print('实际的label: ', ' '.join(\

            '%08s'%classes[labels[j]] for j in range(4)))

show(tv.utils.make_grid(images / 2 - 0.5)).resize((400,100))

image

接着计算网络预测的label:


# 计算图片在每个类别上的分数

outputs = net(Variable(images))

# 得分最高的那个类

_, predicted = t.max(outputs.data, 1)

print('预测结果: ', ' '.join('%5s'\

            % classes[predicted[j]] for j in range(4)))···

预测结果: cat ship ship ship

已经可以看出效果,准确率50%,但这只是一部分的图片,再来看看在整个测试集上的效果。

···correct = 0 # 预测正确的图片数

total = 0 # 总共的图片数

for data in testloader:

    images, labels = data

    outputs = net(Variable(images))

    _, predicted = t.max(outputs.data, 1)

    total += labels.size(0)

    correct += (predicted == labels).sum()

print('10000张测试集中的准确率为: %d %%' % (100 * correct / total))

预测结果:'cat car car plane'

已经可以看出效果,准确率50%,但这只是一部分的图片,再来看看在整个测试集上的效果


correct = 0 # 预测正确的图片数

total = 0 # 总共的图片数

for data in testloader:

    images, labels = data

    outputs = net(Variable(images))

    _, predicted = t.max(outputs.data, 1)

    total += labels.size(0)

    correct += (predicted == labels).sum()

print('10000张测试集中的准确率为: %d %%' % (100 * correct / total))

image

训练的准确率远比随机猜测(准确率10%)要好很多,所以该网络还是学到了一些东西。

参考书籍:深度学习框架 pytorch入门与实践 陈云

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 204,053评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,527评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 150,779评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,685评论 1 276
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,699评论 5 366
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,609评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,989评论 3 396
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,654评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,890评论 1 298
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,634评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,716评论 1 330
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,394评论 4 319
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,976评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,950评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,191评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 44,849评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,458评论 2 342