理工男的文科梦 —— LSTM深度学习写春联

笔者作为一名根正苗红的理工男,内心却常常有很多文艺青年才会有的想法,例如写首诗、做首词,甚至包括春节写副对联,空有一番愿望却胸无点墨。随着对机器学习和深度学习的了解,逐渐萌生了使用机器帮助笔者完成文艺青年的转型。:)
本文借助递归神经网络RDD的变种之一LSTM算法,对收集到的6900多条对联进行学习,训练好模型后可以由机器写出对联。

递归神经网络与LSTM

故事从人工神经网络开始,人工神经网络诞生已久。如下图所示,神经网络的基本结构由输入层、输出层和一个或多个隐含层组成。

多层神经网络

全连接的神经网络下一层神经元的输入由上一层所有神经元的输出决定,因此带来了一个严重的问题即参数数量过大导致无法训练。因此,随时神经网络的发展,衍生了一系列的变化。比较流行的有应用于图像识别领域的卷积神经网络CNN、应用于自然语言处理的递归神经网络RNN。本文应用到的LSTM算法即为RNN的一种形态。RNN解决了这样的问题:即样本出现的时间顺序对于自然语言处理、语音识别、手写体识别等应用非常重要,神经元的输出可以在下一个时间戳直接作用到自身。因此RNN很适合处理时序对结果影响较深的领域。
关于RNN和LSTM原理的说明可以移步 http://www.jianshu.com/p/9dc9f41f0b29 ,本文不多加赘言。

RNN
由LSTM作诗引发

由于LSTM算法非常适用自然语言处理领域,因此网上出现了很多应用LSTM做文字领域的尝试,例如: LSTM写诗 中使用LSTM写诗,LSTM创作歌词中使用LSTM模仿歌手风格写歌词,以及使用LSTM算法给小孩起名(是多么不靠谱的粑粑麻麻)。
因此,笔者突发想法,如果给一个足够的春联训练样本,一样可以照猫画老虎,训练一个可以写对联的文艺“机器模型”。因此,问题就分解为:找样本、写算法、训练、应用模型。

春联样本搜集和规整

借助于强大的度娘,费劲九牛之力,从网上搜集了各式春联共6900对,其中上联下联之间是用","分割区分上下联,对联之间是用"。"区分一联的结束。样式如下:
训练样本

