FB等提出全新卷积操作OctConv,速度接近理论极限

引言

论文地址
这篇论文是周一时带我的大佬(现在瑞士读博士,据说还在nips上面发过文章😶,瑟瑟发抖)发给我一个一篇链接文章,博客是计划周五就要写出来的,但是由于要将maxnet的代码迁移到pytorch的resnet上面花费了一些时间。至今没见过这位大佬,我这位本科大白只是每周一阅读他发的论文和相关demo代码,改写或者迁移到现在的工业图像分类上。有想一起学习的可以加qq:1678354579进行讨论。
下面的内容由于时间有限,主要以代码实现为主。才疏学浅,如果那些错误还请大佬多多指正!

摘要

在自然图像中,信息总是在不同频率中表达的,其中高频信号一般包含丰富的细节而低频信号一般包含整体的结构。类似地,卷积层的输出特征图同样可以被看作是混合了不同频域的信息。在这项工作中,我们提出了如何根据频域去分解信息混合的特征图,并设计了一个新颖的八度卷积(Octave Convolution,OctConv)操作来保存和处理那些在较低空间分辨率下变化“较慢”(Slower)的特征图,从而减少存储和计算开销。与现有多尺度(multi-scale)方法不同的是,八度卷积被制定为一种单个通用的即插即用卷积单元,可以直接替换普通(vanilla)卷积而不需要对现有网络有任何调整。它同时也是对一些表明有着更好拓扑(topologies)或者减少通道冗余的方法的补充,并且与这些方法正交(orthogonal)。通过简单地用八度卷积替换普通卷积,我们在实验中发现我们在减少存储和计算开销的同时,还能持续提高图像和视频识别任务的准确率。一个使用八度卷积的ResNet-152网络能够在ImageNet上达到82.9%的Top-1分类准确率,而其浮点计算量仅仅只有22.2G(Giga)。

  • 总结下来就是:自然界的图像中高频的信息表示细腻而丰富的细节,低频表示整体的轮廓和布局。八度卷积最大的优点就是节省存储空间的运算力,而且有怎么如此强的功能只需要改动网络中卷积部分即可实现即插即用的功能!我的代码能力一般,大概花了一天左右的时间改写了octconv版的resnet,后期经过改动能够适应三种卷积的增强版
  • 加一句,关于低频和高频个人觉得可能搞美术的人更能理解。比如像画人物一样,大致的轮廓是差不多的,不经常改变为低频。具体的细节,一颦一动每个人都不一样为高频。本人为工科宅男一枚,献丑了😂

原理浅谈

关于详细的原理,大家可以参考论文和一片中文博客。我这里更深的理解也是来源这篇博客,推荐大家去看看。
这里我主要从个人代码理解和实现的角度来聊一聊原理,说白了就是数学公式看的有点蒙逼。代码和公式相结合能够理解更深入。
传统的图像卷积是每一个卷积核为[kernel_size,kernel_size,in_channels],通过一系列相乘相加操作后得出特征图的一个像素点。如果是BP网络这一步就已经结束了,但是卷积网络会利用stride进行移动相同的卷积核得出下一个像素点。就这样按照步长在图像的宽高进行移动,得出一个通道的特征图,那如果我想要out_channels个通道的特征图。我只需要out_channels个卷积和就可以了,所以卷积的参数维度就是[kernel_size,kernel_size,in_channels,out_channels]。后期人们在消除特征图的冗余,人们又提出了grop_conv和depth_wise的卷积,对应的网络就是现在的resenxt和mobilenet。关于冗余的理解之前看过一本书上讲解是过多的输出通道,卷积核很大概率存在相似性,那么输出的特征图就会存在线性相关(简单说就是特征图的一个向量可以由另一个向量线性表示)。这部分如果大家有感到不太懂的,自动google关键字。或者加我私聊,欢迎骚扰!

