2019-01-16 解析bert代码

代码文件为bert_lstm_ner.py,下面进行逐行解析:

tf.logging.set_verbosity(tf.logging.INFO)#运行代码时,将会看到info日志输出INFO:tensorflow:loss = 1.18812, step = 1INFO:tensorflow:loss = #0.210323, step = 101INFO:tensorflow:loss = 0.109025, step = 201

processors = {

        "ner": NerProcessor

    }#定义一个ner:NerProcessor的字典

bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)#将bert参数传到bert_config中

if FLAGS.max_seq_length > bert_config.max_position_embeddings:#假如最大总输入序列长度大于bert最大的wordembedding长度,报错

    raise ValueError(

        "Cannot use sequence length %d because the BERT model "

        "was only trained up to sequence length %d" %

        (FLAGS.max_seq_length, bert_config.max_position_embeddings))

# 在train 的时候,才删除上一轮产出的文件,在predicted 的时候不做clean

if FLAGS.clean and FLAGS.do_train:#默认是两个ture

    if os.path.exists(FLAGS.output_dir):#假如输出文件位置存在

        def del_file(path):#设置个删文件的函数

            ls = os.listdir(path)#listdir函数返回文件夹中的所有文件名字

            for i in ls:

                c_path = os.path.join(path, i)#os.path.join()函数用于路径拼接文件路径

                if os.path.isdir(c_path):#如果该文件存在

                    del_file(c_path)#删除文件

                else:

                    os.remove(c_path)#删除文件

        try:

            del_file(FLAGS.output_dir)#尝试删除文件,否则报错

        except Exception as e:

            print(e)

            print('pleace remove the files of output dir and data.conf')

            exit(-1)

    if os.path.exists(FLAGS.data_config_path):#如果保存数据的位置存在

        try:

            os.remove(FLAGS.data_config_path)#尝试删除

        except Exception as e:

            print(e)

            print('pleace remove the files of output dir and data.conf')

            exit(-1)

task_name = FLAGS.task_name.lower()#task_name是要训练的任务的名称,值为ner

if task_name not in processors:#如果processor里面没有ner,报错

    raise ValueError("Task not found: %s" % (task_name))

processor = processors[task_name]()#返回NerProcessor()函数

label_list = processor.get_labels()#label_list值为["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "X", "[CLS]", "[SEP]"]

tokenizer = tokenization.FullTokenizer(

    vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)#输出函数(对bert的词汇文件在进行变小写后进行fulltokenizer)

tpu_cluster_resolver = None#不使用tpu集群

if FLAGS.use_tpu and FLAGS.tpu_name:#不考虑

    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(

        FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2#如果为PER_HOST_V1或PER_HOST_V2,则在每个主机上调用一次input_fn。 #使用每核心输入管道配置,每个核心调用一次。 具有全局批量大小

run_config = tf.contrib.tpu.RunConfig(#定义tpu函数

    cluster=tpu_cluster_resolver,#false

    master=FLAGS.master,#none‘TensorFlow master URL.’

    model_dir=FLAGS.output_dir,#输出位置

    save_checkpoints_steps=FLAGS.save_checkpoints_steps,#" 保存模型checkpoint的频率."为1000

    tpu_config=tf.contrib.tpu.TPUConfig(#定义tpu函数2

        iterations_per_loop=FLAGS.iterations_per_loop,#"在每个评估单元调用中要执行多少步骤."1000

        num_shards=FLAGS.num_tpu_cores,#tpu核数,8

        per_host_input_for_training=is_per_host))#PER_HOST_V2

train_examples = None#none

num_train_steps = None#none

num_warmup_steps = None#none

if os.path.exists(FLAGS.data_config_path):#如果data config 文件,保存训练和dev config存在

    with codecs.open(FLAGS.data_config_path) as fd:#打开文件路径

        data_config = json.load(fd)#加载数据到data_config中

else:

    data_config = {}#否则设为空

if FLAGS.do_train:

        # 加载训练数据

    if len(data_config) == 0:#如果为空

        train_examples = processor.get_train_examples(FLAGS.data_dir)#将训练样本输入到变量中

        num_train_steps = int(

            len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)#训练执行总批次数为样本长度/训练总批次*训练总次数

        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)#上面数值*进行线性学习率热身训练的比例。

        data_config['num_train_steps'] = num_train_steps#数据参数设定1

        data_config['num_warmup_steps'] = num_warmup_steps#数据参数设定2

        data_config['num_train_size'] = len(train_examples)#数据参数设定3(数据长度)

    else:

        num_train_steps = int(data_config['num_train_steps'])#直接调用1

        num_warmup_steps = int(data_config['num_warmup_steps'])#直接调用2

    # 返回的model_dn 是一个函数,其定义了模型,训练,评测方法,并且使用钩子参数,加载了BERT模型的参数进行了自己模型的参数初始化过程

    # tf 新的架构方法,通过定义model_fn 函数,定义模型,然后通过EstimatorAPI进行模型的其他工作,Es就可以控制模型的训练,预测,评估工作等。

