Dynamic Routing Between Capsules

Sabour S, Frosst N, Hinton G E, et al. Dynamic Routing Between Capsules[C]. neural information processing systems, 2017: 3856-3866.

虽然11年就提出了capsule的概念, 但是走入人们视线的应该还是这篇文章吧. 虽然现阶段, capsule没有体现出什么优势. 不过, capsule相较于传统的CNN融入了很多先验知识, 更能够拟合人类的视觉系统(我不知), 或许有一天它会大放异彩.

主要内容

在这里插入图片描述

直接从这个结构图讲起吧.

  1. Input: 1 x 28 x 28 的图片 经过 9 x 9的卷积核(stride=1, padding=0, out_channels=256)作用;
  2. 256 x 20 x 20的特征图, 经过primarycaps作用(9 x 9 的卷积核(strde=2, padding=0, out_channels=256);
  3. (32 x 8) x 6 x 6的特征图, 理解为32 x 6 x 6 x 8 = 1152 x 8, 即1152个胶囊, 每个胶囊由一个8D的向量表示u_{i}; (这个地方要不要squash, 大部分实现都是要的.)
  4. 接下来digitcaps中有10个caps(对应10个类别), 1152caps和10个caps一一对应, 分别用i, j表示, 前一层的caps为后一层提供输入, 输入为
    \hat{u}_{j|i} = W_{ij}u_i,
    可见, 应当有1152 x 10个W_{ij}\in \mathbb{R}^{16\times 8}, 其中16是输出胶囊的维度. 最后10个caps的输出为
    s_j= \sum_{i}c_{ij}\hat{u}_{j|i}, v_j= \frac{\|s\|_j^2}{1 + \|s_j\|^2} \frac{s_j}{\|s_j\|}.

其中c_{ij}是通过一个路由算法决定的, v_j, 即最后的输入如此定义是出于一种直觉, 即保持原始输出(s)的方向, 同时让v的长度表示一个概率(这一步称为squash).

首先初始化b_{ij}=0 (这里在程序实现的时候有一个考量, 是每一次都要初始化吗, 我看大部分的实现都是如此的).

在这里插入图片描述

上面的Eq.3就是
\tag{3} c_{ij}=\frac{\exp(b_{ij})}{\sum_{k}\exp(b_{ik})}.

另外\hat{\mu}_{j|i} \cdot v_j=\hat{\mu}_{j|i}^Tv_j是一种cos相似度度量.

损失函数

损失函数采用的是margin loss:
\tag{4} L_k = T_k \max(0, m^+ - \|v_k\|)^2 + \lambda (1 - T_k) \max(0, \|v_k\|-m^-)^2.

m^+, m^-通常取0.9和0.1, \lambda通常取0.5.

代码

我的代码, 在sgd下可以训练(但是准确率只有98), 在adam下就死翘翘了, 所以代码肯定是有问题, 但是我实在是找不出来了, 这里有很多实现的汇总.



"""
Sabour S., Frosst N., Hinton G. Dynamic Routing Between Capsules.
Neural Information Processing Systems, pp. 3856-3866, 2017.
https://arxiv.org/pdf/1710.09829.pdf
The implement below refers to https://github.com/adambielski/CapsNet-pytorch.
"""


import torch
import torch.nn as nn
import torch.nn.functional as F



def squash(s):
    temp = s.norm(dim=-1, keepdim=True)
    return (temp / (1. + temp ** 2)) * s


class PrimaryCaps(nn.Module):

    def __init__(
        self, in_channel, out_entities, 
        out_dims, kernel_size, stride, padding
    ):
        super(PrimaryCaps, self).__init__()
        self.conv = nn.Conv2d(in_channel, out_entities * out_dims, 
                            kernel_size, stride, padding)
        self.out_entities = out_entities
        self.out_dims = out_dims

    def forward(self, inputs):
        conv_outs = self.conv(inputs).permute(0, 2, 3, 1).contiguous()
        outs = conv_outs.view(conv_outs.size(0), -1, self.out_dims)
        return squash(outs)


class AgreeRouting(nn.Module):

    def __init__(self, in_caps, out_caps, out_dims, iterations=3):
        super(AgreeRouting, self).__init__()

        self.in_caps = in_caps
        self.out_caps = out_caps
        self.out_dims = out_dims
        self.iterations = iterations

    @staticmethod
    def softmax(inputs, dim=-1):
        return F.softmax(inputs, dim=dim)

    def forward(self, inputs):
        # inputs N x in_caps x out_caps x out_dims
        b = torch.zeros(inputs.size(0), self.in_caps, self.out_caps).to(inputs.device)
        for r in range(self.iterations):
            c = self.softmax(b) # N x in_caps x out_caps !!!!!!!!!
            s = (c.unsqueeze(-1) * inputs).sum(dim=1) # N x out_caps x out_dims
            v = squash(s) # N x out_caps x out_dims
            b = b + (v.unsqueeze(dim=1) * inputs).sum(dim=-1)
        return v



class CapsLayer(nn.Module):

    def __init__(self, in_caps, in_dims, out_caps, out_dims, routing):
        super(CapsLayer, self).__init__()
        self.in_caps = in_caps
        self.in_dims = in_dims
        self.routing = routing
        self.weights = nn.Parameter(torch.rand(in_caps, out_caps, in_dims, out_dims))
        nn.init.kaiming_uniform_(self.weights)

    def forward(self, inputs):
        # inputs: N x in_caps x in_dims
        inputs = inputs.view(inputs.size(0), self.in_caps, 1, 1, self.in_dims)
        u_pres = (inputs @ self.weights).squeeze() # N x in_caps x out_caps x out_dims
        outs = self.routing(u_pres) # N x out_caps x out_dims

        return outs




class CapsNet(nn.Module):

    def __init__(self):
        super(CapsNet, self).__init__()

        # N x 1 x 28 x 28
        self.conv = nn.Conv2d(1, 256, 9, 1, padding=0) # N x (32 * 8) x 20 x 20
        self.primarycaps = PrimaryCaps(256, 32, 8, 9, 2, 0) # N x (6 x 6 x 32) x 8
        routing = AgreeRouting(32 * 6 * 6, 10, 8, 3)
        self.digitlayer = CapsLayer(32 * 6 * 6, 8, 10, 16, routing)


    def forward(self, inputs):
        conv_outs = F.relu(self.conv(inputs))
        pri_outs = self.primarycaps(conv_outs)
        outs = self.digitlayer(pri_outs)
        probs = outs.norm(dim=-1)
        return probs
        


if __name__ == "__main__":

    x = torch.randn(4, 1, 28 ,28)
    capsnet = CapsNet()
    print(capsnet(x))


def margin_loss(logits, labels, m=0.9, leverage=0.5, adverage=True):
    # outs: N x num_classes x dim
    # labels: N
    temp1 = F.relu(m - logits) ** 2
    temp2 = F.relu(logits + m - 1) ** 2
    T = F.one_hot(labels.long(), logits.size(-1))
    loss = (temp1 * T + leverage * temp2 * (1 - T)).sum()
    if adverage:
        loss = loss / logits.size(0)
    # Another implement is using scatter_
    # T = torch.zero(logits.size()).long()
    # T.scatter_(dim=1, index=labels.view(-1, 1), 1.).cuda() if cuda()
    return loss

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 206,378评论 6 481
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 88,356评论 2 382
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 152,702评论 0 342
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 55,259评论 1 279
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 64,263评论 5 371
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,036评论 1 285
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,349评论 3 400
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,979评论 0 259
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 43,469评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,938评论 2 323
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,059评论 1 333
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,703评论 4 323
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,257评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,262评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,485评论 1 262
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,501评论 2 354
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,792评论 2 345