bert蒸馏初探

目录:

  1. 指标结果
    • 指标
    • 折线图
  2. 小结
  3. 方案

1. 指标结果

数据一:电网数据

teacher-dev

                 precision    recall  f1-score   support

       accuracy                         0.9138        58
      macro avg     0.9004    0.8791    0.8827        58
   weighted avg     0.9206    0.9138    0.9111        58

student-dev

bert-layer1(约97Mb)
epoch-19
                 precision    recall  f1-score   support

       accuracy                         0.8103        58
      macro avg     0.6136    0.6035    0.5917        58
   weighted avg     0.8448    0.8103    0.8107        58

epoch-29
                 precision    recall  f1-score   support

       accuracy                         0.8448        58
      macro avg     0.6305    0.6175    0.6118        58
   weighted avg     0.8661    0.8448    0.8468        58

epoch-39
                 precision    recall  f1-score   support

       accuracy                         0.8966        58
      macro avg     0.7522    0.6769    0.6964        58
   weighted avg     0.9321    0.8966    0.9060        58

epoch-49
                 precision    recall  f1-score   support

       accuracy                         0.8276        58
      macro avg     0.6204    0.6074    0.6011        58
   weighted avg     0.8508    0.8276    0.8296        58

epoch-59
                 precision    recall  f1-score   support

       accuracy                         0.8448        58
      macro avg     0.6277    0.6144    0.6098        58
   weighted avg     0.8590    0.8448    0.8449        58

bert-layer3(约150Mb)
epoch  9
                 precision    recall  f1-score   support

       accuracy                         0.8448        58
      macro avg     0.8619    0.8411    0.8358        58
   weighted avg     0.8542    0.8448    0.8395        58

epoch  19
                 precision    recall  f1-score   support

       accuracy                         0.9310        58
      macro avg     0.9198    0.9013    0.8969        58
   weighted avg     0.9387    0.9310    0.9299        58

epoch  29
                 precision    recall  f1-score   support

       accuracy                         0.9138        58
      macro avg     0.8984    0.8791    0.8735        58
   weighted avg     0.9196    0.9138    0.9104        58

epoch  39
                 precision    recall  f1-score   support

       accuracy                         0.9138        58
      macro avg     0.8984    0.8791    0.8735        58
   weighted avg     0.9196    0.9138    0.9104        58

epoch  49
                 precision    recall  f1-score   support

       accuracy                         0.9138        58
      macro avg     0.8984    0.8791    0.8735        58
   weighted avg     0.9196    0.9138    0.9104        58

epoch  59
                 precision    recall  f1-score   support

       accuracy                         0.9138        58
      macro avg     0.8984    0.8791    0.8735        58
   weighted avg     0.9196    0.9138    0.9104        58
bert-layer3

修改optimizer为:torch.optim.SGD(student_model.parameters(), lr=0.05)AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)

epoch  9
                 precision    recall  f1-score   support

       accuracy                         0.8103        58
      macro avg     0.7917    0.8240    0.7970        58
   weighted avg     0.8203    0.8103    0.8010        58

epoch  19
                 precision    recall  f1-score   support

       accuracy                         0.8621        58
      macro avg     0.8899    0.8443    0.8560        58
   weighted avg     0.8641    0.8621    0.8585        58

epoch  29
                 precision    recall  f1-score   support

       accuracy                         0.8793        58
      macro avg     0.9019    0.8529    0.8658        58
   weighted avg     0.8816    0.8793    0.8752        58

epoch  39
                 precision    recall  f1-score   support

       accuracy                         0.8793        58
      macro avg     0.8814    0.8529    0.8537        58
   weighted avg     0.8834    0.8793    0.8755        58

epoch  49
                 precision    recall  f1-score   support

       accuracy                         0.8621        58
      macro avg     0.8252    0.8443    0.8304        58
   weighted avg     0.8648    0.8621    0.8601        58

epoch  59
                 precision    recall  f1-score   support

       accuracy                         0.8448        58
      macro avg     0.8058    0.8358    0.8141        58
   weighted avg     0.8573    0.8448    0.8461        58


数据二:某品牌奶粉数据

teacher-dev

              precision    recall  f1-score   support

          测评     0.7800    0.7800    0.7800        50
          种草     0.8856    0.8754    0.8805       345
          科普     0.6636    0.6887    0.6759       106

    accuracy                         0.8263       501
   macro avg     0.7764    0.7813    0.7788       501
weighted avg     0.8281    0.8263    0.8272       501

student-dev

bert-layer-1
epoch 9
              precision    recall  f1-score   support

    accuracy                         0.6886       501
   macro avg     0.2295    0.3333    0.2719       501
weighted avg     0.4742    0.6886    0.5616       501

epoch 19
              precision    recall  f1-score   support

    accuracy                         0.7265       501
   macro avg     0.6689    0.6261    0.6331       501
weighted avg     0.7444    0.7265    0.7295       501