model_fn = model_fn_builder(

    bert_config=bert_config,#从bert文件中获得

    num_labels=len(label_list) + 1,#标签数量

    init_checkpoint=FLAGS.init_checkpoint,#r'D:\bert\chinese_L-12_H-768_A-12\bert_model.ckpt "初始检查点(通常来自预先训练的bert模型)."

    learning_rate=FLAGS.learning_rate,#学习率 5e-5,

    num_train_steps=num_train_steps,#总批次

    num_warmup_steps=num_warmup_steps,#warmup数

#warmup就是先采用小的学习率(0.01)进行训练,训练了400iterations之后将学习率调整至0.1开始正式训练

    use_tpu=FLAGS.use_tpu,#none

    use_one_hot_embeddings=FLAGS.use_tpu)#none

print(model_fn)

estimator = tf.contrib.tpu.TPUEstimator(#定义评估器

    use_tpu=FLAGS.use_tpu,#none

    model_fn=model_fn,#将上面定义的model加入

    config=run_config,#将上面定义的runconfig参数加入

    train_batch_size=FLAGS.train_batch_size,#训练批次 64

    eval_batch_size=FLAGS.eval_batch_size,#评估批次 8

    predict_batch_size=FLAGS.predict_batch_size)# 预测批次 8

train_file =r'C:\Users\dell\Desktop\Name-Entity-Recognition-master\BERT-BiLSTM-CRF-NER\train.tf_record'

filed_based_convert_examples_to_features(

    train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file)#将数据转化为TF_Record 结构,作为模型数据输入:样本,标签,最#大长度,tokenizer,数据

num_train_size = num_train_size = int(data_config['num_train_size'])

tf.logging.info("***** Running training *****")

tf.logging.info("  Num examples = %d", num_train_size)#20864

tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)#64

tf.logging.info("  Num steps = %d", num_train_steps)#978

train_input_fn = file_based_input_fn_builder(

    input_file=train_file,#训练文件

    seq_length=FLAGS.max_seq_length,#最大序列长度 128

    is_training=True,#确定训练

    drop_remainder=True)#没查到。。。

estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)#进行训练

if FLAGS.do_eval:#进行评估

    if data_config.get('eval.tf_record_path', '') == '':#如果字典中没有评估路径

        eval_examples = processor.get_dev_examples(FLAGS.data_dir)#读到data_dir的dev文件

        eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")#获得输出位置的eval.tf_record文件

        filed_based_convert_examples_to_features(

            eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file)#将评估文件转换

        data_config['eval.tf_record_path'] = eval_file#将评估文件加入数据

        data_config['num_eval_size'] = len(eval_examples)#将评估文件长度加入数据

    else:

        eval_file = data_config['eval.tf_record_path']#将评估数据文件读出

        # 打印验证集数据信息

    num_eval_size = data_config.get('num_eval_size', 0)#将评估文件长度读出

    tf.logging.info("***** Running evaluation *****")

    tf.logging.info("  Num examples = %d", num_eval_size)#2318

    tf.logging.info("  Batch size = %d", FLAGS.eval_batch_size)#8

    eval_steps = None

    if FLAGS.use_tpu:#none

        eval_steps = int(num_eval_size / FLAGS.eval_batch_size)#不管

    eval_drop_remainder = True if FLAGS.use_tpu else False#false

    eval_input_fn = file_based_input_fn_builder(

        input_file=eval_file,#评估文件

        seq_length=FLAGS.max_seq_length,#最大序列长度

        is_training=False,#不训练

        drop_remainder=eval_drop_remainder)#none

    result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)#step=none(这里报错)

    output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")#输出文件

    with codecs.open(output_eval_file, "w", encoding='utf-8') as writer:

        tf.logging.info("***** Eval results *****")

        for key in sorted(result.keys()):

            tf.logging.info("  %s = %s", key, str(result[key]))#报出文件

            writer.write("%s = %s\n" % (key, str(result[key])))#写入文件

