Pytorch学习记录- 使用神经网络训练Seq2Seq(论文再读)

对Pytorch的Seq2Seq这6篇论文进行精读,今天重新开始,第一篇,《Sequence to Sequence Learning with Neural Networks》Sutskever, I., O. Vinyals and Q.V. Le, Sequence to Sequence Learning with Neural Networks. 2014.
Google发表于2014年,全文链接

摘要

虽然DNN很牛逼,但是仍然无法完成从句子到句子的映射。
这篇论文提出一个通用端到端学习方法,对序列结构做出最小假设。

  • 结构,使用LSTM将输入句子映射为一个固定维度(fixed dimensionality / fixed-sized)向量。使用另一个LSTM对向量进行解码。
  • 结果,使用WMT-14数据集的英-法翻译任务
  • 模型在长句上没有遇到困难
    最后发现,在源句中颠倒单词顺序,能够提高LSTM的成绩,因为这种操作会在源语句和目标句子之间引入许多短依赖关系,这使得优化问题更容易(似乎现在NLP的GAN有一项就是调整语序)。

1. 介绍

略过

2. 模型

输入句子X \in {x_1,x_2,...,x_T}
输出句子y \in {y_1,y_2,...,y_T}
标准RNN可以通过迭代下面的公式来计算输出序列(y_1,...,y_T)
h_t=sigm(W^{hx}x_t+W^{hh}h_{t-1})
y_t=W^{yh}h_t

但是问题来了,如何应对输入和输入长度不一样的句子?并且句子具有复杂的关系
一个解决方法:使用RNN将输入句子映射到一个固定长度的向量中,然后使用另一个RNN将向量映射到目标句子。但是这样的模型在长文本中进行训练是困难的,幸好LSTM出现了。

LSTM的目标就是估算条件概率p(y_1,...y_T'|x_1,...,x_T),其中输入句子X \in {x_1,x_2,...,x_T},输出句子y \in {y_1,y_2,...,y_T'}为什么是y_T',是因为输出和输入的长度可能不一样。
p(y_1,...y_T'|x_1,...,x_T)=\prod ^{T'}_{t=1}p(y_t|v,y_1,...,y_{t-1})
在方程中,每个p分布都是使用softmax处理词汇中所有单词结果来表示,每个句子都有一个<EOS>的结束符,这样就能确定句子的长度。
实际模型在三个重要方面与上述描述不同。

  • 首先,使用了两个不同的LSTM:一个用于输入序列,另一个用于输出序列。
  • 其次,发现深LSTM明显优于浅LSTM,因此选择了具有四层的LSTM。
  • 第三,发现扭转输入句子的单词顺序是非常有价值的。

3. 模型的实现

image.png

这张图是流程图,输入德语“guten morgen”,在绿色的encoder中被编码为一个一个词,在句首和句尾增加作为标签。

  • 每一个时间步,encoder的输入是当前单词x_t和上一时间步的隐藏状态h_{t-1}
  • 每一个时间步,encoder的输出是新的隐藏状态h_t
    可以将隐藏状态当成表示句子的向量。这样公式就出来了。
    h_t=EncoderRNN(x_t,h_{t-1})
    这里的RNN可以是任何卷积结构(LSTM或是GRU)。
    当输入句子最后一个单词传入RNN后,这时的隐藏状态h_T就是上下文向量,在这里表示为h_T=z就是示意图中中间的那个z。
    有了向量z,可以开始对目标句子进行解码,生成目标语言的句子。这样decoder的公式也有了。
    s_t = \text{DecoderRNN}(y_t, s_{t-1})
    在decoder中,我们从隐藏状态转到实际单词,每一个时间步都使用s_t来进行预测
    \hat{y} _t

4. 模型代码

4.1 引入相关库

import torch
import torch.nn as nn
import torch.optim as optim
from torchtext.datasets import TranslationDataset, Multi30k
from torchtext.data import Field, BucketIterator
import spacy
import random
import math
import time
# 设定SEED,让之后随机数生成一致
SEED=1234
random.seed(SEED)
torch.manual_seed(SEED)
# torch.backends.cudnn.benchmark = True 在程序刚开始加这条语句可以提升一点训练速度,没什么额外开销。
torch.backends.cudnn.deterministic=True
spacy_de=spacy.load('de')
spacy_en=spacy.load('en')
def tokenize_de(text):
    # 使用[::-1]将文本进行倒序排列
    return [tok.text for tok in spacy_de.tokenizer(text)][::-1]
def tokenize_en(text):
    return [tok.text for tok in spacy_en.tokenizer(text)]

SRC=Field(
    tokenize=tokenize_de,
    init_token='<sos>',
    eos_token='<eos>',
    lower=True
)
TRG=Field(
    tokenize=tokenize_en,
    init_token='<sos>',
    eos_token='<eos>',
    lower=True
)

