Task02 PyTorch进阶训练技巧

参考链接:https://github.com/datawhalechina/thorough-pytorch

本task注重于pytorch在实际使用中的一些操作~较为实用

1.自定义损失函数

PyTorch在torch.nn模块为我们提供了许多常用的损失函数,比如:MSELoss,L1Loss,BCELoss...... 但是随着深度学习的发展,出现了越来越多的非官方提供的Loss,比如DiceLoss,HuberLoss,SobolevLoss...... 这些Loss Function专门针对一些非通用的模型,PyTorch不能将他们全部添加到库中去,因此这些损失函数的实现则需要我们通过自定义损失函数来实现。另外,在科学研究中,我们往往会提出全新的损失函数来提升模型的表现,这时我们既无法使用PyTorch自带的损失函数,也没有相关的博客供参考,此时自己实现损失函数就显得更为重要了。

1.1 以函数方式定义

损失函数定义

1.2 以类方式定义

虽然以函数定义的方式很简单,但是以类方式定义更加常用,在以类方式定义损失函数时,我们如果看每一个损失函数的继承关系我们就可以发现Loss函数部分继承自_loss, 部分继承自_WeightedLoss, 而_WeightedLoss继承自_loss, _loss继承自 nn.Module。我们可以将其当作神经网络的一层来对待,同样地,我们的损失函数类就需要继承自nn.Module类,下面以DiceLoss为例。

Dice Loss函数

除此之外,常见的损失函数还有BCE-Dice Loss,Jaccard/Intersection over Union (IoU) Loss,Focal Loss等。

BCE-Dice Loss
Jaccard/Intersection over Union (IoU)
Focal Loss

Focal loss主要是为了解决one-stage目标检测中正负样本比例严重失衡的问题。该损失函数降低了大量简单负样本在训练中所占的权重,也可理解为一种困难样本挖掘。是 Kaiming 大神团队在论文Focal Loss for Dense Object Detection 提出来的损失函数,利用它改善了图像物体检测的效果。

2.动态调整学习率

学习率的选择是深度学习中一个困扰人们许久的问题,学习速率设置过小,会极大降低收敛速度,增加训练时间;学习率太大,可能导致参数在最优解两侧来回振荡。但是当我们选定了一个合适的学习率后,经过许多轮的训练后,可能会出现准确率震荡或loss不再下降等情况,说明当前学习率已不能满足模型调优的需求。此时我们就可以通过一个适当的学习率衰减策略来改善这种现象,提高我们的精度。这种设置方式在PyTorch中被称为scheduler。

2.1 使用官方scheduler

在训练神经网络的过程中,学习率是最重要的超参数之一,作为当前较为流行的深度学习框架,PyTorch已经在torch.optim.lr_scheduler为我们封装好了一些动态调整学习率的方法供我们使用,如下面列出的这些scheduler。

官方scheduler

关于如何使用这些动态调整学习率的策略,PyTorch官方也很人性化的给出了使用实例代码,我们在使用官方给出的torch.optim.lr_scheduler时,需要将scheduler.step()放在optimizer.step()后面进行使用。

使用官方API

2.2 自定义scheduler

虽然PyTorch官方给我们提供了许多的API,但是在实验中也有可能碰到需要我们自己定义学习率调整策略的情况,而我们的方法是自定义函数adjust_learning_rate来改变param_group中lr的值,在下面的叙述中会给出一个简单的实现。

假设我们现在正在做实验,需要学习率每30轮下降为原来的1/10,假设已有的官方API中没有符合我们需求的,那就需要自定义函数来实现学习率的改变。

自定义scheduler

3.模型微调

参考链接:https://blog.csdn.net/qq_42250789/article/details/108832004

微调定义: 给定预训练模型(Pre_trained model),基于模型进行微调(Fine Tune)。相对于从头开始训练(Training a model from scatch),微调为你省去大量计算资源和计算时间,提高了计算效率,甚至提高准确率。

预训练模型:(1) 预训练模型就是已经用数据集训练好了的模型。(2) 现在我们常用的预训练模型就是他人用常用模型,比如VGG16/19,Resnet等模型,并用大型数据集来做训练集,比如Imagenet, COCO等训练好的模型参数。 (3)  正常情况下,我们常用的VGG16/19等网络已经是他人调试好的优秀网络,我们无需再修改其网络结构。

不同数据集下使用微调:

数据集1 - 数据量少,但数据相似度非常高 - 在这种情况下,我们所做的只是修改最后几层或最终的softmax图层的输出类别。

数据集2 - 数据量少,数据相似度低 - 在这种情况下,我们可以冻结预训练模型的初始层(比如k层),并再次训练剩余的(n-k)层。由于新数据集的相似度较低,因此根据新数据集对较高层进行重新训练具有重要意义。

数据集3  - 数据量大,数据相似度低 - 在这种情况下,由于我们有一个大的数据集,我们的神经网络训练将会很有效。但是,由于我们的数据与用于训练我们的预训练模型的数据相比有很大不同。使用预训练模型进行的预测不会有效。因此,最好根据你的数据从头开始训练神经网络(Training from scatch)

