【NLP保姆级教程】手把手带你fastText文本分类(附代码)

本文首发于微信公众号:NewBeeNLP


写在前面

继续NLP保姆级教程系列,今天的教程是基于FAIR的Bag of Tricks for Efficient Text Classification[1]。也就是我们常说的fastText。

最让人欣喜的这篇论文配套提供了fasttext工具包。这个工具包代码质量非常高,论文结果一键还原,目前已经是包装地非常专业了,这是fastText官网和其github代码库,以及提供了python接口,可以直接通过pip安装。这样准确率高又快的模型绝对是实战利器。

为了更好地理解fasttext原理,我们现在直接复现来一遍,但是代码中仅仅实现了最简单的基于单词的词向量求平均,并未使用b-gram的词向量,所以自己实现的文本分类效果会低于facebook开源的库。

论文概览

❝We can train fastText on more than one billion words in less than ten minutes using a standard multicore CPU, and classify half a million sentences among 312K classes in less than a minute.

首先引用论文中的一段话来看看作者们是怎么评价fasttext模型的表现的。

这篇论文的模型非常之简单,之前了解过word2vec的同学可以发现这跟CBOW的模型框架非常相似。

对应上面这个模型,比如输入是一句话,到就是这句话的单词或者是n-gram。每一个都对应一个向量,然后对这些向量取平均就得到了文本向量,然后用这个平均向量取预测标签。当类别不多的时候,就是最简单的softmax;当标签数量巨大的时候,就要用到「hierarchical softmax」了。