train_data, valid_data, test_data=Multi30k.splits(exts=('.de','.en'),fields=(SRC,TRG))
print(f"Number of training examples: {len(train_data.examples)}")
print(f"Number of validation examples: {len(valid_data.examples)}")
print(f"Number of testing examples: {len(test_data.examples)}")
print(vars(train_data.examples[1]))
SRC.build_vocab(train_data,min_freq=2)
TRG.build_vocab(train_data,min_freq=2)
print(f"Unique tokens in source (de) vocabulary: {len(SRC.vocab)}")
print(f"Unique tokens in target (en) vocabulary: {len(TRG.vocab)}")
Number of training examples: 29000
Number of validation examples: 1014
Number of testing examples: 1000
{'src': ['.', 'antriebsradsystem', 'ein', 'bedienen', 'schutzhelmen', 'mit', 'männer', 'mehrere'], 'trg': ['several', 'men', 'in', 'hard', 'hats', 'are', 'operating', 'a', 'giant', 'pulley', 'system', '.']}
Unique tokens in source (de) vocabulary: 7855
Unique tokens in target (en) vocabulary: 5893
device=torch.device('cuda' if torch.cuda.is_available else 'cpu')
print(device)
BATCH_SIZE=128
train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE, 
    device = device)

4.2 构建模型

在前面完成对数据的处理后,现在开始构建模型,我们之前也这么做了,按照教程一步步走下来,证明是可以的。
但是,如果要你再写一遍的话,你会发现依旧写不出来,问题在哪里?
我个人感觉是对模型吃的不透,包括传入数据的结构和数据的处理流程。所以,这里选择这个只有三个模块的seq2seq来研究。
模型包括三个部分,encoder、decoder和seq2seq(整合部分)。seq2seq的操作流程:

  • 使用RNN(LSTM/GRU)对输入的句子(源语)进行编码,生成独立向量
  • 独立向量就是上下文向量,可以把这个上下文向量作为输入句子的抽象表示
  • 由第二个RNN(LSTM/GRU)对独立向量进行解码,通过一次生成一个字来学习输出目标句子

实现也会分成三个模块(encoder、decoder、seq2seq)来实现。在之前我们都会按照encoder->decoder->seq2seq的顺序来做,这样复合从具体到抽象的逻辑,但是我个人感觉搞到最后seq2seq的时候一头雾水,对输入的数据结构不了解。
这次换一下,从模型训练参数->seq2seq->encoder->decoder,我们看看搞进去的数据是什么样子。

4.2.1 模型配置

INPUT_DIM = len(SRC.vocab) # 模型输入维度,输入encoder的one-hot向量维度,就是根据源语数据集搞出来的词汇表中单词个数
# print(len(SRC.vocab))
OUTPUT_DIM = len(TRG.vocab) # 模型输出维度,输入到Decoder的one-hot向量,就是根据目标语数据集搞出来的词汇表单词个数
# print(len(TRG.vocab))

ENC_EMB_DIM = 256 # encoder的嵌入层维度,将one-hot向量转为密度向量
DEC_EMB_DIM = 256 # decoder的嵌入层维度,将one-hot向量转为密度向量

HID_DIM = 512 # 隐藏层和cell状态维度
N_LAYERS = 4 # 搞一个四层的
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5

在搞定数据输入之后我们可以来看一下这次处理的数据样式。使用enumerate函数来枚举train_data和train_iterator。

对于一个可迭代的(iterable)/可遍历的对象(如列表、字符串),enumerate将其组成一个索引序列,利用它可以同时获得索引和值。enumerate多用于在for循环中得到计数

可以看到train_data和train_iterator中所包含的数据是两种类型(list和tensor)
在之前的处理中,我们使用Multi30k的splits对SRC和TRG进行了训练、测试、验证集划分处理,生成train_data等三个数据集。
使用BucketIterator.splits对train_data进行处理

BucketIterator是torchtext最强大的功能之一。它会自动将输入序列进行shuffle并做bucket。
这个功能强大的原因是——正如我前面提到的——我们需要填充输入序列使得长度相同才能批处理。

这里介绍一个小技巧,在anaconda中,经常会因为打印信息长度问题只保留头尾,可以使用set_printoptions方法,这个方法在pandas、numpy、torch中都有。
tensor可以使用shape来查看tensor的行数和列数,这里是28行、128列,也就是说batch_size就是列数,目标句子长度max_len就是行数

torch.set_printoptions(threshold = 1e6)
for i , batch in enumerate(train_data):
    if i <1:
        print(i)
        src=batch.src
        trg=batch.trg
        print(type(src))
        print(src)
#         print(trg)
    else: break
for i , batch in enumerate(train_iterator):
    if i <1:
        print(i)
        src=batch.src
        trg=batch.trg
        print(type(src))
        print(src.shape)
        print(src)
        print(src.shape[0])
        print(src.shape[1])
