小黑的Python日记:Pytorch简单实现GoogleNet

大家好呀,我是小黑喵

核心:

1.1*1的kernel_size能够有效减少参数数量
2.Inception block

Inception结构

Inception Module

网络结构

Googlenet

忽略辅助分类:使用CIFAR10数据集

代码:

###author:xiaoheimiao
import torch
import torchvision
import matplotlib.pyplot as plt
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#定义conv-bn-relu函数
def conv_relu(in_channel, out_channel, kernel, stride=1, padding=0):
    conv = nn.Sequential(
        nn.Conv2d(in_channel, out_channel, kernel, stride, padding),
        nn.BatchNorm2d(out_channel, eps=1e-3),
        nn.ReLU(True),
    )
    return conv

#定义incepion结构,见inception图
class inception(nn.Module):
    def __init__(self, in_channel, out1_1, out2_1, out2_3, out3_1, out3_5,
                 out4_1):
        super(inception, self).__init__()
        self.branch1 = conv_relu(in_channel, out1_1, 1)
        self.branch2 = nn.Sequential(
            conv_relu(in_channel, out2_1, 1),
            conv_relu(out2_1, out2_3, 3, padding=1))
        self.branch3 = nn.Sequential(
            conv_relu(in_channel, out3_1, 1),
            conv_relu(out3_1, out3_5, 5, padding=2))
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(3, stride=1, padding=1),
            conv_relu(in_channel, out4_1, 1),
        )

    def forward(self, x):
        b1 = self.branch1(x)
        b2 = self.branch2(x)
        b3 = self.branch3(x)
        b4 = self.branch4(x)
        output = torch.cat([b1, b2, b3, b4], dim=1)
        return output

# 堆叠GOOGLENET,见上表所示结构
class GOOGLENET(nn.Module):
    def __init__(self):
        super(GOOGLENET, self).__init__()
        self.features = nn.Sequential(
            conv_relu(3, 64, 7, 2, 3), nn.MaxPool2d(3, stride=2, padding=0),
            conv_relu(64, 64, 1), conv_relu(64, 192, 3, padding=1),
            nn.MaxPool2d(3, 2), inception(192, 64, 96, 128, 16, 32, 32),
            inception(256, 128, 128, 192, 32, 96, 64), nn.MaxPool2d(
                3, stride=2), inception(480, 192, 96, 208, 16, 48, 64),
            inception(512, 160, 112, 224, 24, 64, 64),
            inception(512, 128, 128, 256, 24, 64, 64),
            inception(512, 112, 144, 288, 32, 64, 64),
            inception(528, 256, 160, 320, 32, 128, 128), nn.MaxPool2d(3, 2),
            inception(832, 256, 160, 320, 32, 128, 128),
            inception(832, 384, 182, 384, 48, 128, 128), nn.AvgPool2d(2))
        self.classifier = nn.Sequential(
            nn.Linear(9216,1024),
            nn.Dropout2d(p=0.4),
            nn.Linear(1024, 10))

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        out = self.classifier(x)
        return out

#训练函数
def net_train():
    net.train()
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # 将输入传入GPU
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # 将梯度置零
        optimizer.zero_grad()

        # 前向传播-计算误差-反向传播-优化
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # 计算误差并显示
        running_loss += loss.item()
        if i % 60 == 0:    # print every 60 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 60))
            running_loss = 0.0

    print('Training Epoch Finished')

#测试函数
def net_test():
    correct = 0
    total = 0
    # 关闭梯度
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))
    return

#数据集函数
def net_dataloader(root, train_transform, test_transform):
    trainset = torchvision.datasets.CIFAR10(
        root, train=True, transform=train_transform, download=False)
    testset = torchvision.datasets.CIFAR10(
        root, train=False, transform=test_transform, download=False)
    trainloader = DataLoader(
        trainset, batch_size=60, shuffle=True, num_workers=2)
    testloader = DataLoader(
        testset, batch_size=8, shuffle=False, num_workers=2)
    print('Initializing Dataset...')
    return trainloader, testloader

#main
if __name__ == "__main__":
    # 创建实例并送入GPU
    net = GOOGLENET().to(device)
    # 选择误差
    criterion = nn.CrossEntropyLoss()
    # 选择优化器
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
    # 数据位置
    root = './pydata/data/'
    # 数据处理
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    test_transform = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    # 创建数据loader
    trainloader, testloader = net_dataloader(root, train_transform,
                                             test_transform)
    # run
    n_epoch = 5#改变epoch
    for epoch in range(n_epoch):
        net_train()#每个epoch训练一次,测试一次
        net_test()
喵喵喵
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 196,165评论 5 462
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 82,503评论 2 373
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 143,295评论 0 325
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 52,589评论 1 267
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 61,439评论 5 358
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 46,342评论 1 273
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 36,749评论 3 387
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 35,397评论 0 255
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 39,700评论 1 295
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 34,740评论 2 313
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 36,523评论 1 326
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 32,364评论 3 314
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 37,755评论 3 300
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,024评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 30,297评论 1 251
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 41,721评论 2 342
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 40,918评论 2 336

推荐阅读更多精彩内容