这些样本将会在训练阶段进行类型转换并输入给LSTM模型中。如果您也想试下本文案例,请私信我这些样本(毕竟搜集训练样本是个苦差事(: )

LSTM算法

本文使用TensorFlow进行建模,TensorFlow就无需多言,是这个领域目前最活跃的框架。写对联的算法主要工作包括:根据样本数据产生LSTM输入数据和结果值;定义LSTM的模型以及损失函数;将训练数据喂给TensorFlow用来训练模型。接下来会逐步列举本例中使用的方法。

  • 训练数据转换
    由于样本数据是一条条汉字组成的对联,这样的数据是无法交给模型训练的,因此需要对样本数据进行转换。基本思想是:
    • 将样本的所有对联加载录入,统计出所有出现的汉字,并将汉字进行编码,例如:一共有10000个汉字出现在样本中,那么对出现的汉字按 0 - 999 进行编码,每个汉字对应一个编码。
    • 对原始样本进行编码转换,生成用数字编码表示的对联集。
    • 每条对联作为一个输入序列,每批次训练batch_size条,生成输入数据xdata,输出y值为xdata+1。因为文本分析的特点是有时序性。
couplet_file ="couplet.txt"
#对联
couplets = []
with open(couplet_file,'r') as f:
    for line in f:
        try:
            content = line.replace(' ','')
            if '_' in content or '(' in content or '(' in content or '《' in content or '[' in content:
                continue
            if len(content) < 5*3 or len(content) > 79*3:
                continue
            content = '[' + content + ']'
           # print chardet.detect(content)
            content = content.decode('utf-8')
            couplets.append(content)

        except Exception as e:
            pass

# 按字数排序
couplets = sorted(couplets,key=lambda line: len(line))
print('对联总数: %d'%(len(couplets)))
# 统计每个字出现次数
all_words = []
for couplet in couplets:
    all_words += [word for word in couplet]

counter = collections.Counter(all_words)
count_pairs = sorted(counter.items(), key=lambda x: -x[1])
words, _ = zip(*count_pairs)
words = words[:len(words)] + (' ',)
# 每个字映射为一个数字ID
word_num_map = dict(zip(words, range(len(words))))

to_num = lambda word: word_num_map.get(word, len(words))
couplets_vector = [ list(map(to_num, couplet)) for couplet in couplets]

# 每次取64首对联进行训练, 此参数可以调整
batch_size = 64
n_chunk = len(couplets_vector) // batch_size
x_batches = []
y_batches = []
for i in range(n_chunk):
    start_index = i * batch_size#起始位置
    end_index = start_index + batch_size#结束位置

    batches = couplets_vector[start_index:end_index]
    length = max(map(len,batches))#每个batches中句子的最大长度
    xdata = np.full((batch_size,length), word_num_map[' '], np.int32)
    for row in range(batch_size):
        xdata[row,:len(batches[row])] = batches[row]
    ydata = np.copy(xdata)
    ydata[:,:-1] = xdata[:,1:]
    x_batches.append(xdata)
    y_batches.append(ydata)
  • 定义LSTM模型

    • 使用TF api tf.nn.rnn_cell.BasicLSTMCell定义cell为一个128维的ht的cell。并使用MultiRNNCell 定义为两层的LSTM。
    • 对训练样本输入进行embedding化。
    • 使用tf.nn.dynamic_rnn计算输出值。(也可以通过循环step的方法,依次计算)
    • 加入softmax层。
def neural_network(rnn_size=128, num_layers=2):
    cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_size, state_is_tuple=True)
    cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)

    initial_state = cell.zero_state(batch_size, tf.float32)

    with tf.variable_scope('rnnlm'):
        softmax_w = tf.get_variable("softmax_w", [rnn_size, len(words)+1])
        softmax_b = tf.get_variable("softmax_b", [len(words)+1])
        with tf.device("/cpu:0"):
            embedding = tf.get_variable("embedding", [len(words)+1, rnn_size])
            inputs = tf.nn.embedding_lookup(embedding, input_data)

    outputs, last_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=initial_state, scope='rnnlm')
    output = tf.reshape(outputs,[-1, rnn_size])

    logits = tf.matmul(output, softmax_w) + softmax_b
    probs = tf.nn.softmax(logits)
    return logits, last_state, probs, cell, initial_state
  • 训练阶段
    • 使用TF sequence_loss_by_example计算所有examples(假设一句话有n个单词,一个单词及单词所对应的label就是一个example,所有example就是一句话中所有单词)的加权交叉熵损失。
    • tf.gradients 计算梯度,并使用clip_by_global_norm控制梯度爆炸的问题。梯度爆炸和梯度弥散的原因一样,都是因为链式法则求导的关系,导致梯度的指数级衰减。为了避免梯度爆炸,需要对梯度进行修剪。(来自网上的解释,不明觉厉(: )
    • 定义步长,步长过大,会很可能越过最优值,步长过小则使优化的效率过低,长时间无法收敛。因此learning rate是一个需要适当调整的参数。一个小技巧是,随时训练的进行,即沿着梯度方向收敛的过程中,适当减小步长,不至于错过最优解。在代码中 0.01 * (0.97 ** epoch),learing rate基数值为0.01, 系数为0.97的epoch方,可以看出epoch越大,learing rate越小。
    • 分批次将样本数据x_batches和y_batches喂给TF进行训练。
def train_neural_network():
    logits, last_state, _, _, _ = neural_network()
    targets = tf.reshape(output_targets, [-1])
    loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example([logits], [targets], [tf.ones_like(targets, dtype=tf.float32)], len(words))
    cost = tf.reduce_mean(loss)
    learning_rate = tf.Variable(0.0, trainable=False)
    tvars = tf.trainable_variables()
    grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), 5)
    optimizer = tf.train.AdamOptimizer(learning_rate)
    train_op = optimizer.apply_gradients(zip(grads, tvars))

    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        saver = tf.train.Saver(tf.all_variables())

        for epoch in range(100):
            sess.run(tf.assign(learning_rate, 0.01 * (0.97 ** epoch)))
            n = 0
            for batche in range(n_chunk):
                train_loss, _ , _ = sess.run([cost, last_state, train_op], feed_dict={input_data: x_batches[n], output_targets: y_batches[n]})
                n += 1
                print(epoch, batche, train_loss)
            if epoch % 7 == 0:
                saver.save(sess, './couplet.module', global_step=epoch)
  • 训练结束 , 诗性大发

经过漫长的训练(取决于样本数和迭代次数), loss控制在1.5左右。


loss

可以看到,经过100次的迭代训练,每7次保存一次(saver.save(sess, './couplet.module', global_step=epoch)), 最后的模型保存在couplet.module-98里。

modle

在eval阶段,使用saver.restore(sess, 'couplet.module-98') 将训练好的模型加载, 因为机器算出来的依旧是上文提到的数字编码,因此需要再将数字转为汉字。

好啦,来看看机器创作的对联吧, 是不是有点意思呢?

couplet
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 206,602评论 6 481
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 88,442评论 2 382
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 152,878评论 0 344
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 55,306评论 1 279
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 64,330评论 5 373
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,071评论 1 285
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,382评论 3 400
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,006评论 0 259
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 43,512评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,965评论 2 325
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,094评论 1 333
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,732评论 4 323
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,283评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,286评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,512评论 1 262
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,536评论 2 354
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,828评论 2 345

推荐阅读更多精彩内容

  • 神经结构进步、GPU深度学习训练效率突破。RNN,时间序列数据有效,每个神经元通过内部组件保存输入信息。 卷积神经...
    利炳根阅读 4,728评论 0 7
  • 作者 | 武维AI前线出品| ID:ai-front 前言 自然语言处理(简称NLP),是研究计算机处理人类语言的...
    AI前线阅读 2,562评论 0 8
  • 第二个Topic讲深度学习,承接前面的《浅谈机器学习基础》。 深度学习简介 前面也提到过,机器学习的本质就是寻找最...
    我偏笑_NSNirvana阅读 15,578评论 7 49
  • 当一个人知道,另一个人其实是在乎的, 可是沟通方式不对, 距离也远, 见面太少,触摸不到。 渐渐,渐渐地, 也就拖...
    我家门口的有条溪阅读 397评论 0 0
  • 回顾自己走过的乌漆嘛嘿的路,总有人为我点灯照我前行,虽然还是在摸黑前行,但至少有那么些瞬间让我感动不已,为我点灯的...
    云中君style阅读 771评论 0 1