#         print(trg)
    else: break
0
<class 'list'>
['.', 'büsche', 'vieler', 'nähe', 'der', 'in', 'freien', 'im', 'sind', 'männer', 'weiße', 'junge', 'zwei']
0
<class 'torch.Tensor'>
torch.Size([29, 128])
tensor([[   2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
            2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
            2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
            2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
            2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
            2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
            2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
            2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
            2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
            2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
            2,    2,    2,    2,    2,    2,    2,    2],
        [   4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,
            4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,
            4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,
            4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,
            4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,
            4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,
            4,    4,    4,    4,    4,    4,    4,    4,   29,    4,    4,    4,
            4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,
            4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,
            4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,
            4,    4,    4,    4,    4,    4,    4,    4],
        [   0, 3810,   65,  344,  762,   90, 4062,  368,  681,    0, 5343, 2391,
          118,   92,  507, 2615,   63,   34,   21, 1522,    0, 1640,  715,   63,
          837, 2779,  117,   72,  647,   23,    0,  215,   90,   21,  344,   54,
         3034,    0,   29, 7436,    5,   48,  299,  638,  297, 1114,  235,  933,
         1692, 6879,   48, 3367,  693,  122,  476,  235, 1534,   72,  286,    0,
          375,  301,   81,  141, 4703, 1716, 1803, 7475, 1101,    0,  271,   60,
           84,   72, 1674,   29,    0,  181,  113,  123, 1188,   27,  581, 3386,
         3119, 4317,  328,  228,  186, 1089, 1921,   90,   21,    0,   11, 6836,
          547, 4583,   80,   90, 5438,  916, 4559,  481,   21,   21,  932, 3389,
         2248, 2044,  233, 3826,  740,  230,    0,  389,   80,  343,  420,   34,
         1628,  316, 2475,   12,  129,   58, 1738,   63],
        [  82,  189,  105,    0,  197,   17,  102,   33,  248,  282,  149,    8,
           20,  525,   14,    6,  126, 1576,  379,   14,    8,    6,  118,  254,
         2700,   19, 2690,   34,    5,  839,   18,   19,   17, 6882, 7429, 1049,
          419,  168,  629,  529,  248, 2316,    5,   14,   86,  151,   39, 1462,
           11, 1956,  144,    8,  116,   24,   10,  304, 4360,   34,    5,    0,
          139,   24,   15,    6, 1217,  334, 1020,  555,   75,    7,   61,  117,
          144, 1488,   17, 2611,   34,   67,   14,   20,   11,  882,   12,   28,
            8,  506,  638, 1947,  405,   46,   19,   17, 1957,   19,  352, 4268,
           55, 2151,   22,   17,   19,    8,  139, 1900,  498,  679,   77,    6,
         7691,    6,   21,   19,    0,   14,   10,   17, 1841,  178,  726,   14,
         1452,   14,  423, 4963,   38, 1649, 2151,  126],
        [   6,   82,  428,   22,   11,    7,   19,    7,    5, 2527, 3704,   42,
         1102,   59,    7,   12,  314,   14,    6,    7,   42,   12,   20,  562,
         1995,   61,  574,   17,    7, 6937,   61,   61,   58,    0,   22,   28,
           14, 2442,   27,   37,    5, 1116,   37,   21,   20,  547,   42,    6,
          295,    6,   28,   12,    5,   12,  107,  110,   19,   17,  184,   14,
            0,   12,    7,    7,  151,   64,  919,    8,  553,    0,   87,  176,
           28,    8,   12, 1142,   17,  330,   22,   73,   13,   19,   53, 3950,
           75, 2254,   14,    8,    6,   19,  608,    7,    8,   68,   44,  172,
         1704,    5,   36,    7,   85,   85,  260,    7,  170,   19,   42,   21,
            5,   21, 3005,  238,    5,    7, 3936,   85,   43,    6,   14,    7,
            6,   12,   15,    5,   10,  170,    6,  115],
        [  12,   14,  658,    9,  293, 4525,    7,  146,   15,   19, 5272, 4891,
           19,   38,  143,   86,    6,   12,   21, 1476,   74,    0,   57, 3439,
            6,  138,    6,    0,  183,   11,    9,    9,  433,   10,  108,  183,
           27, 1888,  200,  520,   58,   39,  300,    0,  116,  304,  347,   22,
            6,   11, 1668,  293,   27, 3404,    6,   15,   15,   60,   10,    7,
           10,  164, 7504,   23, 1939,   17,    7,  415,   14,   14,   24,   98,
         4148,  124,   13,  248,   99,    8, 2224,   26, 1184,  534,   10,    0,
           85,   14,   21,   91,   12,    7,  435,  119,   16,   10, 1252,   41,
           29,   12,   14, 2072, 2110,  823,   20,   31,  627,  335, 1564,  183,
           17,  159,   33,   40,  166,   38,  101,  619,   11,   12,   21, 1191,
          126,  729,   22,  136,  140,   10,   12,    8],
        [  53,   11,    6,  154,   19,  139,  856,    9,    9,  728,   15,  127,
           68,   13, 2853,   20,    7,   29,    0,  127,    9,    6,   80,   14,
            7,   11,   21,   11,   13,   16,   48,   48,    6,  260,   10,   10,
           23,  687, 4227,    0, 7181,    9, 4629,   19,  610, 3597, 1296,  108,
           21,  101,    8,   24,   29,   57,   21,    7,    9, 1451, 2031,   13,
          966,    5,   19,  220,  193,    9, 2680,    0,   11,   21,   12,  325,
          529,   10, 4404,   20,   35,  205,   48, 1182,    6, 7725,   23,   12,
         1448,    7,   23,   55, 3018,   75,    5,  231,   94,   90,   33,  101,
         1590,   23,   27,   27,   23,  335,  898,   41,    0, 3590,   19,  120,
            9,   40, 1462,  206,   15,   40,   86,   14,  449,   74,   23,  127,
         3127,    5,  108,   15,   31,  107,   23,   37],
        [  30,  721,    7,    0,  509,  598,   10, 2428,   26,   24,   10,   11,
           89,    5,   11,   16,  752,  381,   19,    0,   84,   11,   43,   22,
         6062,   10,   62,  137,    5,    8, 1123,  358,   12,    0,  906,   29,
          185,    5,    6,    7,    6,  195,  199,   68,    5,   33,    6, 1037,
          442,   15, 1895,   11,   32, 1348,   31,   31,  285,   11,    8,    5,
           83,   37,   37,   87,  205,   73, 2431,   11,   15,  279,  431,   35,
           86,  213,    6,   17,    9,   39,   10, 2682,   47,   10,  185,  464,
           10,   93,  307,   17,    6, 1029,    3,   37,   14,   15,   12,   35,
            5,   93, 5150, 2835,   68,   13, 4833,   22,    7,  102,   85,   51,
         5178,  341,   24,   11,    9,   71,   20,   59,    6,   25,   37,   12,
           11,  288,   10,   12,    8,   19,  696,   13],
        [  43,    6,   13,  357,   10,    9,  178,    0,    5,   11,  115,  261,
          251,    3,   31,    8,    0,   14, 1129,   25, 4351,  212,    3,   69,
          167,   89,  618,   25,    3,   10,  577,  285,  845,   11,    8, 1413,
           45,    3,   21,   25,   11,    6,    6,  503,   35,  565,   12,    6,
         2099,   10,  733,  359, 2845,   18,   13,   54,  251,  450,   21,    3,
           25,  370, 2300,    6,   22,   76,   24,   79,    9,   45,   15,    9,
           20,  474,  759,    9,  896,    9,   13,  644,   15,   82,  246, 1221,
          438,   16,   73,    9,    7,  616,    1,   10,  283,    7,   10,    9,
            3, 3397,  705, 1140,   26,    5,    0,   36,   30,    6,   29,   20,
            6,    6,    0,   13,  246,    6,    9,   38,    7,   26,   79,   53,
          824,    9,   87,  488,   10,   64,   25,    5],
        [   3,    7,    5,   11,   63,   37,   33,   11,    3,  145,   17,   80,
           11,    1,   13,    9,    5,    7,   30,  177,    0,   16,    1,   13,
           30,   82,    5,   18,    1,   13,   10,   10,   25,   16,   58,   51,
           18,    1,   15,   66,  824,   12,    7,    5,    9,   49,   54,    7,
           11, 2796,   21, 1623,    5,    3,    5,   22,   11,   50,   23,    1,
           66,   10,   50,   12,  221,   85,   12,  203, 3137,   18,    3,  140,
           35,   14,   16, 7620,  765,    0,   19,    7,    9,  769,    6,   15,
          664,    8,  196,   16,   62,   46,    1,  164,   16,    0,  458,  104,
            1,   12,   12,   14,   70,    3,    5,    8,   18,   11,  618,   13,
           12,    7,  208,   96,    6,    7, 3321, 4103,   29,   43,   82,  544,
         1461,  731,    6,   25,   81,   53,  196,    3],
        [   1,  167,    3,   34,  254,  129, 1506,   17,    1,   35,    7,   43,
           26,    1,    5,    0,   10,   13,   18,    5, 7582,    8,    1,   96,
           18,   14,    3,    3,    1,   96,    0,   40,   66,  101,   83,   14,
            3,    1,    9,    5,   14,   31,   16,    3,  262,    5,  600,   13,
           52,  112,    9,   11,    3,    1,    3,   36,  490,    6,  287,    1,
            5,  260,   10, 1051,   24,  129,   74,    6, 1105,    3,    1,   81,
            9,   11,  274,    8,   12,   75,   12,    3,   13,    7,    7,   12,
           52,    3,    0,    8,   13,   11,    1,    6,    8,   48,  137,    6,
            1, 2016,  544,   12,    5,    1,  270,    3,    3,   13,    5,   15,
           25, 7595,   38,    5,    7,  959,    5,    5,   16,    3,    6,   18,
           14, 6750,   12,   43,   15,   54,  199,    1],
        [   1,   25,    1,    8,   69,    8, 3929,    9,    1,    9,  170,    3,
            5,    1,    3,   14,   25,    5,    3,    3,   11,    3,    1,    5,
            3,    7,    1,    1,    1,    5,   15,   46,    5,    8,   32,    7,
            1,    1,   40,    3,    7,  906,    8,    1,   16,   10,    3,    5,
         4016,   14,  104,   25,    1,    1,    1,    8,    5,    7,   13,    1,
            3, 1975,   40,    5,   59,   38,   73,    7,   15,    1,    1,   15,
           34,  263,    8,    3,   57,   12,   69,    1,  130,  488,   41,    0,
           65,    1,    3,    3,    5,   13,    1,   11,    3, 1797,   54,   27,
            1,  225,    3,  261,    3,    1,   26,    1,    1,    5,    3,    3,
           18,  967,   65,    3,   52,    5,  670,    3,    8,    1,    7,    3,
            7,   11,  141,    3,    7,  216,   39,    1],
        [   1,    5,    1,   74,   32,   15,    8,   54,    1,   90,   30,    1,
            3,    1,    1,   21,   66,    3,    1,    1,   17,    1,    1,    3,
            1,   13,    1,    1,    1,    3,    9,    6,    3,   95,    5,   13,
            1,    1, 1844,    1, 2784,   15,    3,    1,  196,   16,    1,    3,
          175,   11,   78,   66,    1,    1,    1,    3,    3,   16,    5,    1,
            1,   11,   51,    3,    9,   40,    3,   13,    9,    1,    1,    7,
           17,    0,    3,    1,   41,  347,    9,    1,    5, 1860,   76,   22,
          103,    1,    1,    1,    3,    5,    1,   38,    1,    5,   76,   29,
            1,   11,    1,   16,    1,    1,    5,    1,    1,    3,    1,    1,
            3,    8,   36,    1,   45,    3,   20,    1,    3,    1,  164,    1,
           60,   13,    6,    1,  461,    3,    3,    1],
        [   1,   10,    1,   54,  114,    9,    3,   22,    1, 1427,   18,    1,
            1,    1,    1, 1117,    5,    1,    1,    1,    9,    1,    1,    1,
            1,    5,    1,    1,    1,    1,   13,    7,    1,  250,    3,    5,
            1,    1,  202,    1,   13,    7,    1,    1,  232,    8,    1,    1,
            7,   14,    6,    5,    1,    1,    1,    1,    1,    8,    3,    1,
            1,   13,    6,    1,  318,  586,    1,    5,  298,    1,    1,  898,
          539,    5,    1,    1,   76,   65,   48,    1,    3,  105,    3,   36,
            3,    1,    1,    1,    1,   10,    1,   13,    1,    3,    3,   16,
            1,   26,    1,    8,    1,    1,    3,    1,    1,    1,    1,    1,
            1,    3,    8,    1,   43,    1,    9,    1,    1,    1,    5,    1,
          116,    5,   59,    1,    5,    1,    1,    1],
        [   1,   80,    1,    3,    5,   13,    1,   36,    1,  844,    3,    1,
            1,    1,    1,  670,    3,    1,    1,    1, 1415,    1,    1,    1,
            1,    3,    1,    1,    1,    1,    5,   39,    1,   82,    1,    3,
            1,    1,   11,    1,  865,   13,    1,    1,    8,    3,    1,    1,
           80,    9,   27,    3,    1,    1,    1,    1,    1,    3,    1,    1,
            1,    5,    7,    1,    5,    7,    1,    3,    5,    1,    1,    6,
          694,   61,    1,    1,    3,   18, 1791,    1,    1,    3,    1,    8,
            1,    1,    1,    1,    1,   16,    1,    5,    1,    1,    1,    8,
            1,   70,    1,   10,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    3,    1,    3,    1,  740,    1,    1,    1,    3,    1,
            5,    3,   62,    1,   93,    1,    1,    1],
        [   1,   18,    1,    1,    3,    5,    1,    8,    1,  180,    1,    1,
            1,    1,    1,   20,    1,    1,    1,    1, 4252,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    3,    9,    1,    6,    1,    1,
            1,    1,   13,    1,    5,    5,    1,    1,   10,    1,    1,    1,
          103,  332,   29,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    3,   16,    1,   38,   13,    1,    1,    3,    1,    1,   11,
           10, 2045,    1,    1,    1,    3, 2040,    1,    1,    1,    1,    3,
            1,    1,    1,    1,    1,    8,    1,    3,    1,    1,    1,    3,
            1,    5,    1,   13,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    0,    1,    1,    1,    1,    1,
            3,    1,   13,    1,    8,    1,    1,    1],
        [   1,    3,    1,    1,    1,    3,    1,    3,    1,  457,    1,    1,
            1,    1,    1,   13,    1,    1,    1,    1, 2479,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,   49,    1,   10,    1,    1,
            1,    1,   96,    1,    3,    3,    1,    1,   13,    1,    1,    1,
           76,  199,  113,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,  890,    1, 3946,    5,    1,    1,    1,    1,    1,   23,
          219,  734,    1,    1,    1,    1,    7,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    3,    1,    1,    1,    1,    1,    1,
            1,    3,    1,    5,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    5,    1,    1,    1,    1,    1,
            1,    1,    5,    1,    9,    1,    1,    1],
        [   1,    1,    1,    1,    1,    1,    1,    1,    1,    7,    1,    1,
            1,    1,    1,    5,    1,    1,    1,    1,    7,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    5,    1,  260,    1,    1,
            1,    1,    5,    1,    1,    1,    1,    1,  229,    1,    1,    1,
            3,  868,    8,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    8,    1,    8,    3,    1,    1,    1,    1,    1,  712,
           11,    5,    1,    1,    1,    1,  812,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    3,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,  166,    1,    1,    1,    1,    1,
            1,    1,    3,    1,  549,    1,    1,    1],
        [   1,    1,    1,    1,    1,    1,    1,    1,    1,   13,    1,    1,
            1,    1,    1,    3,    1,    1,    1,    1,   45,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    3,    1,  742,    1,    1,
            1,    1,    3,    1,    1,    1,    1,    1,  234,    1,    1,    1,
            1,   30,    3,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    3,    1,    3,    1,    1,    1,    1,    1,    1,   49,
           16,    3,    1,    1,    1,    1,   15,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,   15,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    6,    1,    1,    1],
        [   1,    1,    1,    1,    1,    1,    1,    1,    1,    5,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,   18,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    0,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    5,    1,    1,    1,
            1,   18,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    5,
            8,    1,    1,    1,    1,    1,    9,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    9,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    7,    1,    1,    1],
        [   1,    1,    1,    1,    1,    1,    1,    1,    1,    3,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    3,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,   11,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    3,    1,    1,    1,
            1,    3,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    3,
            3,    1,    1,    1,    1,    1,   13,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,   32,    1,    1,    1,    1,    1,
            1,    1,    1,    1,  862,    1,    1,    1],
        [   1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,   16,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    5,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    6,    1,    1,    1,    1,    1,
            1,    1,    1,    1,   53,    1,    1,    1],
        [   1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    8,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    3,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,   47,    1,    1,    1,    1,    1,
            1,    1,    1,    1,   41,    1,    1,    1],
        [   1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    3,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,   29,    1,    1,    1,    1,    1,
            1,    1,    1,    1, 1021,    1,    1,    1],
        [   1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1, 3938,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    3,    1,    1,    1],
        [   1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    7,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1],
        [   1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,   13,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1],
        [   1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    5,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1],
        [   1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    3,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1]], device='cuda:0')