模型真的很简单,也没什么可以说的了。下面提一下论文中的两个tricks:

  • 「hierarchical softmax」
    类别数较多时,通过构建一个霍夫曼编码树来加速softmax layer的计算,和之前word2vec中的trick相同
  • 「N-gram features」
    只用unigram的话会丢掉word order信息,所以通过加入N-gram features进行补充用hashing来减少N-gram的存储
  • 看了论文的实验部分,如此简单的模型竟然能取得这么好的效果 !

    但是也有人指出论文中选取的数据集都是对句子词序不是很敏感的数据集,所以得到文中的试验结果并不奇怪。

    代码实现

    看完阉割版代码大家记得去看看源码噢~跟之前系列的一样,定义一个fastTextModel类,然后写网络框架,输入输出placeholder,损失,训练步骤等。

    class fastTextModel(BaseModel):
    """
    A simple implementation of fasttext for text classification
    """
    def __init__(self, sequence_length, num_classes, vocab_size,
    embedding_size, learning_rate, decay_steps, decay_rate,
    l2_reg_lambda, is_training=True,
    initializer=tf.random_normal_initializer(stddev=0.1)):
    self.vocab_size = vocab_size
    self.embedding_size = embedding_size
    self.num_classes = num_classes
    self.sequence_length = sequence_length
    self.learning_rate = learning_rate
    self.decay_steps = decay_steps
    self.decay_rate = decay_rate
    self.is_training = is_training
    self.l2_reg_lambda = l2_reg_lambda
    self.initializer = initializer

    self.input_x = tf.placeholder(tf.int32, [None, self.sequence_length], name='input_x')
    self.input_y = tf.placeholder(tf.int32, [None, self.num_classes], name='input_y')

    self.global_step = tf.Variable(0, trainable=False, name='global_step')
    self.instantiate_weight()
    self.logits = self.inference()
    self.loss_val = self.loss()
    self.train_op = self.train()

    self.predictions = tf.argmax(self.logits, axis=1, name='predictions')
    correct_prediction = tf.equal(self.predictions, tf.argmax(self.input_y, 1))
    self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'), name='accuracy')

    def instantiate_weight(self):
    with tf.name_scope('weights'):
    self.Embedding = tf.get_variable('Embedding', shape=[self.vocab_size, self.embedding_size],
    initializer=self.initializer)
    self.W_projection = tf.get_variable('W_projection', shape=[self.embedding_size, self.num_classes],
    initializer=self.initializer)
    self.b_projection = tf.get_variable('b_projection', shape=[self.num_classes])


    def inference(self):
    """
    1. word embedding
    2. average embedding
    3. linear classifier
    :return:
    """
    # embedding layer
    with tf.name_scope('embedding'):
    words_embedding = tf.nn.embedding_lookup(self.Embedding, self.input_x)
    self.average_embedding = tf.reduce_mean(words_embedding, axis=1)

    logits = tf.matmul(self.average_embedding, self.W_projection) +self.b_projection

    return logits


    def loss(self):
    # loss
    with tf.name_scope('loss'):
    losses = tf.nn.softmax_cross_entropy_with_logits(labels=self.input_y, logits=self.logits)
    data_loss = tf.reduce_mean(losses)
    l2_loss = tf.add_n([tf.nn.l2_loss(cand_var) for cand_var in tf.trainable_variables()
    if 'bias' not in cand_var.name]) * self.l2_reg_lambda
    data_loss += l2_loss * self.l2_reg_lambda
    return data_loss

    def train(self):
    with tf.name_scope('train'):
    learning_rate = tf.train.exponential_decay(self.learning_rate, self.global_step,
    self.decay_steps, self.decay_rate,
    staircase=True)

    train_op = tf.contrib.layers.optimize_loss(self.loss_val, global_step=self.global_step,
    learning_rate=learning_rate, optimizer='Adam')

    return train_op
    def prepocess():
    """
    For load and process data
    :return:
    """
    print("Loading data...")
    x_text, y = data_process.load_data_and_labels(FLAGS.positive_data_file, FLAGS.negative_data_file)
    # bulid vocabulary
    max_document_length = max(len(x.split(' ')) for x in x_text)
    vocab_processor = learn.preprocessing.VocabularyProcessor(max_document_length)
    x = np.array(list(vocab_processor.fit_transform(x_text)))

    # shuffle
    np.random.seed(10)
    shuffle_indices = np.random.permutation(np.arange(len(y)))
    x_shuffled = x[shuffle_indices]
    y_shuffled = y[shuffle_indices]

    # split train/test dataset
    dev_sample_index = -1 * int(FLAGS.dev_sample_percentage * float(len(y)))
    x_train, x_dev = x_shuffled[:dev_sample_index], x_shuffled[dev_sample_index:]
    y_train, y_dev = y_shuffled[:dev_sample_index], y_shuffled[dev_sample_index:]
    del x, y, x_shuffled, y_shuffled

    print('Vocabulary Size: {:d}'.format(len(vocab_processor.vocabulary_)))
    print('Train/Dev split: {:d}/{:d}'.format(len(y_train), len(y_dev)))
    return x_train, y_train, vocab_processor, x_dev, y_dev


    def train(x_train, y_train, vocab_processor, x_dev, y_dev):
    with tf.Graph().as_default():
    session_conf = tf.ConfigProto(
    # allows TensorFlow to fall back on a device with a certain operation implemented
    allow_soft_placement= FLAGS.allow_soft_placement,
    # allows TensorFlow log on which devices (CPU or GPU) it places operations
    log_device_placement=FLAGS.log_device_placement
    )
    sess = tf.Session(config=session_conf)
    with sess.as_default():
    # initialize cnn
    fasttext = fastTextModel(sequence_length=x_train.shape[1],
    num_classes=y_train.shape[1],
    vocab_size=len(vocab_processor.vocabulary_),
    embedding_size=FLAGS.embedding_size,
    l2_reg_lambda=FLAGS.l2_reg_lambda,
    is_training=True,
    learning_rate=FLAGS.learning_rate,
    decay_steps=FLAGS.decay_steps,
    decay_rate=FLAGS.decay_rate
    )

    # output dir for models and summaries
    timestamp = str(time.time())
    out_dir = os.path.abspath(os.path.join(os.path.curdir, 'run', timestamp))
    if not os.path.exists(out_dir):
    os.makedirs(out_dir)
    print('Writing to {} \n'.format(out_dir))

    # checkpoint dir. checkpointing – saving the parameters of your model to restore them later on.
    checkpoint_dir = os.path.abspath(os.path.join(out_dir, FLAGS.ckpt_dir))
    checkpoint_prefix = os.path.join(checkpoint_dir, 'model')
    if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
    saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints)

    # Write vocabulary
    vocab_processor.save(os.path.join(out_dir, 'vocab'))

    # Initialize all
    sess.run(tf.global_variables_initializer())


    def train_step(x_batch, y_batch):
    """
    A single training step
    :param x_batch:
    :param y_batch:
    :return:
    """
    feed_dict = {
    fasttext.input_x: x_batch,
    fasttext.input_y: y_batch,
    }
    _, step, loss, accuracy = sess.run(
    [fasttext.train_op, fasttext.global_step, fasttext.loss_val, fasttext.accuracy],
    feed_dict=feed_dict
    )
    time_str = datetime.datetime.now().isoformat()
    print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy))

    def dev_step(x_batch, y_batch):
    """
    Evaluate model on a dev set
    Disable dropout
    :param x_batch:
    :param y_batch:
    :param writer:
    :return:
    """
    feed_dict = {
    fasttext.input_x: x_batch,
    fasttext.input_y: y_batch,
    }
    step, loss, accuracy = sess.run(
    [fasttext.global_step, fasttext.loss_val, fasttext.accuracy],
    feed_dict=feed_dict
    )
    time_str = datetime.datetime.now().isoformat()
    print("dev results:{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy))

    # generate batches
    batches = data_process.batch_iter(list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs)
    # training loop
    for batch in batches:
    x_batch, y_batch = zip(*batch)
    train_step(x_batch, y_batch)
    current_step = tf.train.global_step(sess, fasttext.global_step)
    if current_step % FLAGS.validate_every == 0:
    print('\n Evaluation:')
    dev_step(x_dev, y_dev)
    print('')

    path = saver.save(sess, checkpoint_prefix, global_step=current_step)
    print('Save model checkpoint to {} \n'.format(path))

    def main(argv=None):
    x_train, y_train, vocab_processor, x_dev, y_dev = prepocess()
    train(x_train, y_train, vocab_processor, x_dev, y_dev)

    if __name__ == '__main__':
    tf.app.run()

    对啦,我这里使用的数据集还是之前训练CNN时的那一份

    「完整代码可以在公众号后台回复"ft"获取。」


    本文参考资料

    [1]Bag of Tricks for Efficient Text Classification: https://arxiv.org/abs/1607.01759


      【NLP保姆级教程】手把手带你RNN文本分类(附代码)

      【NLP保姆级教程】手把手带你CNN文本分类(附代码)

      BERT源码分析(PART III)

    本文首发于微信公众号:NewBeeNLP

    ©著作权归作者所有,转载或内容合作请联系作者
    • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
      沈念sama阅读 206,968评论 6 482
    • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
      沈念sama阅读 88,601评论 2 382
    • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
      开封第一讲书人阅读 153,220评论 0 344
    • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
      开封第一讲书人阅读 55,416评论 1 279
    • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
      茶点故事阅读 64,425评论 5 374
    • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
      开封第一讲书人阅读 49,144评论 1 285
    • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
      沈念sama阅读 38,432评论 3 401
    • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
      开封第一讲书人阅读 37,088评论 0 261
    • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
      沈念sama阅读 43,586评论 1 300
    • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
      茶点故事阅读 36,028评论 2 325
    • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
      茶点故事阅读 38,137评论 1 334
    • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
      沈念sama阅读 33,783评论 4 324
    • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
      茶点故事阅读 39,343评论 3 307
    • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
      开封第一讲书人阅读 30,333评论 0 19
    • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
      开封第一讲书人阅读 31,559评论 1 262
    • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
      沈念sama阅读 45,595评论 2 355
    • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
      茶点故事阅读 42,901评论 2 345

    推荐阅读更多精彩内容