好像有点扯远了,,,,现在开始进入重点啦!!八度卷积是在分辨率的维度提出低频的信息在传统的卷积中也存在冗余,通过将特征图分离成低频信息(低分辨率),高频信息(高分辨率)的达到节省存储和算力。大概估算一下,如果每一个特征图的一半为低频信息,那么他的分辨率降低为原始特征图的1/2,存储会卷积运算会减少1/4。
下采样刚才我们降低冗余是通过降低低频信息的分辨率,那么现在的问题是如何进行分辨率的降低呢?卷积网络中有两种下采样的方式,一种是池化(pool),一种是步长为2的卷积。论文的实验是说池化的方式会更有效

消融实验

将八度卷积嵌入到resnet中发现stride=2的卷积下采样并没有降低可训练的参数,而pool的下采样方式则数十倍的降低了参数量。具体的数值当时没有保存,应该会降低的更过。pool我们好理解,因为pool本来并没有可训练卷积,而stride=2的卷积下采样本质是将原始的卷积核分解成四份(中间卷积)或者两份(开始和结尾卷积),所以他的可训练参数是不会减少的。
八度卷积路线图
第一层卷积:输入图像默认全部为高频信息,故alpha_int=0,alpha_out=
在这里插入图片描述

中间层卷积,特征图包含低频和高频信息,一般设置为alpha_int=alpha_out=
在这里插入图片描述

最后一层卷积,回复正常特征图,故alpha_int=,alpha_out=0
在这里插入图片描述

这里的参数设置一般为0.5,0.2。具体的参数设置会根据图像的特征丰富程度调整。
简单总结:特征图由第一层进入分为两路(低频信息和高频信息),中间层一直是两路信息,并且两路信息之间有交互,最终汇聚为一路信息输出。

具体实现代码

版本一 pool池化

# -*- coding: utf-8 -*-
# @Time    : 2019/4/22 13:29
# @Author  : ljf
import torch
import torch.nn.functional as F
from torch import nn


class OctConv2d_v1(nn.Conv2d):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 alpha_in=0.5,
                 alpha_out=0.5
                 ):
        """adapt first octconv , octconv and last octconv

        """
        assert alpha_in >= 0 and alpha_in <= 1, "the value of alpha_in should be in range of [0,1],but get {}".format(
            alpha_in)
        assert alpha_out >= 0 and alpha_out <= 1, "the value of alpha_in should be in range of [0,1],but get {}".format(
            alpha_out)
        super(OctConv2d_v1, self).__init__(in_channels,
                                        out_channels,
                                        dilation,
                                        groups,
                                        bias,)
        self.alpha_in = alpha_in
        self.alpha_out = alpha_out
        self.kernel_size = (1,1)
        self.stride = (1,1)
        self.avgPool = nn.AvgPool2d(kernel_size, stride, padding)
        self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2)

        self.inChannelSplitIndex = int(
            self.alpha_in * self.in_channels)
        self.outChannelSplitIndex = int(
            self.alpha_out * self.out_channels)
        # split bias
        if bias:
            self.hh_bias = self.bias[self.outChannelSplitIndex:]
            self.hl_bias = self.bias[:self.outChannelSplitIndex]
            self.ll_bias = self.bias[ :self.outChannelSplitIndex]
            self.lh_bias = self.bias[ self.outChannelSplitIndex:]
        else:
            self.hh_bias = None
            self.hl_bias = None
            self.ll_bias = None
            self.ll_bias = None

        # conv and upsample
        self.upsample = F.interpolate

    def forward(self, x):
        if not isinstance(x, tuple):
            # first octconv
            input_h = x if self.alpha_in == 0 else None
            input_l = x if self.alpha_in == 1 else None
        else:
            input_l = x[0]
            input_h = x[1]

        output = [0, 0]
        # H->H
        if self.outChannelSplitIndex != self.out_channels and self.inChannelSplitIndex != self.in_channels:
            output_hh = F.conv2d(self.avgPool(input_h),
                                 self.weight[
                                 self.outChannelSplitIndex:,
                                 self.inChannelSplitIndex:,
                                 :, :],
                                 self.bias[self.outChannelSplitIndex:],
                                 self.kernel_size
                                 )

            output[1] += output_hh

        # H->L
        if self.outChannelSplitIndex != 0 and self.inChannelSplitIndex != self.in_channels:
            output_hl = F.conv2d(self.avgpool(self.avgPool(input_h)),
                                 self.weight[
                :self.outChannelSplitIndex,
                self.inChannelSplitIndex:,
                                     :, :],
                                 self.bias[:self.outChannelSplitIndex],
                                 self.kernel_size
                                 )

            output[0] += output_hl

        # L->L
        if self.outChannelSplitIndex != 0 and self.inChannelSplitIndex != 0:
            output_ll = F.conv2d((self.avgPool(input_l)),
                                 self.weight[
                                 :self.outChannelSplitIndex,
                                 :self.inChannelSplitIndex,
                                 :, :],
                                 self.bias[:self.outChannelSplitIndex],
                                 self.kernel_size
                                 )

            output[0] += output_ll

        # L->H
        if self.outChannelSplitIndex != self.out_channels and self.inChannelSplitIndex != 0:
            output_lh = F.conv2d(self.avgPool(input_l),
                                 self.weight[
                                 self.outChannelSplitIndex:,
                                 :self.inChannelSplitIndex,
                                 :, :],
                                 self.bias[self.outChannelSplitIndex:],
                                 self.kernel_size
                                 )
            output_lh = self.upsample(output_lh, scale_factor=2)

            output[1] += output_lh

        if isinstance(output[0], int):
            out = output[1]
        else:
            out = tuple(output)
        return out
