import torch
from torch import nn
import torch.nn.functional as f
from torch.autograd import Variable
# Define some constants
KERNEL_SIZE = 3
PADDING = KERNEL_SIZE // 2
class ConvLSTMCell(nn.Module):
"""
Generate a convolutional LSTM cell
"""
def __init__(self, input_size, hidden_size):
super(ConvLSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.Gates = nn.Conv2d(input_size + hidden_size, 4 * hidden_size, KERNEL_SIZE, padding=PADDING)
def forward(self, input_, prev_state):
# get batch and spatial sizes
batch_size = input_.data.size()[0]
spatial_size = input_.data.size()[2:]
# generate empty prev_state, if None is provided
if prev_state is None:
state_size = [batch_size, self.hidden_size] + list(spatial_size)
prev_state = (
Variable(torch.zeros(state_size)),
Variable(torch.zeros(state_size))
)
prev_hidden, prev_cell = prev_state
# data size is [batch, channel, height, width]
stacked_inputs = torch.cat((input_, prev_hidden), 1)
gates = self.Gates(stacked_inputs)
# chunk across channel dimension
in_gate, remember_gate, out_gate, cell_gate = gates.chunk(4, 1)
# apply sigmoid non linearity
in_gate = f.sigmoid(in_gate)
remember_gate = f.sigmoid(remember_gate)
out_gate = f.sigmoid(out_gate)
# apply tanh non linearity
cell_gate = f.tanh(cell_gate)
# compute current cell and hidden state
cell = (remember_gate * prev_cell) + (in_gate * cell_gate)
hidden = out_gate * f.tanh(cell)
return hidden, cell
def _main():
"""
Run some basic tests on the API
"""
# define batch_size, channels, height, width
b, c, h, w = 1, 3, 4, 8
d = 5 # hidden state size
lr = 1e-1 # learning rate
T = 6 # sequence length
max_epoch = 20 # number of epochs
# set manual seed
torch.manual_seed(0)
print('Instantiate model')
model = ConvLSTMCell(c, d)
print(repr(model))
print('Create input and target Variables')
x = Variable(torch.rand(T, b, c, h, w))
y = Variable(torch.randn(T, b, d, h, w))
print('Create a MSE criterion')
loss_fn = nn.MSELoss()
print('Run for', max_epoch, 'iterations')
for epoch in range(0, max_epoch):
state = None
loss = 0
for t in range(0, T):
state = model(x[t], state)
loss += loss_fn(state[0], y[t])
print(' > Epoch {:2d} loss: {:.3f}'.format((epoch+1), loss.data[0]))
# zero grad parameters
model.zero_grad()
# compute new grad parameters through time!
loss.backward()
# learning_rate step against the gradient
for p in model.parameters():
p.data.sub_(p.grad.data * lr)
print('Input size:', list(x.data.size()))
print('Target size:', list(y.data.size()))
print('Last hidden state size:', list(state[0].size()))
if __name__ == '__main__':
_main()
PyTorch实现卷积LSTM核(ConvLSTMCell)
最后编辑于 :
©著作权归作者所有,转载或内容合作请联系作者
- 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
- 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
- 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
推荐阅读更多精彩内容
- CNN从2012年的AlexNet发展至今,科学家们发明出各种各样的CNN模型,一个比一个深,一个比一个准确,一个...
- 姓名:周雪宁 学号:1702110196 转载:https://mp.weixin.qq.com/s/4-9SHF...
- 1×N 和 N×1 的卷积核主要是为了发现宽的特征和高的特征 1×1的卷积核,因为在实验中发现很多特征并没有激发状...
- CNN神经网络算法是常用的模式识别算法,该算法通过卷积运算将图片特征存储到多个卷积核中,卷积核通过算法的反向传输一...
- Question: 从NIN 到Googlenet mrsa net 都是用了这个,为什么呢? 发现很多网络使用了...