epoch 29
              precision    recall  f1-score   support

    accuracy                         0.7106       501
   macro avg     0.6198    0.5992    0.6067       501
weighted avg     0.7152    0.7106    0.7118       501

epoch 38
              precision    recall  f1-score   support

    accuracy                         0.7146       501
   macro avg     0.6200    0.5898    0.6007       501
weighted avg     0.7161    0.7146    0.7139       501

epoch 49
              precision    recall  f1-score   support

    accuracy                         0.7146       501
   macro avg     0.6451    0.6156    0.6243       501
weighted avg     0.7269    0.7146    0.7182       501
bert-layer3
epoch 9
              precision    recall  f1-score   support

    accuracy                         0.7605       501
   macro avg     0.7092    0.6881    0.6920       501
weighted avg     0.7782    0.7605    0.7659       501

epoch 19
              precision    recall  f1-score   support

    accuracy                         0.7764       501
   macro avg     0.7030    0.7069    0.7049       501
weighted avg     0.7787    0.7764    0.7775       501

epoch 29
              precision    recall  f1-score   support

    accuracy                         0.7764       501
   macro avg     0.6991    0.6977    0.6981       501
weighted avg     0.7791    0.7764    0.7776       501
bert-layer3

修改optimizer为:torch.optim.SGD(student_model.parameters(), lr=0.05)AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)

epoch 9
              precision    recall  f1-score   support

    accuracy                         0.7685       501
   macro avg     0.6756    0.6662    0.6690       501
weighted avg     0.7734    0.7685    0.7701       501

epoch 19
              precision    recall  f1-score   support

    accuracy                         0.8064       501
   macro avg     0.7360    0.7261    0.7295       501
weighted avg     0.8106    0.8064    0.8078       501

epoch 29
              precision    recall  f1-score   support

    accuracy                         0.7784       501
   macro avg     0.7089    0.6968    0.6998       501
weighted avg     0.7870    0.7784    0.7812       501

奶粉数据上的student表现:

bert-layer-1

image.png

bert-layer-3(SGD)

image.png

bert-layer-3(Adamw)

image.png

2. 小结

  • 1层的layer没有3层的好使(废话
  • SGDAdamW没感觉到特别特别明显差异,先当作炼丹问题 | update:看下图的话感觉SGD相对更稳定一些
  • 训练过程很容易崩掉啊,后面降得跟啥似的
  • 两个loss得权重比例和Temperature感觉取值也很玄学,可炼
  • 目前student部分不够完善
  • 如果teacher结果好,1层的student表现还行;如果teacher表现不是非常理想,那student如果结构弱也比较吃亏

3. 方案

对于teacher模型,我在代码中返回的是return loss, dense_2_with_softmax, dense_2_output,其中dense_2_output即为logits,这个后面会用到。

对于student模型,与teacher的模型结构基本上完全一样,但是在bert_config里面有不同的设置,我在这里将num_attention_heads设置为3,将num_hidden_layers分别设置成1和3进行了尝试。

训练部分有用部分如下:

model = torch.load("/data/static_MODEL/event_extract/sentence_classify_daneng_teacher/fold_0/model_epoch_23_p_1.0000_r_1.0000_f_1.0000.pt")
student_model = Student(args.bert_model_toy, args.label_num)

这里model即为teacher,是直接从训练好的模型加载的,故设为*.eval()

optimizer也是用了两种进行尝试,分别是

# 第一种方法:teacher model即用的这个
param_optimizer = list(student_model.named_parameters())
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay': args.weight_decay},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
warmup_steps = int(args.warmup_proportion * num_train_optimization_steps)
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)

# 第二种方法:从网上直接粘贴过来的
optimizer = torch.optim.SGD(student_model.parameters(), lr=0.05)

do_train的训练过程中,对于每个batch数据,进行:

with torch.no_grad():
    _, teacher_output_with_softmax, teacher_output = model(input_ids, segment_ids, input_mask, label_ids)
student_output, student_output_with_softmax = student_model(input_ids, segment_ids, input_mask, label_ids)

后面会用到student_outputteacher_output,实际上就是student去学习teacher的分布,对于论文比较常见的是:

image.png

在当前的实验中是摘抄了这段代码:

def distillation(y, teacher_scores, labels, T, alpha):
    p = F.log_softmax(y/T, dim=1)
    q = F.softmax(teacher_scores/T, dim=1)
    l_kl = F.kl_div(p, q, size_average=False) * (T**2) / y.shape[0]
    l_ce = F.cross_entropy(y, labels)

    return l_kl * alpha + l_ce * (1. - alpha)

# 调用
loss = distillation(student_output, teacher_output, label_ids.long(), T, 0.2)
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 204,684评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 87,143评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 151,214评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,788评论 1 277
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,796评论 5 368
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,665评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,027评论 3 399
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,679评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 41,346评论 1 299
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,664评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,766评论 1 331
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,412评论 4 321
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,015评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,974评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,203评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,073评论 2 350
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,501评论 2 343

推荐阅读更多精彩内容