Bert系列(三)——源码解读之Pre-train

pre-train是迁移学习的基础,虽然Google已经发布了各种预训练好的模型,而且因为资源消耗巨大,自己再预训练也不现实(在Google Cloud TPU v2 上训练BERT-Base要花费近500刀,耗时达到两周。在GPU上可想而知只会更贵),但是学习bert的预训练方法可以为我们弄懂整个bert的运行流程提供莫大的帮助。预训练涉及到的模块有点多,所以这也将会是一篇长文,在能简略的地方我尽量简略,还是那句话,我的文章只能是起到一个导读的作用,如果想摸清里面的各种细节还是要自己把源码过一遍的。

pre-train涉及到的模块分为以下三个,我将为大家一一介绍:

1.tokenization.py

2.create_pretraining_data.py

3.run_pretraining.py

其中tokenization是对原始句子内容的解析,分为BasicTokenizer和WordpieceTokenizer两个,不只是在预训练中,在fine-tune和推断过程同样要用到它;create_pretraining_data顾名思义就是将原始语料转换成适合模型预训练的输入数据;run_pretraining就是预训练的执行代码了。

一、tokenization.py

1、BasicTokenizer

class BasicTokenizer(object):
  """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""

  def __init__(self, do_lower_case=True):
    self.do_lower_case = do_lower_case

  def tokenize(self, text):
    """Tokenizes a piece of text."""
    text = convert_to_unicode(text)
    text = self._clean_text(text)

    text = self._tokenize_chinese_chars(text)

    orig_tokens = whitespace_tokenize(text)
    split_tokens = []
    for token in orig_tokens:
      if self.do_lower_case:
        token = token.lower()
        token = self._run_strip_accents(token)
      split_tokens.extend(self._run_split_on_punc(token))

    output_tokens = whitespace_tokenize(" ".join(split_tokens))
    return output_tokens

  def _run_strip_accents(self, text):
    """Strips accents from a piece of text."""
    text = unicodedata.normalize("NFD", text)
    output = []
    for char in text:
      cat = unicodedata.category(char)
      if cat == "Mn":
        continue
      output.append(char)
    return "".join(output)

  def _run_split_on_punc(self, text):
    """Splits punctuation on a piece of text."""
    chars = list(text)
    i = 0
    start_new_word = True
    output = []
    while i < len(chars):
      char = chars[i]
      if _is_punctuation(char):
        output.append([char])
        start_new_word = True
      else:
        if start_new_word:
          output.append([])
        start_new_word = False
        output[-1].append(char)
      i += 1

    return ["".join(x) for x in output]

  def _tokenize_chinese_chars(self, text):
    """Adds whitespace around any CJK character."""
    output = []
    for char in text:
      cp = ord(char)
      if self._is_chinese_char(cp):
        output.append(" ")
        output.append(char)
        output.append(" ")
      else:
        output.append(char)
    return "".join(output)

  def _is_chinese_char(self, cp):
    """Checks whether CP is the codepoint of a CJK character."""
    if ((cp >= 0x4E00 and cp <= 0x9FFF) or  #
        (cp >= 0x3400 and cp <= 0x4DBF) or  #
        (cp >= 0x20000 and cp <= 0x2A6DF) or  #
        (cp >= 0x2A700 and cp <= 0x2B73F) or  #
        (cp >= 0x2B740 and cp <= 0x2B81F) or  #
        (cp >= 0x2B820 and cp <= 0x2CEAF) or
        (cp >= 0xF900 and cp <= 0xFAFF) or  #
        (cp >= 0x2F800 and cp <= 0x2FA1F)):  #
      return True
    return False

  def _clean_text(self, text):
    """Performs invalid character removal and whitespace cleanup on text."""
    output = []
    for char in text:
      cp = ord(char)
      if cp == 0 or cp == 0xfffd or _is_control(char):
        continue
      if _is_whitespace(char):
        output.append(" ")
      else:
        output.append(char)
    return "".join(output)

BasicTokenizer的主要是进行unicode转换、标点符号分割、小写转换、中文字符分割、去除重音符号等操作,最后返回的是关于词的数组(中文是字的数组)

2、WordpieceTokenizer

class WordpieceTokenizer(object):
  """Runs WordPiece tokenziation."""

  def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
    self.vocab = vocab
    self.unk_token = unk_token
    self.max_input_chars_per_word = max_input_chars_per_word

  def tokenize(self, text):
    text = convert_to_unicode(text)
    output_tokens = []
    for token in whitespace_tokenize(text):
      chars = list(token)
      if len(chars) > self.max_input_chars_per_word:
        output_tokens.append(self.unk_token)
        continue
      is_bad = False
      start = 0
      sub_tokens = []
      while start < len(chars):
        end = len(chars)
        cur_substr = None
        while start < end:
          substr = "".join(chars[start:end])
          if start > 0:
            substr = "##" + substr
          if substr in self.vocab:
            cur_substr = substr
            break
          end -= 1
        if cur_substr is None:
          is_bad = True
          break
        sub_tokens.append(cur_substr)
        start = end

      if is_bad:
        output_tokens.append(self.unk_token)
      else:
        output_tokens.extend(sub_tokens)
    return output_tokens

WordpieceTokenizer的目的是将合成词分解成类似词根一样的词片。例如将"unwanted"分解成["un", "##want", "##ed"]这么做的目的是防止因为词的过于生僻没有被收录进词典最后只能以[UNK]代替的局面,因为英语当中这样的合成词非常多,词典不可能全部收录。

3、FullTokenizer

class FullTokenizer(object):
  """Runs end-to-end tokenziation."""

  def __init__(self, vocab_file, do_lower_case=True):
    self.vocab = load_vocab(vocab_file)
    self.inv_vocab = {v: k for k, v in self.vocab.items()}
    self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
    self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)

  def tokenize(self, text):
    split_tokens = []
    for token in self.basic_tokenizer.tokenize(text):
      for sub_token in self.wordpiece_tokenizer.tokenize(token):
        split_tokens.append(sub_token)

    return split_tokens

  def convert_tokens_to_ids(self, tokens):
    return convert_by_vocab(self.vocab, tokens)

  def convert_ids_to_tokens(self, ids):
    return convert_by_vocab(self.inv_vocab, ids)

FullTokenizer的作用就很显而易见了,对一个文本段进行以上两种解析,最后返回词(字)的数组,同时还提供token到id的索引以及id到token的索引。这里的token可以理解为文本段处理过后的最小单元。

二、create_pretraining_data.py

1、配置

flags.DEFINE_string("input_file", None,
                    "Input raw text file (or comma-separated list of files).")
flags.DEFINE_string(
    "output_file", None,
    "Output TF example file (or comma-separated list of files).")
flags.DEFINE_string("vocab_file", None,
                    "The vocabulary file that the BERT model was trained on.")
flags.DEFINE_bool(
    "do_lower_case", True,
    "Whether to lower case the input text. Should be True for uncased "
    "models and False for cased models.")
flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
flags.DEFINE_integer("max_predictions_per_seq", 20,
                     "Maximum number of masked LM predictions per sequence.")
flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
flags.DEFINE_integer(
    "dupe_factor", 10,
    "Number of times to duplicate the input data (with different masks).")
flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
flags.DEFINE_float(
    "short_seq_prob", 0.1,
    "Probability of creating sequences which are shorter than the "
    "maximum length.")

配置input_file、output_file分别代表输入的源语料文件和处理过的预料文件地址;

do_lower_case:是否全部转为小写字母,是否转换成小写字母的意义在Bert系列(一)——demo运行里面已经说过了。

dupe_factor:默认重复10次,目的是可以生成不同情况的masks;

short_seq_prob:构造长度小于指定"max_seq_length"的样本比例。因为在fine-tune过程里面输入的target_seq_length是可变的(小于等于max_seq_length),那么为了防止过拟合也需要在pre-train的过程当中构造一些短的样本。

2、main入口

def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)

  tokenizer = tokenization.FullTokenizer(
      vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)

  input_files = []
  for input_pattern in FLAGS.input_file.split(","):
    input_files.extend(tf.gfile.Glob(input_pattern))

  tf.logging.info("*** Reading from input files ***")
  for input_file in input_files:
    tf.logging.info("  %s", input_file)

  rng = random.Random(FLAGS.random_seed)
  instances = create_training_instances(
      input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
      FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
      rng)

  output_files = FLAGS.output_file.split(",")
  tf.logging.info("*** Writing to output files ***")
  for output_file in output_files:
    tf.logging.info("  %s", output_file)

  write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
                                  FLAGS.max_predictions_per_seq, output_files)

从入口开始看,步骤很简单:1)构造tokenizer ;2)构造instances ;3)保存instances

3、构造instances

def create_training_instances(input_files, tokenizer, max_seq_length,
                              dupe_factor, short_seq_prob, masked_lm_prob,
                              max_predictions_per_seq, rng):
  """Create `TrainingInstance`s from raw text."""
  all_documents = [[]]
  for input_file in input_files:
    with tf.gfile.GFile(input_file, "r") as reader:
      while True:
        line = tokenization.convert_to_unicode(reader.readline())
        if not line:
          break
        line = line.strip()

        # Empty lines are used as document delimiters
        if not line:
          all_documents.append([])
        tokens = tokenizer.tokenize(line)
        if tokens:
          all_documents[-1].append(tokens)

  # Remove empty documents
  all_documents = [x for x in all_documents if x]
  rng.shuffle(all_documents)

  vocab_words = list(tokenizer.vocab.keys())
  instances = []
  for _ in range(dupe_factor):
    for document_index in range(len(all_documents)):
      instances.extend(
          create_instances_from_document(
              all_documents, document_index, max_seq_length, short_seq_prob,
              masked_lm_prob, max_predictions_per_seq, vocab_words, rng))

  rng.shuffle(instances)
  return instances

这一步是阅读数据,数据的输入文本可以是一个文件也可以是用逗号分割的若干文件;
文件里用换行来表示句子的边界,即一句一行,同理段落之间用空一行来表示段落的边界,一个段落表示成一个document;具体的构造方法在create_instances_from_document函数里面。

def create_instances_from_document(
    all_documents, document_index, max_seq_length, short_seq_prob,
    masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
  """Creates `TrainingInstance`s for a single document."""
  document = all_documents[document_index]

  # Account for [CLS], [SEP], [SEP]
  max_num_tokens = max_seq_length - 3
  target_seq_length = max_num_tokens
  if rng.random() < short_seq_prob:
    target_seq_length = rng.randint(2, max_num_tokens)

  instances = []
  current_chunk = []
  current_length = 0
  i = 0
  while i < len(document):
    segment = document[i]
    current_chunk.append(segment)
    current_length += len(segment)
    if i == len(document) - 1 or current_length >= target_seq_length:
      if current_chunk:
        # `a_end` is how many segments from `current_chunk` go into the `A`
        # (first) sentence.
        a_end = 1
        if len(current_chunk) >= 2:
          a_end = rng.randint(1, len(current_chunk) - 1)

        tokens_a = []
        for j in range(a_end):
          tokens_a.extend(current_chunk[j])

        tokens_b = []
        # Random next
        is_random_next = False
        if len(current_chunk) == 1 or rng.random() < 0.5:
          is_random_next = True
          target_b_length = target_seq_length - len(tokens_a)

          for _ in range(10):
            random_document_index = rng.randint(0, len(all_documents) - 1)
            if random_document_index != document_index:
              break

          random_document = all_documents[random_document_index]
          random_start = rng.randint(0, len(random_document) - 1)
          for j in range(random_start, len(random_document)):
            tokens_b.extend(random_document[j])
            if len(tokens_b) >= target_b_length:
              break
  
          num_unused_segments = len(current_chunk) - a_end
          i -= num_unused_segments
        # Actual next
        else:
          is_random_next = False
          for j in range(a_end, len(current_chunk)):
            tokens_b.extend(current_chunk[j])
        truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)

        assert len(tokens_a) >= 1
        assert len(tokens_b) >= 1

        tokens = []
        segment_ids = []
        tokens.append("[CLS]")
        segment_ids.append(0)
        for token in tokens_a:
          tokens.append(token)
          segment_ids.append(0)

        tokens.append("[SEP]")
        segment_ids.append(0)

        for token in tokens_b:
          tokens.append(token)
          segment_ids.append(1)
        tokens.append("[SEP]")
        segment_ids.append(1)

        (tokens, masked_lm_positions,
         masked_lm_labels) = create_masked_lm_predictions(
             tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
        instance = TrainingInstance(
            tokens=tokens,
            segment_ids=segment_ids,
            is_random_next=is_random_next,
            masked_lm_positions=masked_lm_positions,
            masked_lm_labels=masked_lm_labels)
        instances.append(instance)
      current_chunk = []
      current_length = 0
    i += 1
  return instances

这一段算是整个模块的核心了。

instance = TrainingInstance(
            tokens=tokens,
            segment_ids=segment_ids,
            is_random_next=is_random_next,
            masked_lm_positions=masked_lm_positions,
            masked_lm_labels=masked_lm_labels)

1)一个instance 包含一个tokens,实际上就是输入的词序列;该序列表现形式为:

[CLS] A [SEP] B [SEP]

A=[token_0, token_1, ...,token_i]
B=[token_i+1, token_i+2, ...,token_n-1]

其中:
2<= n < max_seq_length - 3 (in short_seq_prob)
n=max_seq_length - 3 (in 1-short_seq_prob)

token 最后表现形式如下图所示:


tokens示意图

segment_ids 指的形式为[0,0,0...1,1,111] 0的个数为i+1个,1的个数为max_seq_length - (i+1)
对应到模型输入就是token_type

is_random_next:其实就是上图的Label,0.5的概率为True(和当只有一个segment的时候),如果为True则B和A不属于同一document。剩下的情况为False,则B为A同一document的后续句子。

masked_lm_positions:序列里被[MASK]的位置;

masked_lm_labels:序列里被[MASK]的token

2)在create_masked_lm_predictions函数里,一个序列在指定MASK数量之后,有80%被真正MASK,10%还是保留原来token,10%被随机替换成其他token。

4、保存instance

def write_instance_to_example_files(instances, tokenizer, max_seq_length,
                                    max_predictions_per_seq, output_files):
  """Create TF example files from `TrainingInstance`s."""
  writers = []
  for output_file in output_files:
    writers.append(tf.python_io.TFRecordWriter(output_file))

  writer_index = 0

  total_written = 0
  for (inst_index, instance) in enumerate(instances):
    input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
    input_mask = [1] * len(input_ids)
    segment_ids = list(instance.segment_ids)
    assert len(input_ids) <= max_seq_length

    while len(input_ids) < max_seq_length:
      input_ids.append(0)
      input_mask.append(0)
      segment_ids.append(0)

    assert len(input_ids) == max_seq_length
    assert len(input_mask) == max_seq_length
    assert len(segment_ids) == max_seq_length

    masked_lm_positions = list(instance.masked_lm_positions)
    masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
    masked_lm_weights = [1.0] * len(masked_lm_ids)

    while len(masked_lm_positions) < max_predictions_per_seq:
      masked_lm_positions.append(0)
      masked_lm_ids.append(0)
      masked_lm_weights.append(0.0)

    next_sentence_label = 1 if instance.is_random_next else 0

    features = collections.OrderedDict()
    features["input_ids"] = create_int_feature(input_ids)
    features["input_mask"] = create_int_feature(input_mask)
    features["segment_ids"] = create_int_feature(segment_ids)
    features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
    features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
    features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
    features["next_sentence_labels"] = create_int_feature([next_sentence_label])

    tf_example = tf.train.Example(features=tf.train.Features(feature=features))

    writers[writer_index].write(tf_example.SerializeToString())
    writer_index = (writer_index + 1) % len(writers)

    total_written += 1

    if inst_index < 20:
      tf.logging.info("*** Example ***")
      tf.logging.info("tokens: %s" % " ".join(
          [tokenization.printable_text(x) for x in instance.tokens]))

      for feature_name in features.keys():
        feature = features[feature_name]
        values = []
        if feature.int64_list.value:
          values = feature.int64_list.value
        elif feature.float_list.value:
          values = feature.float_list.value
        tf.logging.info(
            "%s: %s" % (feature_name, " ".join([str(x) for x in values])))

  for writer in writers:
    writer.close()

  tf.logging.info("Wrote %d total instances", total_written)

instance保存没什么好说的,只有两点:

while len(input_ids) < max_seq_length:
      input_ids.append(0)
      input_mask.append(0)
      segment_ids.append(0)

1)之前不是有short_seq_prob的概率导致样本的长度小于max_predictions_per_seq吗,这里把这些样本补齐,padding为0,同样的还有input_mask和segment_ids;
2) 把instance的is_random_next转化成变量next_sentence_label保存。

为了验证这个数据模块对中文输入输出的支持,我做了个测试:

python3 create_pretraining_data.py   --input_file=/tmp/zh_test.txt   --output_file=/tmp/output.txt   --vocab_file=$BERT_ZH_DIR/vocab.txt

zh_test.txt是我脸滚键盘随意输入的一些汉字,共有两段,每段两句话:

酒店附近开房的艰苦的飞机飞抵发窘惹风波,觉得覅奇偶均衡能否v不。
极度疯狂减肥的人能否打开v高科技就而后就覅哦冏结构i恶如桂萼黑人牙膏覅u我也【发票未开u俄日附件二我就佛i额外阶级感v,我为何军方的我i和服i好热哦iu均为辐9为u

ui和覅文化覅哦佛为进度覅u蛊蛾i巨乳古人规格i兼顾如果我是破看到v个ui就火热i今年的付款了几个vi哦素问。就觉发给金佛i为借口破碎的梦
i觉得覅u而非各位i风格较为哦个粉色哦i多发几个v二哥i文件哦i怪兽决斗盘可加热管覅u个人文集狗哥

vocab.txt是下载的bert中文预训练模型里的词典

最后的部分输出如下所示:

INFO:tensorflow:*** Example ***
INFO:tensorflow:tokens: [CLS] i 觉 得 [UNK] u [MASK] 非 [MASK] 位 i 风 格 较 ##by 哦 个 驅 色 哦 i 多 发 [MASK] 个 v 二 哥 i 文 件 哦 i 怪 [MASK] 决 斗 盘 可 加 热 管 [MASK] u [MASK] [MASK] 文 集 狗 哥 [SEP] [MASK] [UNK] 奇 偶 均 衡 能 否 v 不 。 极 [MASK] 疯 狂 减 肥 的 人 能 否 打 开 v 高 科 技 就 而 [MASK] 就 [UNK] 哦 冏 结 构 i 恶 如 桂 萼 黑 人 牙 膏 [UNK] u 我 也 【 发 票 未 开 [MASK] 俄 日 [MASK] 件 二 我 就 佛 i 额 [MASK] 阶 [MASK] 感 v [MASK] 我 为 [MASK] 军 方 [SEP]
INFO:tensorflow:input_ids: 101 151 6230 2533 100 163 103 7478 103 855 151 7599 3419 6772 8684 1521 702 7705 5682 1521 151 1914 1355 103 702 164 753 1520 151 3152 816 1521 151 2597 103 1104 3159 4669 1377 1217 4178 5052 103 163 103 103 3152 7415 4318 1520 102 103 100 1936 981 1772 6130 5543 1415 164 679 511 3353 103 4556 4312 1121 5503 4638 782 5543 1415 2802 2458 164 7770 4906 2825 2218 5445 103 2218 100 1521 1087 5310 3354 151 2626 1963 3424 5861 7946 782 4280 5601 100 163 2769 738 523 1355 4873 3313 2458 103 915 3189 103 816 753 2769 2218 867 151 7583 103 7348 103 2697 164 103 2769 711 103 1092 3175 102
INFO:tensorflow:input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
INFO:tensorflow:segment_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
INFO:tensorflow:masked_lm_positions: 6 8 14 17 23 34 42 44 45 46 51 63 80 105 108 116 118 121 124 0
INFO:tensorflow:masked_lm_ids: 5445 1392 711 5106 1126 1077 100 702 782 3152 2533 2428 1400 163 7353 1912 5277 8024 862 0
INFO:tensorflow:masked_lm_weights: 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 0.0
INFO:tensorflow:next_sentence_labels: 1

可以看到token序列里的中文确实是以字的形式出现的

三、run_pretraining.py

终于到预训练的执行模块了,里面大部分都是tensorflow训练的常规代码,感觉没什么好分析的。

看过前面的内容和我前两章内容的朋友我想已经初步知道预训练的整个逻辑了,这里作一个简单的介绍:

1、X和Y的确定

    input_ids = features["input_ids"]
    input_mask = features["input_mask"]
    segment_ids = features["segment_ids"]
    masked_lm_positions = features["masked_lm_positions"]
    masked_lm_ids = features["masked_lm_ids"]
    masked_lm_weights = features["masked_lm_weights"]
    next_sentence_labels = features["next_sentence_labels"]
    model = modeling.BertModel(
        config=bert_config,
        is_training=is_training,
        input_ids=input_ids,
        input_mask=input_mask,
        token_type_ids=segment_ids,
        use_one_hot_embeddings=use_one_hot_embeddings)

其中input_ids、input_mask 、segment_ids 作为X,剩下的masked_lm_positions、masked_lm_ids 、masked_lm_weights 、next_sentence_labels 共同作为Y

2、 loss

    (masked_lm_loss,
     masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output(
         bert_config, model.get_sequence_output(), model.get_embedding_table(),
         masked_lm_positions, masked_lm_ids, masked_lm_weights)

    (next_sentence_loss, next_sentence_example_loss,
     next_sentence_log_probs) = get_next_sentence_output(
         bert_config, model.get_pooled_output(), next_sentence_labels)

    total_loss = masked_lm_loss + next_sentence_loss

可以看到loss 分别由masked_lm_loss和next_sentence_loss组成,masked_lm_loss针对的是语言模型对MASK起来的标签的预测,即上下文语境预测当前词;而next_sentence_loss是对于句子关系的预测。前者在迁移学习中可以用于标注类任务(分词、NER等),后者可以用于句子关系任务(QA、自然语言推理等)。

需要多说一句的是,masked_lm_loss,用到了模型的sequence_output和embedding_table,这是因为对多个MASK的标签进行预测是一个标注问题,所以需要获取最后一层的整个sequence,而embedding_table用来反embedding,这样就映射到token的学习了。而next_sentence_loss用到的是pooled_output,对应的是第一个token [CLS],它一般用于分类任务的学习。

总结:

本文介绍了以下几个内容:

1、tokenization模块:我把它叫做对原始文本段的解析,只有解析过后才能标准化输入;

2、create_pretraining_data模块:对原始数据进行转换,原始数据本是无标签的数据,通过句子的拼接可以产生句子关系的标签,通过MASK可以产生标注的标签,其本质是语言模型的应用;

3、run_pretraining模块:在执行预训练的时候针对以上两种标签分别利用bert模型的不同输出部件,计算loss,然后进行梯度下降优化。

本文系列
Bert系列(一)——demo运行
Bert系列(二)——模型主体源码解读
Bert系列(四)——源码解读之Fine-tune
Bert系列(五)——中文分词实践 F1 97.8%(附代码)

Reference
1.https://github.com/google-research/bert
2.BERT: Pre-training of Deep Bidirectional Transformers for
Language Understanding

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 200,176评论 5 469
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 84,190评论 2 377
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 147,232评论 0 332
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 53,953评论 1 272
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 62,879评论 5 360
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,177评论 1 277
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,626评论 3 390
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,295评论 0 254
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,436评论 1 294
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,365评论 2 317
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,414评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,096评论 3 315
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,685评论 3 303
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,771评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 30,987评论 1 255
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 42,438评论 2 346
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,032评论 2 341

推荐阅读更多精彩内容