pytorch学习(十二)—迁移学习Transfer Learning

前言

在训练深度学习模型时,有时候我们没有海量的训练样本,只有少数的训练样本(比如几百个图片),几百个训练样本显然对于深度学习远远不够。这时候,我们可以使用别人预训练好的网络模型权重,在此基础上进行训练,这就引入了一个概念——迁移学习(Transfer Learning)


迁移学习

What(什么是迁移学习)

迁移学习(Transfer Learning,TL)对于人类来说,就是掌握举一反三的学习能力。比如我们学会骑自行车后,学骑摩托车就很简单了;在学会打羽毛球之后,再学打网球也就没那么难了。对于计算机而言,所谓迁移学习,就是能让现有的模型算法稍加调整即可应用于一个新的领域和功能的一项技术

How(如何进行迁移学习)

  • 首先需要选择一个预训练好的模型,需要注意的是该模型的训练过程最好与我们要进行训练的任务相似。比如我们要训练一个Cat,dog图像分类的模型,最好应该选择一个图像分类的预训练模型。

  • 针对实际任务,对网络结构进行调整。比如找到了一个预训练好的AlexNet(1000类别), 但是我们实际的任务的2分类,因此需要把最后一层的全连接输出改为2.

Why(为何要使用迁移学习)

https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html

In practice, very few people train an entire Convolutional Network from scratch (with random initialization), because it is relatively rare to have a dataset of sufficient size. Instead, it is common to pretrain a ConvNet on a very large dataset (e.g. ImageNet, which contains 1.2 million images with 1000 categories), and then use the ConvNet either as an initialization or a fixed feature extractor for the task of interest.


目的

  • 了解ResNet
  • 基于预训练好的ResNet-18, 进行一个图像二分类迁移学习

开发/测试环境

  • Ubuntu 18.04
  • pycharm
  • Anaconda3, python3.6
  • pytorch1.0, torchvision

ResNet-18

image.png

实验内容

准备数据集

  • 训练集合
  • 验证集合

数据集下载链接

下载好之后,复制到工程 /data/ 路径下


image.png

训练集合,验证集合


image.png

训练集,验证集 分别包含2个子文件夹,这是一个2分类问题。分类对象:蚂蚁,蜜蜂


image.png
  • 代码
    因为训练一个2分类的模型,数据集加载直接使用pytorch提供的API——ImageFolder最方便。原始图像为jpg格式,在制作数据集时候进行了变换transforms。 加入对GPU的支持,首先判断torch.cuda.is_available(),然后决定使用GPU or CPU
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
import torchvision
from torchvision.transforms import transforms
from torchvision import models
from torchvision.models import ResNet
import numpy as np
import matplotlib.pyplot as plt
import os
import utils


data_dir = './data/hymenoptera_data'

train_dataset = torchvision.datasets.ImageFolder(root=os.path.join(data_dir, 'train'),
                                                 transform=transforms.Compose(
                                                     [
                                                         transforms.RandomResizedCrop(224),
                                                         transforms.RandomHorizontalFlip(),
                                                         transforms.ToTensor(),
                                                         transforms.Normalize(
                                                             mean=(0.485, 0.456, 0.406),
                                                             std=(0.229, 0.224, 0.225))
                                                     ]))

val_dataset = torchvision.datasets.ImageFolder(root=os.path.join(data_dir, 'val'),
                                               transform=transforms.Compose(
                                                     [
                                                         transforms.RandomResizedCrop(224),
                                                         transforms.RandomHorizontalFlip(),
                                                         transforms.ToTensor(),
                                                         transforms.Normalize(
                                                             mean=(0.485, 0.456, 0.406),
                                                             std=(0.229, 0.224, 0.225))
                                                     ]))

train_dataloader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=4)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=4, shuffle=4)

# 类别名称
class_names = train_dataset.classes
print('class_names:{}'.format(class_names))

# 训练设备  CPU/GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('trian_device:{}'.format(device.type))

# 随机显示一个batch
plt.figure()
utils.imshow(next(iter(train_dataloader)))
plt.show()

获取预训练模型

torchvision.models
torchvision中包含了一些常见的预训练模型:

image.png

AlexNet, VGG, SqueezeNet, Resnet,Inception, DenseNet

此次实验采用ResNet18网络模型。
torchvision.models中包含resnet18,首先会实例化一个ResNet网络, 然后model.load_dict()加载预训练好的模型。

def resnet18(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    return model

torchvision 默认将模型保存在/home/.torch/models路径。

image.png

预训练模型文件:


image.png
  • 代码
    加载预训练模型。需要注意的地方:修改ResNet最后一个全连接层的输出个数,二分类问题需要将输出个数改为2。
# -------------------------模型选择,优化方法, 学习率策略----------------------
model = models.resnet18(pretrained=True)

# 全连接层的输入通道in_channels个数
num_fc_in = model.fc.in_features

# 改变全连接层,2分类问题,out_features = 2
model.fc = nn.Linear(num_fc_in, 2)

# 模型迁移到CPU/GPU
model = model.to(device)

# 定义损失函数
loss_fc = nn.CrossEntropyLoss()

# 选择优化方法
optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)

# 学习率调整策略
# 每7个epoch调整一次
exp_lr_scheduler = lr_scheduler.StepLR(optimizer=optimizer, step_size=10, gamma=0.5)  # step_size


训练,测试网络

Epoch: 训练50个epoch
注意地方: 训练时候,需要调用model.train()将模型设置为训练模式。测试时候,调用model.eval() 将模型设置为测试模型,否则训练和测试结果不正确。

# ----------------训练过程-----------------
num_epochs = 50