29
128

4.2.2 Seq2Seq

主要功能:

  • 接收输入/源句子
  • 使用Encoder生成上下文向量
  • 使用Decoder生成预测输出/目标句子 再看一下整体的模型


    image.png

确定encoder和decoder每一层的数目、隐藏层。
下面是实现代码

# Seq2Seq
class Seq2Seq(nn.Module):
    def __init__(self, encoder,decoder,device):
        super(Seq2Seq,self).__init__()
        self.encoder=encoder
        self.decoder=decoder
        self.device=device
        assert encoder.hid_dim==decoder.hid_dim, "Hidden dimensions of encoder and decoder must be equal!"
        assert encoder.n_layers==decoder.n_layers, "Num_Layers of encoder and decoder must be equal!"
        
    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        # src = [src sent len, batch size]
        # trg = [trg sent len, batch size]
        # teacher_forcing_ratio是使用教师强制的概率
        # 例如。如果teacher_forcing_ratio是0.75,我们75%的时间使用groundtruth输入
        batch_size=trg.shape[1]
        max_len=trg.shape[0]
        trg_vocab_size=self.decoder.output_dim
        
        # 创建输出张量,存储我们所有的预测
        outputs=torch.zeros(max_len,batch_size,trg_vocab_size).to(self.device)
        # 输入源语到encoder,然后获取最终的隐藏和单元状态
        hidden, cell=self.encoder(src)
        # decoder第一个输入的是句子的最开始的token,也就是那个<sos>标记,
        input=trg[0,:]
        # max_len就是行数
        for t in range(1, max_len):
            # 将输入,先前隐藏和前一个单元状态传递给Decoder
            # 接收预测,来自Decoder下一个隐藏状态和下一个单元状态
            output, hidden, cell=self.decoder(input,hidden,cell)
            # 将我们的预测,输出放在我们的预测张量中
            outputs[t]=output
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.max(1)[1]
            input = (trg[t] if teacher_force else top1)
        return outputs