if __name__ == "__main__":
    input = torch.randn(1, 3, 32, 32)
    octconv1 = OctConv2d(
        in_channels=3,
        out_channels=6,
        kernel_size=3,
        padding=1,
        stride=2,
        alpha_in=0,
        alpha_out=0.3)
    octconv2 = OctConv2d(
        in_channels=6,
        out_channels=16,
        kernel_size=2,
        padding=0,
        stride=2,
        alpha_in=0.3,
        alpha_out=0.5)
    lastconv = OctConv2d(
        in_channels=16,
        out_channels=32,
        kernel_size=2,
        padding=0,
        stride=2,
        alpha_in=0.5,
        alpha_out=0)
    # bn1 = OctBN(3,3)
    # ac1 = OctAc(name="relu")
    out = octconv1(input)
    print(len(out))
    print(out[0].size())
    print(out[1].size())
    out = octconv2(out)
    print(len(out))
    print(out[0].size())
    print(out[1].size())

    out = lastconv(out)
    print(len(out))
    print(out[0].size())
    print(out[1])

版本二 stride=2的卷积

# -*- coding: utf-8 -*-
# @Time    : 2019/4/22 10:35
# @Author  : ljf
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class OctConv2d_v2(nn.Conv2d):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            stride=1,
            padding=0,
            dilation=1,
            groups=1,
            bias=True,
            alpha_in=0.5,
            alpha_out=0.5,):
        assert alpha_in >= 0 and alpha_in <= 1
        assert alpha_out >= 0 and alpha_out <= 1
        super(OctConv2d_v2, self).__init__(in_channels, out_channels,
                                           kernel_size, stride, padding,
                                           dilation, groups, bias)
        self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2)
        self.alpha_in = alpha_in
        self.alpha_out = alpha_out
        self.inChannelSplitIndex = math.floor(
            self.alpha_in * self.in_channels)
        self.outChannelSplitIndex = math.floor(
            self.alpha_out * self.out_channels)
        if bias:
            self.hh_bias = self.bias[self.outChannelSplitIndex:]
            self.hl_bias = self.bias[:self.outChannelSplitIndex]
            self.ll_bias = self.bias[ :self.outChannelSplitIndex]
            self.lh_bias = self.bias[ self.outChannelSplitIndex:]
        else:
            self.hh_bias = None
            self.hl_bias = None
            self.ll_bias = None
            self.lh_bias = None
    def forward(self, input):
        if not isinstance(input, tuple):
            assert self.alpha_in == 0 or self.alpha_in == 1
            inputLow = input if self.alpha_in == 1 else None
            inputHigh = input if self.alpha_in == 0 else None
        else:
            inputLow = input[0]
            inputHigh = input[1]

        output = [0, 0]
        # H->H
        if self.outChannelSplitIndex != self.out_channels and self.inChannelSplitIndex != self.in_channels:
            outputH2H = F.conv2d(
                inputHigh,
                self.weight[
                    self.outChannelSplitIndex:,
                    self.inChannelSplitIndex:,
                    :,
                    :],
                self.hh_bias,
                self.stride,
                self.padding,
                self.dilation,
                self.groups)
            output[1] += outputH2H

        # H->L
        if self.outChannelSplitIndex != 0 and self.inChannelSplitIndex != self.in_channels:
            outputH2L = F.conv2d(
                self.avgpool(inputHigh),
                self.weight[
                    :self.outChannelSplitIndex,
                    self.inChannelSplitIndex:,
                    :,
                    :],
                self.hl_bias,
                self.stride,
                self.padding,
                self.dilation,
                self.groups)
            output[0] += outputH2L

        # L->L
        if self.outChannelSplitIndex != 0 and self.inChannelSplitIndex != 0:
            outputL2L = F.conv2d(
                inputLow,
                self.weight[
                    :self.outChannelSplitIndex,
                    :self.inChannelSplitIndex,
                    :,
                    :],
                self.ll_bias,
                self.stride,
                self.padding,
                self.dilation,
                self.groups)
            output[0] += outputL2L

        # L->H
        if self.outChannelSplitIndex != self.out_channels and self.inChannelSplitIndex != 0:
            outputL2H = F.conv2d(
                F.interpolate(inputLow, scale_factor=2),
                self.weight[
                    self.outChannelSplitIndex:,
                    :self.inChannelSplitIndex,
                    :,
                    :],
                self.lh_bias,
                self.stride,
                self.padding,
                self.dilation,
                self.groups)
            output[1] += outputL2H
        if isinstance(output[0],int):
            out = output[1]
        else:
            out = tuple(output)
        return out