for epoch in range(num_epochs):

    running_loss = 0.0
    exp_lr_scheduler.step()

    for i, sample_batch in enumerate(train_dataloader):
        inputs = sample_batch[0]
        labels = sample_batch[1]

        model.train()

        # GPU/CPU
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        # foward
        outputs = model(inputs)

        # loss
        loss = loss_fc(outputs, labels)

        # loss求导,反向
        loss.backward()

        # 优化
        optimizer.step()

        #
        running_loss += loss.item()

        # 測試
        if i % 20 == 19:
            correct = 0
            total = 0
            model.eval()
            for images_test, labels_test in val_dataloader:
                images_test = images_test.to(device)
                labels_test = labels_test.to(device)

                outputs_test = model(images_test)
                _, prediction = torch.max(outputs_test, 1)
                correct += (torch.sum((prediction == labels_test))).item()
               # print(prediction, labels_test, correct)
                total += labels_test.size(0)
            print('[{}, {}] running_loss = {:.5f} accurcay = {:.5f}'.format(epoch + 1, i + 1, running_loss / 20,
                                                                        correct / total))
            running_loss = 0.0

        # if i % 10 == 9:
        #     print('[{}, {}] loss={:.5f}'.format(epoch+1, i+1, running_loss / 10))
        #     running_loss = 0.0

print('training finish !')
torch.save(model.state_dict(), './model/model_2.pth')

训练输出结果

image.png
image.png
image.png
image.png

随着训练次数增加,accuracy基本上是上升趋势,最终达到93%的准确率。

image.png
image.png
image.png

完整代码

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
import torchvision
from torchvision.transforms import transforms
from torchvision import models
from torchvision.models import ResNet
import numpy as np
import matplotlib.pyplot as plt
import os
import utils


data_dir = './data/hymenoptera_data'

train_dataset = torchvision.datasets.ImageFolder(root=os.path.join(data_dir, 'train'),
                                                 transform=transforms.Compose(
                                                     [
                                                         transforms.RandomResizedCrop(224),
                                                         transforms.RandomHorizontalFlip(),
                                                         transforms.ToTensor(),
                                                         transforms.Normalize(
                                                             mean=(0.485, 0.456, 0.406),
                                                             std=(0.229, 0.224, 0.225))
                                                     ]))

val_dataset = torchvision.datasets.ImageFolder(root=os.path.join(data_dir, 'val'),
                                               transform=transforms.Compose(
                                                     [
                                                         transforms.RandomResizedCrop(224),
                                                         transforms.RandomHorizontalFlip(),
                                                         transforms.ToTensor(),
                                                         transforms.Normalize(
                                                             mean=(0.485, 0.456, 0.406),
                                                             std=(0.229, 0.224, 0.225))
                                                     ]))

train_dataloader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=4)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=4, shuffle=4)

# 类别名称
class_names = train_dataset.classes
print('class_names:{}'.format(class_names))

# 训练设备  CPU/GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('trian_device:{}'.format(device.type))

# 随机显示一个batch
#plt.figure()
#utils.imshow(next(iter(train_dataloader)))
#plt.show()

# -------------------------模型选择,优化方法, 学习率策略----------------------
model = models.resnet18(pretrained=True)

# 全连接层的输入通道in_channels个数
num_fc_in = model.fc.in_features

# 改变全连接层,2分类问题,out_features = 2
model.fc = nn.Linear(num_fc_in, 2)

# 模型迁移到CPU/GPU
model = model.to(device)

# 定义损失函数
loss_fc = nn.CrossEntropyLoss()

# 选择优化方法
optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)

# 学习率调整策略
# 每7个epoch调整一次
exp_lr_scheduler = lr_scheduler.StepLR(optimizer=optimizer, step_size=10, gamma=0.5)  # step_size


# ----------------训练过程-----------------
num_epochs = 50

for epoch in range(num_epochs):

    running_loss = 0.0
    exp_lr_scheduler.step()

    for i, sample_batch in enumerate(train_dataloader):
        inputs = sample_batch[0]
        labels = sample_batch[1]

        model.train()

        # GPU/CPU
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        # foward
        outputs = model(inputs)

        # loss
        loss = loss_fc(outputs, labels)

        # loss求导,反向
        loss.backward()

        # 优化
        optimizer.step()

        #
        running_loss += loss.item()

        # 測試
        if i % 20 == 19:
            correct = 0
            total = 0
            model.eval()
            for images_test, labels_test in val_dataloader:
                images_test = images_test.to(device)
                labels_test = labels_test.to(device)

                outputs_test = model(images_test)
                _, prediction = torch.max(outputs_test, 1)
                correct += (torch.sum((prediction == labels_test))).item()
               # print(prediction, labels_test, correct)
                total += labels_test.size(0)
            print('[{}, {}] running_loss = {:.5f} accurcay = {:.5f}'.format(epoch + 1, i + 1, running_loss / 20,
                                                                        correct / total))
            running_loss = 0.0

        # if i % 10 == 9:
        #     print('[{}, {}] loss={:.5f}'.format(epoch+1, i+1, running_loss / 10))
        #     running_loss = 0.0

print('training finish !')
torch.save(model.state_dict(), './model/model_2.pth')


End

参考:
https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
https://blog.csdn.net/sunqiande88/article/details/80100891

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 204,053评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,527评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 150,779评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,685评论 1 276
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,699评论 5 366
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,609评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,989评论 3 396
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,654评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,890评论 1 298
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,634评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,716评论 1 330
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,394评论 4 319
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,976评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,950评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,191评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 44,849评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,458评论 2 342

推荐阅读更多精彩内容