# 保存数据的配置文件,避免在以后的训练过程中多次读取训练以及测试数据集,消耗时间

if not os.path.exists(FLAGS.data_config_path):

    with codecs.open(FLAGS.data_config_path, 'a', encoding='utf-8') as fd:

        json.dump(data_config, fd)#把a作为data_config_path存入data_config

if FLAGS.do_predict:#开始预测

    token_path = os.path.join(FLAGS.output_dir, "token_test.txt")#导入测试集输出位置

    if os.path.exists(token_path):#如果测试集存在

        os.remove(token_path)#删了

    with codecs.open(os.path.join(FLAGS.output_dir, 'label2id.pkl'), 'rb') as rf:#打开label2id的文件

        label2id = pickle.load(rf)

        id2label = {value: key for key, value in label2id.items()}#转成字典

    predict_examples = processor.get_test_examples(FLAGS.data_dir)#得到test文件

    predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")#得到预测的tf_record文件

    filed_based_convert_examples_to_features(predict_examples, label_list,

                                                FLAGS.max_seq_length, tokenizer,

                                                predict_file, mode="test")#建立测试的tf_record文件

    tf.logging.info("***** Running prediction*****")

    tf.logging.info("  Num examples = %d", len(predict_examples))#4636

    tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)#8

    if FLAGS.use_tpu:

            # Warning: According to tpu_estimator.py Prediction on TPU is an

            # experimental feature and hence not supported here

            raise ValueError("Prediction in TPU not supported")

    predict_drop_remainder = True if FLAGS.use_tpu else False#false

    predict_input_fn = file_based_input_fn_builder(

        input_file=predict_file,#输入文件

        seq_length=FLAGS.max_seq_length,#最大序列

        is_training=False,#不训练

        drop_remainder=predict_drop_remainder)#none

    predicted_result = estimator.evaluate(input_fn=predict_input_fn)#报错。。。

    output_eval_file = os.path.join(FLAGS.output_dir, "predicted_results.txt")#输出预测结果

    with codecs.open(output_eval_file, "w", encoding='utf-8') as writer:

        tf.logging.info("***** Predict results *****")

        for key in sorted(predicted_result.keys()):

            tf.logging.info("  %s = %s", key, str(predicted_result[key]))

            writer.write("%s = %s\n" % (key, str(predicted_result[key])))#写入文件

    result = estimator.predict(input_fn=predict_input_fn)#预测

    output_predict_file = os.path.join(FLAGS.output_dir, "label_test.txt")#输出文件

    def result_to_pair(writer):#这里是写入函数

        for predict_line, prediction in zip(predict_examples, result):

            idx = 0

            line = ''

            line_token = str(predict_line.text).split(' ')

            label_token = str(predict_line.label).split(' ')

            if len(line_token) != len(label_token):

                tf.logging.info(predict_line.text)

                tf.logging.info(predict_line.label)

            for id in prediction:

                if id == 0:

                    continue

                curr_labels = id2label[id]

                if curr_labels in ['[CLS]', '[SEP]']:

                    continue

                    # 不知道为什么,这里会出现idx out of range 的错误。。。do not know why here cache list out of range exception!

                try:

                    line += line_token[idx] + ' ' + label_token[idx] + ' ' + curr_labels + '\n'

                except Exception as e:

                    tf.logging.info(e)

                    tf.logging.info(predict_line.text)

                    tf.logging.info(predict_line.label)

                    line = ''

                    break

                idx += 1

            writer.write(line + '\n')

    with codecs.open(output_predict_file, 'w', encoding='utf-8') as writer:

        result_to_pair(writer)#写入文件

    from conlleval import return_report

    eval_result = return_report(output_predict_file)#百度找不到,猜测是得到评估结果的函数

    print(eval_result)

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 206,214评论 6 481
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 88,307评论 2 382
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 152,543评论 0 341
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 55,221评论 1 279
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 64,224评论 5 371
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,007评论 1 284
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,313评论 3 399
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,956评论 0 259
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 43,441评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,925评论 2 323
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,018评论 1 333
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,685评论 4 322
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,234评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,240评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,464评论 1 261
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,467评论 2 352
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,762评论 2 345