vectornet代码复现

数据:
先来看下丢到模型里面的x,下面是直接将x当作散点图可视化,每个polyline用不同的颜色表示,红线是需要预测的agent的历史轨迹


x可视化

下面是官方的api可视化


image.png

模型结构:

class HGNN(nn.Module):
    def forward(self, data):
        time_step_len = int(data[0].time_step_len[0]) #83
        valid_lens = data[0].valid_len # 78
        sub_graph_out = self.subgraph(data)
        x = sub_graph_out.x.view(-1, time_step_len, self.polyline_vec_shape) 
        out = self.self_atten_layer(x, valid_lens)
        pred = self.traj_pred_mlp(out[:, [0]].squeeze(1))
        return pred

核心代码就四行:

1. sub_graph_out = self.subgraph(data)

2. x = sub_graph_out.x.view(-1, time_step_len, self.polyline_vec_shape)

3. out = self.self_atten_layer(x, valid_lens)

4. pred = self.traj_pred_mlp(out[:, [0]].squeeze(1))

首先看1
subGraph的forward如下

class SubGraph(nn.Module):
    """
    Subgraph that computes all vectors in a polyline, and get a polyline-level feature
    """
    def __init__(self, in_channels, num_subgraph_layres=3, hidden_unit=64):
        super(SubGraph, self).__init__()
        self.num_subgraph_layres = num_subgraph_layres
        self.layer_seq = nn.Sequential()
        for i in range(num_subgraph_layres):
            self.layer_seq.add_module(
                f'glp_{i}', GraphLayerProp(in_channels, hidden_unit))
            in_channels *= 2

    def forward(self, sub_data):
        x, edge_index = sub_data.x, sub_data.edge_index # x 8310,8 edge_index 2,66852
        for name, layer in self.layer_seq.named_modules():
            if isinstance(layer, GraphLayerProp):
                x = layer(x, edge_index)
        sub_data.x = x # 8310,64
        out_data = max_pool(sub_data.cluster, sub_data) # 1162,64
        assert out_data.x.shape[0] % int(sub_data.time_step_len[0]) == 0
        out_data.x = out_data.x / out_data.x.norm(dim=0)
        return out_data

subgraph的核心代码有三步

1.1

 for name, layer in self.layer_seq.named_modules():
            if isinstance(layer, GraphLayerProp):
                x = layer(x, edge_index)

1.2 out_data = max_pool(sub_data.cluster, sub_data)

1.3 out_data.x = out_data.x / out_data.x.norm(dim=0)

先来看1.1
subgraph的forward中首先过了三层GraphLayerProp

for name, layer in self.layer_seq.named_modules():
            if isinstance(layer, GraphLayerProp):
                x = layer(x, edge_index)

self.layer_seq.named_modules()如下:

(glp_0): GraphLayerProp(
    (mlp): Sequential(
      (0): Linear(in_features=8, out_features=64, bias=True)
      (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (2): ReLU()
      (3): Linear(in_features=64, out_features=8, bias=True)
    )
  )
  (glp_1): GraphLayerProp(
    (mlp): Sequential(
      (0): Linear(in_features=16, out_features=64, bias=True)
      (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (2): ReLU()
      (3): Linear(in_features=64, out_features=16, bias=True)
    )
  )
  (glp_2): GraphLayerProp(
    (mlp): Sequential(
      (0): Linear(in_features=32, out_features=64, bias=True)
      (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (2): ReLU()
      (3): Linear(in_features=64, out_features=32, bias=True)
    )
  )

但是我们发现(3)linear的out_features 不等于下一层的in_features
因为(3)linear后面还有个contact的操作(具体看GraphLayerProp里面的update),让out_features翻倍了,实际上应该是:
(8310,8)-> (8310,16)
(8310,16)-> (8310,32)
(8310,32)-> (8310,64)
现在咱们来具体看下GraphLayerProp

class GraphLayerProp(MessagePassing):
    """
    Message Passing mechanism for infomation aggregation
    """
    def __init__(self, in_channels, hidden_unit=64, verbose=False):
        super(GraphLayerProp, self).__init__(
            aggr='max')  # MaxPooling aggragation
        self.verbose = verbose
        self.mlp = nn.Sequential(
            nn.Linear(in_channels, hidden_unit),
            nn.LayerNorm(hidden_unit),
            nn.ReLU(),
            nn.Linear(hidden_unit, in_channels)
        )

    def forward(self, x, edge_index):
        if self.verbose:
            print(f'x before mlp: {x}')
        x = self.mlp(x)
        if self.verbose:
            print(f"x after mlp: {x}")
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)

    def message(self, x_j):
        return x_j

    def update(self, aggr_out, x):
        if self.verbose:
            print(f"x after mlp: {x}")
            print(f"aggr_out: {aggr_out}")
        return torch.cat([x, aggr_out], dim=1)

GraphLayerProp中主要有三步:

1.1.1 encoder

1.1.2 aggregate

1.1.3 contact

subgraph

结合图片来看:

1.1.1 encoder:

forward中x = self.mlp(x) 先对feature做一次mlp ,即x :(8310,8) -> (8310,64) -> x (8310,8)

x = self.mlp(x)

1.1.2 aggregate:

做一次max的gnn 的aggregate

super(GraphLayerProp, self).__init__(
            aggr='max')  # MaxPooling aggragation

1.1.3 contact:

将max出来的feature 和 max前的feature 做一次concat ,所以feature维度在这翻倍

torch.cat([x, aggr_out], dim=1) 

上述1.1.1-1.1.3是一层GraphLayerProp,subgraph的forward中过了三层,即:
(8310,8)-> (8310,16)
(8310,16)-> (8310,32)
(8310,32)-> (8310,64)

现在过完三次GraphLayerProp,x : (8310,64)

1.2 out_data = max_pool(sub_data.cluster, sub_data) # 1162,64

回到1.2:对每个polyline subgraph做maxpooling

sub_data.cluster 里面类似[0,0,0,0,1,1,1,1,2,2,2,3,3....1161,1161]
这里面0000,1111,222分别是不同id的车道线、车辆等的子图,即论文中的polyline subgraphs

例如:
0,0,0,0表示id为0的子图有四个时间刻

现在将每个物体抽象成了一个64维向量,即,将所有时间刻的向量池化为一个时间刻的向量

做maxpooling 后x:(1162,64)= (14*83 ,64)
即有14个场景中,每个场景83个车道和车辆单一时刻的vector

polyline subgraphs

1.3 out_data.x = out_data.x / out_data.x.norm(dim=0)

除以均值

2 x = sub_graph_out.x.view(-1, time_step_len, self.polyline_vec_shape)

接下来reshape一下

time_step_len = 83 (83包含了1个agent,41个左车道线和41个右车道线)

x(1162,64) -> x(14,83,64)

这里14表示有14个预测场景,每个场景有83个polyline,每个polyline的feature是64维的向量

3 out = self.self_atten_layer(x, valid_lens) #14,83,64

通过self attention计算每个polyline直接的注意力,再aggregate一下。

self_atten_layer的初始化:

self.self_atten_layer = SelfAttentionLayer(
            self.polyline_vec_shape,
            global_graph_width, 
            need_scale=False) #64  64
self attention
def forward(self, x, valid_len):
        query = self.q_lin(x) # 14,83,64 
        key = self.k_lin(x)
        value = self.v_lin(x)
        scores = torch.bmm(query, key.transpose(1, 2)) # 14,83,83
        attention_weights = masked_softmax(scores, valid_len)
        return torch.bmm(attention_weights, value)

4 pred = self.traj_pred_mlp(out[:, [0]].squeeze(1)) #14,60

traj_pred_mlp的初始化

self.traj_pred_mlp = TrajPredMLP(
            global_graph_width, out_channels, traj_pred_mlp_width) # 64 60 64

最后一步直接把(14,83,64) -> (14,60)
60的向量由30个x坐标值和30个y坐标值组成,即预测的后30个时间片的轨迹坐标

class TrajPredMLP(nn.Module):
    """Predict one feature trajectory, in offset format"""

    def __init__(self, in_channels, out_channels, hidden_unit):
        super(TrajPredMLP, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_channels, hidden_unit),
            nn.LayerNorm(hidden_unit),
            nn.ReLU(),
            nn.Linear(hidden_unit, out_channels)
        )

    def forward(self, x):
        return self.mlp(x)
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 204,793评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 87,567评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 151,342评论 0 338
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,825评论 1 277
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,814评论 5 368
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,680评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,033评论 3 399
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,687评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 42,175评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,668评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,775评论 1 332
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,419评论 4 321
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,020评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,978评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,206评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,092评论 2 351
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,510评论 2 343

推荐阅读更多精彩内容