if __name__ == "__main__":
    input = torch.randn(1, 3, 32, 32)
    octconv1 = OctConv2d(in_channels=3,
                         out_channels=6,
                         kernel_size=3,
                         stride=2,
                         padding=1,
                         dilation=1,
                         groups=1,
                         bias=True,
                         alpha_in=0.,
                         alpha_out=0.25)
    octconv2 = OctConv2d(in_channels=6,
                         out_channels=16,
                         kernel_size=3,
                         stride=1,
                         padding=1,
                         dilation=1,
                         groups=1,
                         bias=True,
                         alpha_in=0.25,
                         alpha_out=0.5)
    out = octconv1(input)
    print(len(out))
    print(out[0].shape)
    print(out[1].size())

    out = octconv2(out)
    print(len(out))
    print(out[0].size())
    print(out[1].size())

github地址

功力有限,还请各位多多包涵,多多指证。
参考文章:https://mp.weixin.qq.com/s?__biz=MzUyMjE2MTE0Mw==&mid=2247487810&idx=1&sn=1428510ec154a24a9e779d82f693930d&chksm=f9d14fdacea6c6cc42a630e57726c1789a54dc8e31616bd747fb2c35f41dbbd86f2c2a0b8998&mpshare=1&scene=23&srcid=#rd

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 206,602评论 6 481
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 88,442评论 2 382
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 152,878评论 0 344
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 55,306评论 1 279
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 64,330评论 5 373
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,071评论 1 285
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,382评论 3 400
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,006评论 0 259
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 43,512评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,965评论 2 325
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,094评论 1 333
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,732评论 4 323
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,283评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,286评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,512评论 1 262
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,536评论 2 354
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,828评论 2 345

推荐阅读更多精彩内容