数据集4  - 数据量大,数据相似度高 - 这是理想情况。在这种情况下,预训练模型应该是最有效的。使用模型的最好方法是保留模型的体系结构和模型的初始权重。然后,我们可以使用在预先训练的模型中的权重来重新训练该模型。

3.1 模型微调的流程

在源数据集(如ImageNet数据集)上预训练一个神经网络模型,即源模型。

创建一个新的神经网络模型,即目标模型。它复制了源模型上除了输出层外的所有模型设计及其参数。我们假设这些模型参数包含了源数据集上学习到的知识,且这些知识同样适用于目标数据集。我们还假设源模型的输出层跟源数据集的标签紧密相关,因此在目标模型中不予采用。

为目标模型添加一个输出⼤小为⽬标数据集类别个数的输出层,并随机初始化该层的模型参数。

在目标数据集上训练目标模型。我们将从头训练输出层,而其余层的参数都是基于源模型的参数微调得到的。

3.2 使用已有模型结构

这里我们以torchvision中的常见模型为例,列出了如何在图像分类任务中使用PyTorch提供的常见模型结构和参数。对于其他任务和网络结构,使用方式是类似的:

实例化网格:

实例化网格

传递pretrained参数:通过True或者False来决定是否使用预训练好的权重,在默认状态下pretrained = False,意味着我们不使用预训练得到的权重,当pretrained = True,意味着我们将使用在一些数据集上预训练得到的权重。

传递pretrained参数

注意事项:

通常PyTorch模型的扩展为.pt或.pth,程序运行时会首先检查默认路径中是否有已经下载的模型权重,一旦权重被下载,下次加载就不需要下载了。

一般情况下预训练模型的下载会比较慢,我们可以直接通过迅雷或者其他方式去 这里 查看自己的模型里面model_urls,然后手动下载,预训练模型的权重在Linux和Mac的默认下载路径是用户根目录下的.cache文件夹。在Windows下就是C:\Users\<username>\.cache\torch\hub\checkpoint。我们可以通过使用 torch.utils.model_zoo.load_url()设置权重的下载地址。

如果觉得麻烦,还可以将自己的权重下载下来放到同文件夹下,然后再将参数加载网络。

self.model=models.resnet50(pretrained=False)

self.model.load_state_dict(torch.load('./model/resnet50-19c8e357.pth'))

如果中途强行停止下载的话,一定要去对应路径下将权重文件删除干净,要不然可能会报错。

3.3 训练特定层

在默认情况下,参数的属性.requires_grad = True,如果我们从头开始训练或微调不需要注意这里。但如果我们正在提取特征并且只想为新初始化的层计算梯度,其他参数不进行改变。那我们就需要通过设置requires_grad = False来冻结部分层。在PyTorch官方中提供了这样一个例程。

在下面我们仍旧使用resnet18为例的将1000类改为4类,但是仅改变最后一层的模型参数,不改变特征提取的模型参数;注意我们先冻结模型参数的梯度,再对模型输出部分的全连接层进行修改,这样修改后的全连接层的参数就是可计算梯度的。

之后在训练过程中,model仍会进行梯度回传,但是参数更新则只会发生在fc层。通过设定参数的requires_grad属性,我们完成了指定训练模型的特定层的目标,这对实现模型微调非常重要。

4. 半精度训练

我们提到PyTorch时候,总会想到要用硬件设备GPU的支持,也就是“卡”。GPU的性能主要分为两部分:算力和显存,前者决定了显卡计算的速度,后者则决定了显卡可以同时放入多少数据用于计算。在可以使用的显存数量一定的情况下,每次训练能够加载的数据更多(也就是batch size更大),则也可以提高训练效率。另外,有时候数据本身也比较大(比如3D图像、视频等),显存较小的情况下可能甚至batch size为1的情况都无法实现。因此,合理使用显存也就显得十分重要。

我们观察PyTorch默认的浮点数存储方式用的是torch.float32,小数点后位数更多固然能保证数据的精确性,但绝大多数场景其实并不需要这么精确,只保留一半的信息也不会影响结果,也就是使用torch.float16格式。由于数位减了一半,因此被称为“半精度”。

显然半精度能够减少显存占用,使得显卡可以同时加载更多数据进行计算。本节会介绍如何在PyTorch中设置使用半精度计算。

4.1 半精度训练的设置

在PyTorch中使用autocast配置半精度训练,同时需要在下面三处加以设置:

注意:

半精度训练主要适用于数据本身的size比较大(比如说3D图像、视频等)。当数据本身的size并不大时(比如手写数字MNIST数据集的图片尺寸只有28*28),使用半精度训练则可能不会带来显著的提升。

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 205,033评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 87,725评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 151,473评论 0 338
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,846评论 1 277
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,848评论 5 368
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,691评论 1 282
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,053评论 3 399
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,700评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 42,856评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,676评论 2 323
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,787评论 1 333
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,430评论 4 321
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,034评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,990评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,218评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,174评论 2 352
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,526评论 2 343

推荐阅读更多精彩内容