可以看到在Seq2Seq这个模块的作用是整合encoder和decoder,将两者的数据打通,batch_size的大小就是在句子中一个词转为向量的长度"shape[1]",而max_len就是这个句子所包含的词的个数。这个在以后的模型中也会很重要,每次我都会注意一下。

4.2.3 encoder

请注意,我们只将第一层的隐藏状态作为输入传递给第二层,而不是单元状态。

image.png

下面重点来了,encoder有哪些参数其实在最开始的参数设置里面就可以看到这些。

  • input_dim输入encoder的one-hot向量维度,这个和输入词汇大小一致
  • emb_dim嵌入层的维度,这一层将one-hot向量转为密度向量
  • hid_dim隐藏层和cell状态维度
  • n_layersRNN的层数
  • dropout是要使用的丢失量。这是一个防止过度拟合的正则化参数。

没什么特别的地方。

class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):
        super(Encoder, self).__init__()
        self.input_dim=input_dim
        self.emb_dim=emb_dim
        self.hid_dim=hid_dim
        self.n_layers=n_layers
        self.dropout=dropout
        
        self.embedding=nn.Embedding(input_dim,emb_dim)
        self.rnn=nn.LSTM(emb_dim,hid_dim,n_layers,dropout=dropout)
        self.dropout=nn.Dropout(dropout)
    def forward(self, src):
        embedded=self.dropout(self.embedding(src))
        outputs, (hidden,cell)=self.rnn(embedded)
        return hidden, cell

