import torch
import torch.nn as nn
lstm = nn.LSTM(
input_size=10,
hidden_size=20,
num_layers=1,
batch_first=True
)
input = torch.randn(3, 5, 10) # batch_size=3, seq_len=5, num_features=10
h0 = torch.randn(1, 3, 20)
c0 = torch.randn(1, 3, 20)
output, (h, c) = lstm(input, (h0, c0))
"""
h和c都是三维张量,其中第一维度表示该LSTM层的层数num_layers,默认为1
output是三维张量
output[:, -1, :] 与 h[-1, :, :]是一样的
当多个LSTM层叠加时,它们之间的数据传递用每一层的output
最后一个LSTM层与全连接层相连时,采用最后一层的h[-1, :, :]作为全连接层的输入
"""
torch中LSTM层的理解与记录
©著作权归作者所有,转载或内容合作请联系作者
- 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
- 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
- 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
推荐阅读更多精彩内容
- 一次上《记梁任公先生的一次演讲》这课时,有同学问我,为什么梁启超讲到《桃花扇》会痛哭流涕而不能自已。那今天我就这个...
- 这8种学生永远拿不到高分!早看早受益! 下面是一位资深班主任总结了8种成绩提不上去的原因,分别对应8类孩子,如果你...