目录:
- 指标结果
- 指标
- 折线图
- 小结
- 方案
1. 指标结果
数据一:电网数据
teacher-dev
precision recall f1-score support
accuracy 0.9138 58
macro avg 0.9004 0.8791 0.8827 58
weighted avg 0.9206 0.9138 0.9111 58
student-dev
bert-layer1(约97Mb)
epoch-19
precision recall f1-score support
accuracy 0.8103 58
macro avg 0.6136 0.6035 0.5917 58
weighted avg 0.8448 0.8103 0.8107 58
epoch-29
precision recall f1-score support
accuracy 0.8448 58
macro avg 0.6305 0.6175 0.6118 58
weighted avg 0.8661 0.8448 0.8468 58
epoch-39
precision recall f1-score support
accuracy 0.8966 58
macro avg 0.7522 0.6769 0.6964 58
weighted avg 0.9321 0.8966 0.9060 58
epoch-49
precision recall f1-score support
accuracy 0.8276 58
macro avg 0.6204 0.6074 0.6011 58
weighted avg 0.8508 0.8276 0.8296 58
epoch-59
precision recall f1-score support
accuracy 0.8448 58
macro avg 0.6277 0.6144 0.6098 58
weighted avg 0.8590 0.8448 0.8449 58
bert-layer3(约150Mb)
epoch 9
precision recall f1-score support
accuracy 0.8448 58
macro avg 0.8619 0.8411 0.8358 58
weighted avg 0.8542 0.8448 0.8395 58
epoch 19
precision recall f1-score support
accuracy 0.9310 58
macro avg 0.9198 0.9013 0.8969 58
weighted avg 0.9387 0.9310 0.9299 58
epoch 29
precision recall f1-score support
accuracy 0.9138 58
macro avg 0.8984 0.8791 0.8735 58
weighted avg 0.9196 0.9138 0.9104 58
epoch 39
precision recall f1-score support
accuracy 0.9138 58
macro avg 0.8984 0.8791 0.8735 58
weighted avg 0.9196 0.9138 0.9104 58
epoch 49
precision recall f1-score support
accuracy 0.9138 58
macro avg 0.8984 0.8791 0.8735 58
weighted avg 0.9196 0.9138 0.9104 58
epoch 59
precision recall f1-score support
accuracy 0.9138 58
macro avg 0.8984 0.8791 0.8735 58
weighted avg 0.9196 0.9138 0.9104 58
bert-layer3
修改optimizer
为:torch.optim.SGD(student_model.parameters(), lr=0.05)
→ AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
epoch 9
precision recall f1-score support
accuracy 0.8103 58
macro avg 0.7917 0.8240 0.7970 58
weighted avg 0.8203 0.8103 0.8010 58
epoch 19
precision recall f1-score support
accuracy 0.8621 58
macro avg 0.8899 0.8443 0.8560 58
weighted avg 0.8641 0.8621 0.8585 58
epoch 29
precision recall f1-score support
accuracy 0.8793 58
macro avg 0.9019 0.8529 0.8658 58
weighted avg 0.8816 0.8793 0.8752 58
epoch 39
precision recall f1-score support
accuracy 0.8793 58
macro avg 0.8814 0.8529 0.8537 58
weighted avg 0.8834 0.8793 0.8755 58
epoch 49
precision recall f1-score support
accuracy 0.8621 58
macro avg 0.8252 0.8443 0.8304 58
weighted avg 0.8648 0.8621 0.8601 58
epoch 59
precision recall f1-score support
accuracy 0.8448 58
macro avg 0.8058 0.8358 0.8141 58
weighted avg 0.8573 0.8448 0.8461 58
数据二:某品牌奶粉数据
teacher-dev
precision recall f1-score support
测评 0.7800 0.7800 0.7800 50
种草 0.8856 0.8754 0.8805 345
科普 0.6636 0.6887 0.6759 106
accuracy 0.8263 501
macro avg 0.7764 0.7813 0.7788 501
weighted avg 0.8281 0.8263 0.8272 501
student-dev
bert-layer-1
epoch 9
precision recall f1-score support
accuracy 0.6886 501
macro avg 0.2295 0.3333 0.2719 501
weighted avg 0.4742 0.6886 0.5616 501
epoch 19
precision recall f1-score support
accuracy 0.7265 501
macro avg 0.6689 0.6261 0.6331 501
weighted avg 0.7444 0.7265 0.7295 501
epoch 29
precision recall f1-score support
accuracy 0.7106 501
macro avg 0.6198 0.5992 0.6067 501
weighted avg 0.7152 0.7106 0.7118 501
epoch 38
precision recall f1-score support
accuracy 0.7146 501
macro avg 0.6200 0.5898 0.6007 501
weighted avg 0.7161 0.7146 0.7139 501
epoch 49
precision recall f1-score support
accuracy 0.7146 501
macro avg 0.6451 0.6156 0.6243 501
weighted avg 0.7269 0.7146 0.7182 501
bert-layer3
epoch 9
precision recall f1-score support
accuracy 0.7605 501
macro avg 0.7092 0.6881 0.6920 501
weighted avg 0.7782 0.7605 0.7659 501
epoch 19
precision recall f1-score support
accuracy 0.7764 501
macro avg 0.7030 0.7069 0.7049 501
weighted avg 0.7787 0.7764 0.7775 501
epoch 29
precision recall f1-score support
accuracy 0.7764 501
macro avg 0.6991 0.6977 0.6981 501
weighted avg 0.7791 0.7764 0.7776 501
bert-layer3
修改optimizer
为:torch.optim.SGD(student_model.parameters(), lr=0.05)
→ AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
epoch 9
precision recall f1-score support
accuracy 0.7685 501
macro avg 0.6756 0.6662 0.6690 501
weighted avg 0.7734 0.7685 0.7701 501
epoch 19
precision recall f1-score support
accuracy 0.8064 501
macro avg 0.7360 0.7261 0.7295 501
weighted avg 0.8106 0.8064 0.8078 501
epoch 29
precision recall f1-score support
accuracy 0.7784 501
macro avg 0.7089 0.6968 0.6998 501
weighted avg 0.7870 0.7784 0.7812 501
奶粉数据上的student表现:
bert-layer-1
bert-layer-3(SGD)
bert-layer-3(Adamw)
2. 小结
- 1层的layer没有3层的好使(废话
-
SGD
和AdamW
没感觉到特别特别明显差异,先当作炼丹问题 | update:看下图的话感觉SGD相对更稳定一些 训练过程很容易崩掉啊,后面降得跟啥似的- 两个loss得权重比例和Temperature感觉取值也很玄学,可炼
- 目前student部分不够完善
- 如果teacher结果好,1层的student表现还行;如果teacher表现不是非常理想,那student如果结构弱也比较吃亏
3. 方案
对于teacher
模型,我在代码中返回的是return loss, dense_2_with_softmax, dense_2_output
,其中dense_2_output
即为logits,这个后面会用到。
对于student
模型,与teacher
的模型结构基本上完全一样,但是在bert_config
里面有不同的设置,我在这里将num_attention_heads
设置为3,将num_hidden_layers
分别设置成1和3进行了尝试。
训练部分有用部分如下:
model = torch.load("/data/static_MODEL/event_extract/sentence_classify_daneng_teacher/fold_0/model_epoch_23_p_1.0000_r_1.0000_f_1.0000.pt")
student_model = Student(args.bert_model_toy, args.label_num)
这里model
即为teacher,是直接从训练好的模型加载的,故设为*.eval()
。
optimizer
也是用了两种进行尝试,分别是
# 第一种方法:teacher model即用的这个
param_optimizer = list(student_model.named_parameters())
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
'weight_decay': args.weight_decay},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
warmup_steps = int(args.warmup_proportion * num_train_optimization_steps)
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
# 第二种方法:从网上直接粘贴过来的
optimizer = torch.optim.SGD(student_model.parameters(), lr=0.05)
在do_train
的训练过程中,对于每个batch数据,进行:
with torch.no_grad():
_, teacher_output_with_softmax, teacher_output = model(input_ids, segment_ids, input_mask, label_ids)
student_output, student_output_with_softmax = student_model(input_ids, segment_ids, input_mask, label_ids)
后面会用到student_output
和teacher_output
,实际上就是student去学习teacher的分布,对于论文比较常见的是:
在当前的实验中是摘抄了这段代码:
def distillation(y, teacher_scores, labels, T, alpha):
p = F.log_softmax(y/T, dim=1)
q = F.softmax(teacher_scores/T, dim=1)
l_kl = F.kl_div(p, q, size_average=False) * (T**2) / y.shape[0]
l_ce = F.cross_entropy(y, labels)
return l_kl * alpha + l_ce * (1. - alpha)
# 调用
loss = distillation(student_output, teacher_output, label_ids.long(), T, 0.2)