4.2.4 decoder

Decoder同样也是一个LSTM。


image.png

Decoder的初始隐藏和单元状态是我们的上下文向量,它们是来自同一层的Encoder的最终隐藏和单元状态。
接下来将隐藏状态传递给Linear层,预测目标序列下一个标记应该是什么。
Decoder的参数和Encoder类似,其中output_dim是将要输入到Decoder的one-hot向量。

  • 在forward方法中,获取到了输入token、上一层的隐藏状态和单元状态。解压之后加入句子长度维度。
  • 接下来与Encoder类似,传入嵌入层并使用dropout,然后将这批嵌入式令牌传递到具有先前隐藏和单元状态的RNN。这产生了一个输出(来自RNN顶层的隐藏状态),一个新的隐藏状态(每个层一个,堆叠在彼此之上)和一个新的单元状态(每层也有一个,堆叠在彼此的顶部))。
  • 然后我们通过线性层传递输出(在除去句子长度维度之后)以接收我们的预测。然后我们返回预测,新的隐藏状态和新的单元状态。
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):
        super(Decoder, self).__init__()
        self.emb_dim=emb_dim
        self.hid_dim=hid_dim
        self.output_dim=output_dim
        self.n_layers=n_layers
        self.dropout=dropout
        
        self.embedding=nn.Embedding(output_dim, emb_dim)
        self.rnn=nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)
        self.out=nn.Linear(hid_dim, output_dim)
        self.dropout=nn.Dropout(dropout)
        
    # 这里的hidden, cell是encoder输出的结果
    def forward(self, input, hidden, cell):
        input=input.unsqueeze(0)
        embedded=self.dropout(self.embedding(input))
        output, (hidden,cell)=self.rnn(embedded,(hidden,cell))
        prediction=self.out(output.squeeze(0))
        # 这里输出的prediction就是预测数据
        return prediction, hidden, cell

