交通场景、车道分割算法-SCNN

论文中提出了一个新颖的网络Spatial CNN,该网络在图片的行和列上做信息传递,可以有效的识别强先验结构的目标。同时论文提出了一个大型的车道检测数据集UCLane,用于进一步推动自动驾驶发展。

offical githubhttps://github.com/XingangPan/SCNN
paperSpatial As Deep: Spatial CNN for Traffic Scene Understanding

模型

本文提出的SCNN算法将传统的卷积层接层(layer-by-layer)的连接形式的转为feature map中片连片卷积(slice-by-slice)的形式,使得图中像素行和列之间能够传递信息。这特别适用于检测长距离连续形状的目标或大型目标,有着极强的空间关系但是外观线索较差的目标,例如交通线,电线杆和墙。

传统的CNN不能有效处理具有长距离连续的形状(尤其是在遮挡的情况下)。MRF/CRF+CNN的结构使用一个大卷积核来进行信息传递,但是会导致计算效率低下,并且大卷积核很难训练,如下图(a)所示;而SCNN分别在列方向与行方向使用宽卷积做了循环的信息传递,这样就增强了空间信息进而对于识别结构化对象特别有效,如下图(b)所示。


Model

D、U、R、L是四个信息传递模块。D、U沿着H方向做了从上到下和从下到上的信息传递;R、L沿着W方向做了从左到右和从右到左的信息传递。信息传递的公式如下所示,f是relu函数,每一个模块的卷积函数都共享同一个卷积核。
简单的举一个例子,假设x0h方向上的第一片特征,x1为第二片,那么x0x1的信息传递过程就是x1=x1+relu(conv2D(x0)),后面的操作就可以这样循环下去。这个操作类似于循环的残差操作,既能够加快计算效率又能传递长信息。

Message pass equation

在信息传递(Message Pass)过程中,MRF/CRF中每个像素点会直接接收其他所有像素点的信息(大卷积核实现),这其中有许多冗余计算;而SCNN在信息传递的时候并不是获取全局元素,而是顺序传递,由此简化了信息传递的结构加快了模型的运算效率,如下图所示:


Message pass

在进行车道检测时,在上述模型的基础上,在输出结果上添加了一个分支网络。这个分支网络能够直接区分不同车道标记,这样鲁棒性更好。共有4中类型的车道线。输出的概率图经过这个分支网络预测车道标记是否存在。
对于存在值大于0.5的车道标记,在对应的概率图每20行搜索以获得最高的响应位置,然后通过三次样条函数连接这些点(cubic splines),就得到了最终的预测。


road lane

这是该算法在车道分割上达成的效果,使用了UCLane数据库

road lane res

这是该算法在交通场景分割上达成的效果,使用了cityscapes数据库

traffic sense res

模型实现

这个实现与官方实现并不是完全一致,仅用来理解SCNN的网络结构,同时由于数据集较大也没有进行训练测试。想要训练使用这个模型可以下载官方的torch版本或者tf版本

首先我们将信息传递的过程封装成一个keras层,每个MessagePass层沿一个轴做两个方向的信息传递,如下所示:

class MessagePass(Layer):
    def __init__(self, output_dim,
                 axis,
                 **kwargs):
        if 'input_shape' not in kwargs and 'input_dim' in kwargs:
            kwargs['input_shape'] = (kwargs.pop('input_dim'),)
        super(MessagePass, self).__init__(**kwargs)

        self.output_dim = output_dim
        self.axis = axis

    def build(self, input_shape):
        assert self.axis in [1, 2]
        assert input_shape[-1] == self.output_dim

        if self.axis == 1:
            kernel_shape = [1, 9, input_shape[-1], self.output_dim]
        if self.axis == 2:
            kernel_shape = [9, 1, input_shape[-1], self.output_dim]

        self.w1 = self.add_weight(name='one', 
                                 shape=kernel_shape,
                                 initializer='glorot_uniform',
                                 trainable=True)
        self.w2 = self.add_weight(name='two', 
                                 shape=kernel_shape,
                                 initializer='glorot_uniform',
                                 trainable=True)

        super(MessagePass, self).build(input_shape)

    def call(self, inputs, **kwargs):
        h, w = int(inputs.shape[1]), int(inputs.shape[2])
        
        if self.axis == 1:
            n = h
        if self.axis == 2:
            n = w

        feature_slice_old = []
        feature_slice_new = []

        for i in range(n):
            if self.axis == 1:
                cur_slice = K.expand_dims(inputs[:, i, :, :], axis=1)
            else:
                cur_slice = K.expand_dims(inputs[:, :, i, :], axis=2)
            feature_slice_old.append(cur_slice)

            if i == 0:
                feature_slice_new.append(cur_slice)
            else:
                tmp = K.relu(K.conv2d(feature_slice_old[i - 1], self.w1, padding='same'))
                tmp = tmp + feature_slice_old[i]
                feature_slice_new.append(tmp)

        feature_slice_old = feature_slice_new
        feature_slice_new = []

        for i in reversed(range(n)):
            if self.axis == 1:
                cur_slice = K.expand_dims(inputs[:, i, :, :], axis=1)
            else:
                cur_slice = K.expand_dims(inputs[:, :, i, :], axis=2)
            feature_slice_old.append(cur_slice)

            if i == (n - 1):
                feature_slice_new.append(cur_slice)
            else:
                tmp = K.relu(K.conv2d(feature_slice_old[i - 1], self.w2, padding='same'))
                tmp = tmp + feature_slice_old[i]
                feature_slice_new.append(tmp)

        output = K.stack(feature_slice_new, axis=self.axis)
        output = K.squeeze(output, axis=self.axis + 1)

        return output

    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[1], input_shape[2], self.output_dim)

我们选择DenseNet 121作为backbone网络,选取8倍缩小的那一个特征图作为输出特征,然后为这个特征接上信息传递层,如下所示:

class SCNN:
    def __init__(self, height, width, classes=5):
        self.classes = classes
        self.height = height
        self.width = width

    def backbone(self):
        model = DenseNet121(
                input_shape=(self.height, self.width, 3),
                weights=None, 
                include_top=False)

        out_conv = model.get_layer('pool3_conv').output

        return model.input, out_conv

    def build(self):
        inputs, conv_out = self.backbone()

        conv_out = Conv2D(128, (1, 1), padding='same')(conv_out)
        conv_out = BatchNormalization()(conv_out)
        conv_out = Activation('relu')(conv_out)

        conv_out = MessagePass(128, 1)(conv_out)
        conv_out = MessagePass(128, 2)(conv_out)

        conv_out = Conv2D(self.classes, (1, 1), activation='softmax', padding='same')(conv_out)
        prob_output = UpSampling2D((8, 8))(conv_out)

        # add lane existence prediction branch
        x = AveragePooling2D(strides=2)(conv_out)
        x = Flatten()(x)
        x = Dense(128, activation='relu')(x)
        existence_output = Dense(4, activation='sigmoid')(x)

        model = Model(inputs=inputs, outputs=[prob_output, existence_output])

        opt = SGD(lr=0.01, momentum=0.9, decay=0.0001)
        model.compile(
                optimizer=opt,
                loss=['categorical_crossentropy', 'binary_crossentropy'])

        return model


if __name__ == '__main__':
    model = SCNN(288, 800).build()
    print(model.summary())
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 203,362评论 5 477
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,330评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 150,247评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,560评论 1 273
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,580评论 5 365
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,569评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,929评论 3 395
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,587评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,840评论 1 297
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,596评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,678评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,366评论 4 318
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,945评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,929评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,165评论 1 259
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 43,271评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,403评论 2 342

推荐阅读更多精彩内容