论文 FLAT: Chinese NER Using Flat-Lattice Transformer(ACL 2020)
在前两篇中,我们对FLAT模型的输入和网络结构的关键代码进行了解读。本篇我们分析模型的输出以及评价指标。
在上一篇模型介绍中,我们得知,模型的输出pred
会送给CRF层计算loss,即:
pred = self.output(encoded)
mask = seq_len_to_mask(seq_len).bool()
if self.training:
loss = self.crf(pred, target, mask).mean(dim=0)
return {'loss': loss}
else:
# 作者将scores命名为path, 应当为笔误,这里改过来
pred, scores = self.crf.viterbi_decode(pred, mask)
result = {'pred': pred}
return result
这里self.crf()
的具体代码如下:
self.crf = get_crf_zero_init(self.label_size)
def get_crf_zero_init(label_size, include_start_end_trans=False,
allowed_transitions=None, initial_method=None):
import torch.nn as nn
from fastNLP.modules import ConditionalRandomField
crf = ConditionalRandomField(label_size, include_start_end_trans)
crf.trans_m = nn.Parameter(torch.zeros(size=[label_size, label_size], requires_grad=True))
if crf.include_start_end_trans:
crf.start_scores = nn.Parameter(torch.zeros(size=[label_size], requires_grad=True))
crf.end_scores = nn.Parameter(torch.zeros(size=[label_size], requires_grad=True))
return crf
可以发现,这里调用了FastNLP
工具包里的ConditionalRandomField
类,该类提供了forward()以及viterbi_decode()两个方法,分别用于train和inference。
最后可以看到,作者采用的评价指标为:
f1_metric = SpanFPreRecMetric(vocabs['label'], pred='pred', target='target',
seq_len='seq_len', encoding_type=encoding_type)
acc_metric = AccuracyMetric(pred='pred', target='target', seq_len='seq_len')
acc_metric.set_metric_name('label_acc')
metrics = [
f1_metric,
acc_metric
]
这里的SpanFPreRecMetric
和AccuracyMetric
也是FastNLP
工具包里类。
-
SpanFPreRecMetric
以span的方式计算F1, precision, recall -
AccuracyMetric
计算accuracy,这里我理解为是计算token-level的acc
至此,我们已基本解读完FLAT官方开源代码的关键细节,若有不当及错误,欢迎批评指正!
参考:
FLAT: Chinese NER Using Flat-Lattice Transformer (github.com)