4.3 训练模型

enc=Encoder(INPUT_DIM,ENC_EMB_DIM,HID_DIM,N_LAYERS,ENC_DROPOUT)
dec=Decoder(OUTPUT_DIM,DEC_EMB_DIM,HID_DIM,N_LAYERS,DEC_DROPOUT)
model=Seq2Seq(enc,dec,device).to(device)
model
Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(7855, 256)
    (rnn): LSTM(256, 512, num_layers=4, dropout=0.5)
    (dropout): Dropout(p=0.5)
  )
  (decoder): Decoder(
    (embedding): Embedding(5893, 256)
    (rnn): LSTM(256, 512, num_layers=4, dropout=0.5)
    (out): Linear(in_features=512, out_features=5893, bias=True)
    (dropout): Dropout(p=0.5)
  )
)
def init_weights(m):
    for name, param in m.named_parameters():
        nn.init.uniform_(param.data, -0.08, 0.08)
        
model.apply(init_weights)
Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(7855, 256)
    (rnn): LSTM(256, 512, num_layers=4, dropout=0.5)
    (dropout): Dropout(p=0.5)
  )
  (decoder): Decoder(
    (embedding): Embedding(5893, 256)
    (rnn): LSTM(256, 512, num_layers=4, dropout=0.5)
    (out): Linear(in_features=512, out_features=5893, bias=True)
    (dropout): Dropout(p=0.5)
  )
)
def count_parameters(model):
    # pytorch.numel返回矩阵内所有元素个数
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')
The model has 22,304,005 trainable parameters
optimizer = optim.Adam(model.parameters())
# stoi允许访问包含单词及其索引的字典。
PAD_IDX = TRG.vocab.stoi['<pad>']
criterion = nn.CrossEntropyLoss(ignore_index = PAD_IDX)
def train(model, iterator, optimizer, criterion, clip):
    
    model.train()
    epoch_loss = 0
    for i, batch in enumerate(iterator):
        src = batch.src
        trg = batch.trg
        
        optimizer.zero_grad()
        output = model(src, trg)
        
        #trg = [trg sent len, batch size]
        #output = [trg sent len, batch size, output dim]
        
        output = output[1:].view(-1, output.shape[-1])
        trg = trg[1:].view(-1)
        
        #trg = [(trg sent len - 1) * batch size]
        #output = [(trg sent len - 1) * batch size, output dim]
        
        loss = criterion(output, trg)
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)
def evaluate(model, iterator, criterion):
    
    model.eval()
    
    epoch_loss = 0
    
    with torch.no_grad():
    
        for i, batch in enumerate(iterator):

            src = batch.src
            trg = batch.trg

            output = model(src, trg, 0) #turn off teacher forcing

            #trg = [trg sent len, batch size]
            #output = [trg sent len, batch size, output dim]

            output = output[1:].view(-1, output.shape[-1])
            trg = trg[1:].view(-1)

            #trg = [(trg sent len - 1) * batch size]
            #output = [(trg sent len - 1) * batch size, output dim]

            loss = criterion(output, trg)
            
            epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs
N_EPOCHS = 2
CLIP = 1

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
#     valid_loss = evaluate(model, valid_iterator, criterion)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
    best_valid_loss = valid_loss
    torch.save(model.state_dict(), 'tut1-model.pt')
    
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')
>Epoch: 01 | Time: 1m 57s
    Train Loss: 4.981 | Train PPL: 145.638
     Val. Loss: 4.938 |  Val. PPL: 139.557
Epoch: 02 | Time: 1m 57s
    Train Loss: 4.690 | Train PPL: 108.829
     Val. Loss: 4.950 |  Val. PPL: 141.169
Epoch: 03 | Time: 1m 56s
    Train Loss: 4.421 | Train PPL:  83.212
     Val. Loss: 4.642 |  Val. PPL: 103.731
Epoch: 04 | Time: 1m 57s
    Train Loss: 4.187 | Train PPL:  65.833
     Val. Loss: 4.560 |  Val. PPL:  95.608
Epoch: 05 | Time: 1m 57s
    Train Loss: 4.045 | Train PPL:  57.138
     Val. Loss: 4.429 |  Val. PPL:  83.808
Epoch: 06 | Time: 1m 56s
    Train Loss: 3.939 | Train PPL:  51.373
     Val. Loss: 4.400 |  Val. PPL:  81.460
Epoch: 07 | Time: 1m 56s
    Train Loss: 3.862 | Train PPL:  47.579
     Val. Loss: 4.370 |  Val. PPL:  79.046
Epoch: 08 | Time: 1m 57s
    Train Loss: 3.755 | Train PPL:  42.738
     Val. Loss: 4.369 |  Val. PPL:  78.992
Epoch: 09 | Time: 1m 56s
    Train Loss: 3.672 | Train PPL:  39.322
     Val. Loss: 4.223 |  Val. PPL:  68.248
Epoch: 10 | Time: 1m 57s
    Train Loss: 3.622 | Train PPL:  37.402
     Val. Loss: 4.201 |  Val. PPL:  66.773
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 196,165评论 5 462
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 82,503评论 2 373
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 143,295评论 0 325
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 52,589评论 1 267
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 61,439评论 5 358
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 46,342评论 1 273
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 36,749评论 3 387
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 35,397评论 0 255
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 39,700评论 1 295
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 34,740评论 2 313
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 36,523评论 1 326
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 32,364评论 3 314
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 37,755评论 3 300
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,024评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 30,297评论 1 251
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 41,721评论 2 342
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 40,918评论 2 336

推荐阅读更多精彩内容