论文 FLAT: Chinese NER Using Flat-Lattice Transformer(ACL 2020)
我们直接看模型部分,模型的输入部分在上一篇中已经详细解读过。
V0版本:without bert
model = Lattice_Transformer_SeqLabel(embeddings['lattice'],
embeddings['bigram'],
args.hidden, # 128
len(vocabs['label']), # label_size
args.head, args.layer, # 8, 1
args.use_abs_pos, # False 是否使用绝对位置编码
args.use_rel_pos, # True 是否使用相对位置编码
args.learn_pos, # False 绝对和相对位置编码是否可学习(是否计算梯度)
args.add_pos, # False 是否在transformer_layer中通过concat加入位置信息
args.pre, args.post, # '', 'an'
args.ff, # 128x3 feed-forward中间层节点个数
args.scaled, dropout, # False, dropout
args.use_bigram, # 1
mode, device,
vocabs,
max_seq_len=max_seq_len,
rel_pos_shared=args.rel_pos_shared, # True
k_proj=args.k_proj, # False
q_proj=args.q_proj, # True
v_proj=args.v_proj, # True
r_proj=args.r_proj, # True
self_supervised=args.self_supervised, # False
attn_ff=args.attn_ff, # False 是否在self-attn层最后加一个linear层
pos_norm=args.pos_norm, # False 是否对位置编码进行norm
ff_activate=args.ff_activate, # relu
abs_pos_fusion_func=args.abs_pos_fusion_func, # nonlinear_add
embed_dropout_pos=args.embed_dropout_pos, # 0
four_pos_shared=args.four_pos_shared, # True 只针对相对位置编码,指4个位置编码是否共享权重
four_pos_fusion=args.four_pos_fusion, # ff_two 4个位置编码的融合方法
four_pos_fusion_shared=args.four_pos_fusion_shared, # True 是否共享4个位置融合后形成的pos
use_pytorch_dropout=args.use_pytorch_dropout # 0
)
下面我们对Lattice_Transformer_SeqLabel
的一些关键代码块进行解读。
整体结构
- 位置编码
if self.use_rel_pos:
pe = get_embedding(max_seq_len, hidden_size, rel_pos_init=self.rel_pos_init) # [2*max_seq_len+1, hidden_size]
pe_sum = pe.sum(dim=-1, keepdim=True) # [2*max_seq_len+1, 1]
if self.pos_norm:
with torch.no_grad():
pe = pe/pe_sum
self.pe = nn.Parameter(pe, requires_grad=self.learnable_position)
if self.four_pos_shared:
self.pe_ss = self.pe
self.pe_se = self.pe
self.pe_es = self.pe
self.pe_ee = self.pe
else:
self.pe_ss = nn.Parameter(copy.deepcopy(pe), requires_grad=self.learnable_position)
self.pe_se = nn.Parameter(copy.deepcopy(pe), requires_grad=self.learnable_position)
self.pe_es = nn.Parameter(copy.deepcopy(pe), requires_grad=self.learnable_position)
self.pe_ee = nn.Parameter(copy.deepcopy(pe), requires_grad=self.learnable_position)
这里采用三角函数位置编码:
def get_embedding(max_seq_len, embedding_dim, padding_idx=None, rel_pos_init=0):
"""Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
rel pos init:
如果是0,那么从-max_len到max_len的相对位置编码矩阵就按0-2*max_len来初始化,
如果是1,那么就按-max_len,max_len来初始化
"""
num_embeddings = 2*max_seq_len+1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
if rel_pos_init == 0:
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) # [num_embeddings, half_dim]
else:
emb = torch.arange(-max_seq_len, max_seq_len+1, dtype=torch.float).unsqueeze(1)*emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) # [num_embeddings, embedding_dim]
if embedding_dim % 2 == 1:
# zero pad
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
- 是否使用bigram (非必要)
if self.use_bigram:
self.bigram_size = self.bigram_embed.embedding.weight.size(1)
self.char_input_size = self.lattice_embed.embedding.weight.size(1) + self.bigram_size
else:
self.char_input_size = self.lattice_embed.embedding.weight.size(1)
self.lex_input_size = self.lattice_embed.embedding.weight.size(1)
如果使用bigram,则把数据中bigram的embedding信息也加上,否则只有lattice的embedding信息。
- Transformer Encoder
self.char_proj = nn.Linear(self.char_input_size, self.hidden_size)
self.lex_proj = nn.Linear(self.lex_input_size, self.hidden_size)
self.encoder = Transformer_Encoder(self.hidden_size, self.num_heads, self.num_layers,
relative_position=self.use_rel_pos, ...)
- 网络最后一层
self.output = nn.Linear(self.hidden_size, self.label_size)
self.crf = get_crf_zero_init(self.label_size)
self.loss_func = nn.CrossEntropyLoss(ignore_index=-100)
- 模型forward函数
def forward(self, lattice, bigrams, seq_len, lex_num, pos_s, pos_e,
target, chars_target=None):
batch_size = lattice.size(0)
max_seq_len_and_lex_num = lattice.size(1)
max_seq_len = bigrams.size(1)
raw_embed = self.lattice_embed(lattice) # lattice embedding
if self.use_bigram:
bigrams_embed = self.bigram_embed(bigrams)
bigrams_embed = torch.cat([bigrams_embed,
torch.zeros(size=[batch_size, max_seq_len_and_lex_num - max_seq_len,
self.bigram_size]).to(bigrams_embed)], dim=1)
# [bs, max_seq_len_and_lex_num, lattice_embed_size+bigram_embed_size]
raw_embed_char = torch.cat([raw_embed, bigrams_embed], dim=-1)
else:
raw_embed_char = raw_embed
if self.embed_dropout_pos == '0':
raw_embed_char = self.embed_dropout(raw_embed_char)
raw_embed = self.gaz_dropout(raw_embed)
# [bs, max_seq_len_and_lex_num, hidden_size]
embed_char = self.char_proj(raw_embed_char)
char_mask = seq_len_to_mask(seq_len, max_len=max_seq_len_and_lex_num).bool() # [bs, max_len]
# Fills elements of self_tensor with 0 where char_mask is False
embed_char.masked_fill_(~(char_mask.unsqueeze(-1)), 0) # [bs, max_len, 1*hidden_size]
# [bs, max_seq_len_and_lex_num, hidden_size]
embed_lex = self.lex_proj(raw_embed)
lex_mask = (seq_len_to_mask(seq_len + lex_num).bool() ^ char_mask.bool()) # 后缀词汇部分为True
# Fills elements with 0 where lex_mask is False, 即前面char部分置为0
embed_lex.masked_fill_(~(lex_mask).unsqueeze(-1), 0) # [bs, max_len, 1*hidden_size]
embedding = embed_char + embed_lex
encoded = self.encoder(embedding, seq_len, lex_num=lex_num, pos_s=pos_s, pos_e=pos_e)
encoded = encoded[:, :max_seq_len, :] # 仅用char部分做预测
pred = self.output(encoded)
mask = seq_len_to_mask(seq_len).bool()
if self.training:
loss = self.crf(pred, target, mask).mean(dim=0)
if self.batch_num == 327:
print('{} loss:{}'.format(self.batch_num,loss))
exit()
return {'loss': loss}
else:
pred, path = self.crf.viterbi_decode(pred, mask)
result = {'pred': pred}
if self.self_supervised:
chars_pred = self.output_self_supervised(encoded)
result['chars_pred'] = chars_pred
return result
-
embedding = embed_char + embed_lex
为整个模型的Embedding-
embed_char
通过char_mask
使得后缀的词汇部分为0,保留前面的字的部分embeddings -
embed_lex
通过lex_mask
使得前面的字的部分为0,保留后缀的词汇部分的embeddings - 所以embed_char和embed_lex通过
self.char_proj
和self.lex_proj
映射分别学习(train)对应的字和词汇的embeddings
-
-
encoded = self.encoder(...)
为Self-Attention, Add & Norm, FFN, Add & Norm等层 -
pred = self.output(encoded)
为最后一个Linear层 -
loss = self.crf(pred, target, mask)
为最后的CRF层
以上内容与论文中的Figure 2
整体框架图对应。
关键细节
接下来我们分析self.encoder(...)即Transformer_Encoder
的主要内容:
- 是否融合4种位置信息得到相对位置编码,即
论文中的公式(8)
if self.four_pos_fusion_shared:
self.four_pos_fusion_embedding = \
Four_Pos_Fusion_Embedding(self.pe, self.four_pos_fusion, self.pe_ss, self.pe_se, self.pe_es, self.pe_ee,
self.max_seq_len, self.hidden_size, self.mode)
else:
self.four_pos_fusion_embedding = None
class Four_Pos_Fusion_Embedding(nn.Module):
def __init__(self, pe, four_pos_fusion, pe_ss, pe_se, pe_es, pe_ee,
max_seq_len, hidden_size, mode):
super().__init__()
self.mode = mode
self.hidden_size = hidden_size
self.max_seq_len = max_seq_len
self.pe_ss = pe_ss
self.pe_se = pe_se
self.pe_es = pe_es
self.pe_ee = pe_ee
self.pe = pe # [2*max_seq_len+1, hidden_size]
self.four_pos_fusion = four_pos_fusion
if self.four_pos_fusion == 'ff':
self.pos_fusion_forward = nn.Sequential(nn.Linear(self.hidden_size*4, self.hidden_size),
nn.ReLU(inplace=True))
if self.four_pos_fusion == 'ff_linear':
self.pos_fusion_forward = nn.Linear(self.hidden_size*4, self.hidden_size)
elif self.four_pos_fusion == 'ff_two':
self.pos_fusion_forward = nn.Sequential(nn.Linear(self.hidden_size*2, self.hidden_size),
nn.ReLU(inplace=True))
elif self.four_pos_fusion == 'attn':
self.w_r = nn.Linear(self.hidden_size, self.hidden_size)
self.pos_attn_score = nn.Sequential(nn.Linear(self.hidden_size*4, self.hidden_size*4),
nn.ReLU(),
nn.Linear(self.hidden_size*4, 4),
nn.Softmax(dim=-1))
elif self.four_pos_fusion == 'gate':
self.w_r = nn.Linear(self.hidden_size, self.hidden_size)
self.pos_gate_score = nn.Sequential(nn.Linear(self.hidden_size*4, self.hidden_size*2),
nn.ReLU(),
nn.Linear(self.hidden_size*2, 4*self.hidden_size))
def forward(self, pos_s, pos_e):
batch = pos_s.size(0)
max_seq_len = pos_s.size(1)
# [bs, max_seq_len, 1] - [bs, 1, max_seq_len] = [bs, max_seq_len, max_seq_len]
pos_ss = pos_s.unsqueeze(-1) - pos_s.unsqueeze(-2)
pos_se = pos_s.unsqueeze(-1) - pos_e.unsqueeze(-2)
pos_es = pos_e.unsqueeze(-1) - pos_s.unsqueeze(-2)
pos_ee = pos_e.unsqueeze(-1) - pos_e.unsqueeze(-2)
# [bs, max_seq_len, max_seq_len, hidden_size]
pe_ss = self.pe_ss[(pos_ss).view(-1) + self.max_seq_len].view(size=[batch, max_seq_len, max_seq_len, -1])
pe_se = self.pe_se[(pos_se).view(-1) + self.max_seq_len].view(size=[batch, max_seq_len, max_seq_len, -1])
pe_es = self.pe_es[(pos_es).view(-1) + self.max_seq_len].view(size=[batch, max_seq_len, max_seq_len, -1])
pe_ee = self.pe_ee[(pos_ee).view(-1) + self.max_seq_len].view(size=[batch, max_seq_len, max_seq_len, -1])
if self.four_pos_fusion == 'ff':
pe_4 = torch.cat([pe_ss, pe_se, pe_es, pe_ee], dim=-1)
rel_pos_embedding = self.pos_fusion_forward(pe_4)
if self.four_pos_fusion == 'ff_linear':
pe_4 = torch.cat([pe_ss, pe_se, pe_es, pe_ee], dim=-1)
rel_pos_embedding = self.pos_fusion_forward(pe_4)
if self.four_pos_fusion == 'ff_two':
pe_2 = torch.cat([pe_ss, pe_ee], dim=-1)
rel_pos_embedding = self.pos_fusion_forward(pe_2)
elif self.four_pos_fusion == 'attn':
pe_4 = torch.cat([pe_ss, pe_se, pe_es, pe_ee], dim=-1)
attn_score = self.pos_attn_score(pe_4)
pe_4_unflat = self.w_r(pe_4.view(batch, max_seq_len, max_seq_len, 4, self.hidden_size))
pe_4_fusion = (attn_score.unsqueeze(-1) * pe_4_unflat).sum(dim=-2)
rel_pos_embedding = pe_4_fusion
elif self.four_pos_fusion == 'gate':
pe_4 = torch.cat([pe_ss, pe_se, pe_es, pe_ee], dim=-1)
gate_score = self.pos_gate_score(pe_4).view(batch,max_seq_len,max_seq_len,4,self.hidden_size)
gate_score = F.softmax(gate_score, dim=-2)
pe_4_unflat = self.w_r(pe_4.view(batch, max_seq_len, max_seq_len, 4, self.hidden_size))
pe_4_fusion = (gate_score * pe_4_unflat).sum(dim=-2)
rel_pos_embedding = pe_4_fusion
return rel_pos_embedding
- forward函数中传入
pos_s(Head)
和pos_e(Tail)
来得到4种位置信息pos_ss, pos_se, pos_es, pos_ee
- 将4种位置信息转换成对应的位置编码
pe_ss, pe_se, pe_es, pe_ee
- 最后将4种位置编码进行融合。这里融合的方式有5种,
ff
就是带非线性激活函数的全连接,attn
就是先计算出每个位置编码的权重,再加权求和,gate
和attn
类似,只不过计算加权多了一个维度。
默认采用ff_two
,得到4种位置编码融合后形成的位置编码rel_pos_embedding
,维度为[bs, max_seq_len, max_seq_len, hidden_size]
。
- 核心部分
Transform_Encoder_Layer()
for i in range(self.num_layers):
setattr(self, 'layer_{}'.format(i), Transformer_Encoder_Layer(hidden_size, num_heads,...)
模型forward函数
def forward(self, inp, seq_len, lex_num=0, pos_s=None, pos_e=None, print_=False):
output = inp
if self.relative_position:
if self.four_pos_fusion_shared and self.lattice:
rel_pos_embedding = self.four_pos_fusion_embedding(pos_s, pos_e)
else:
rel_pos_embedding = None
else:
rel_pos_embedding = None
for i in range(self.num_layers):
now_layer = getattr(self, 'layer_{}'.format(i)) # 多层 Transformer_Encoder_Layer
output = now_layer(output, seq_len, lex_num=lex_num, pos_s=pos_s, pos_e=pos_e,
rel_pos_embedding=rel_pos_embedding, print_=print_)
output = self.layer_preprocess(output)
return output
可以看到now_layer
在逐层调用Transformer_Encoder_Layer()
进行前向传播,这里将相对位置编码rel_pos_embedding
也传了进去。
因此,我们对最核心代码块Transformer_Encoder_Layer()
的关键部分进行分析:
- 模型forward函数
self.ff = Positionwise_FeedForward([hidden_size, ff_size, hidden_size], self.dropout,ff_activate=self.ff_activate,
use_pytorch_dropout=self.use_pytorch_dropout)
def forward(self, inp, seq_len, lex_num=0, pos_s=None, pos_e=None, rel_pos_embedding=None,
print_=False):
output = inp
output = self.layer_preprocess(output)
if self.lattice:
if self.relative_position:
if rel_pos_embedding is None:
rel_pos_embedding = self.four_pos_fusion_embedding(pos_s,pos_e)
output = self.attn(output, output, output, seq_len, pos_s=pos_s, pos_e=pos_e, lex_num=lex_num,
rel_pos_embedding=rel_pos_embedding)
else:
output = self.attn(output, output, output, seq_len, lex_num)
else:
output = self.attn(output, output, output, seq_len)
output = self.layer_postprocess(output)
output = self.layer_preprocess(output)
output = self.ff(output, print_)
output = self.layer_postprocess(output)
return output
-
self.attn()
为Self-Attention层 -
self.layer_postprocess()
可以执行Add & Norm操作 (注意这里作者实现中有一些bug,并没有实现残差连接) -
self.ff()
为FFN层
- Self-Attention层中融入
相对位置编码
信息
self.attn = MultiHead_Attention_Lattice_rel_save_gpumm(self.hidden_size, self.num_heads,
pe=self.pe,
pe_ss=self.pe_ss,
pe_se=self.pe_se,
pe_es=self.pe_es,
pe_ee=self.pe_ee,
scaled=self.scaled,
mode=self.mode,
max_seq_len=self.max_seq_len,
dvc=self.dvc,
k_proj=self.k_proj,
q_proj=self.q_proj,
v_proj=self.v_proj,
r_proj=self.r_proj,
attn_dropout=self.dropout['attn'],
ff_final=self.attn_ff, # False
four_pos_fusion=self.four_pos_fusion,
use_pytorch_dropout=self.use_pytorch_dropout)
我们详细来看self-attention层中的关键部分:
class MultiHead_Attention_Lattice_rel_save_gpumm(nn.Module):
def __init__(self, hidden_size, num_heads, ...):
... 省略一些
self.per_head_size = self.hidden_size // self.num_heads
self.w_k = nn.Linear(self.hidden_size, self.hidden_size)
self.w_q = nn.Linear(self.hidden_size, self.hidden_size)
self.w_v = nn.Linear(self.hidden_size, self.hidden_size)
self.w_r = nn.Linear(self.hidden_size, self.hidden_size)
self.w_final = nn.Linear(self.hidden_size, self.hidden_size)
self.u = nn.Parameter(torch.Tensor(self.num_heads, self.per_head_size))
self.v = nn.Parameter(torch.Tensor(self.num_heads, self.per_head_size))
def forward(self, key, query, value, seq_len, lex_num, pos_s, pos_e, rel_pos_embedding):
if self.k_proj:
key = self.w_k(key)
if self.q_proj:
query = self.w_q(query)
if self.v_proj:
value = self.w_v(value)
if self.r_proj:
# [bs, max_seq_len, max_seq_len, hidden_size]
rel_pos_embedding = self.w_r(rel_pos_embedding)
batch = key.size(0)
max_seq_len = key.size(1)
# batch * seq_len * n_head * per_head_size
key = torch.reshape(key, [batch, max_seq_len, self.num_heads, self.per_head_size])
query = torch.reshape(query, [batch, max_seq_len, self.num_heads, self.per_head_size])
value = torch.reshape(value, [batch, max_seq_len, self.num_heads, self.per_head_size])
rel_pos_embedding = torch.reshape(rel_pos_embedding,
[batch, max_seq_len, max_seq_len, self.num_heads, self.per_head_size])
# batch * n_head * seq_len * per_head_size
key = key.transpose(1, 2)
query = query.transpose(1, 2)
value = value.transpose(1, 2)
# batch * n_head * per_head_size * key_len
key = key.transpose(-1, -2)
# u_for_c: 1(batch broadcast) * num_heads * 1 * per_head_size
u_for_c = self.u.unsqueeze(0).unsqueeze(-2)
query_and_u_for_c = query + u_for_c
# batch * n_head * seq_len * seq_len
A_C = torch.matmul(query_and_u_for_c, key)
rel_pos_embedding_for_b = rel_pos_embedding.permute(0, 3, 1, 4, 2)
# after above, rel_pos_embedding: batch * num_head * query_len * per_head_size * key_len
query_for_b = query.view([batch, self.num_heads, max_seq_len, 1, self.per_head_size])
# after above, query_for_b: batch * num_head * query_len * 1 * per_head_size
query_for_b_and_v_for_d = query_for_b + self.v.view(1, self.num_heads, 1, 1, self.per_head_size)
B_D = torch.matmul(query_for_b_and_v_for_d, rel_pos_embedding_for_b).squeeze(-2)
attn_score_raw = A_C + B_D
if self.scaled:
attn_score_raw = attn_score_raw / math.sqrt(self.per_head_size)
mask = seq_len_to_mask(seq_len+lex_num).bool().unsqueeze(1).unsqueeze(1)
attn_score_raw_masked = attn_score_raw.masked_fill(~mask, -1e15)
attn_score = F.softmax(attn_score_raw_masked,dim=-1)
attn_score = self.dropout(attn_score)
value_weighted_sum = torch.matmul(attn_score, value)
result = value_weighted_sum.transpose(1, 2).contiguous(). \
reshape(batch, max_seq_len, self.hidden_size)
return result # [batch, max_seq_len, hidden_size]
- 代码变量名中出现的
a, b, c, d
或其大写分别表示了论文中公式(11)
的第一、二、三、四
项。 -
A_C
表示论文中公式(11)
中第一项和第三项的和 -
B_D
表示论文中公式(11)
中第二项和第四项的和
至此,我们对FLAT模型结构的关键代码进行了一个较为详细地解读。
此外,论文作者还提供了一个V1版本,和V0版本的主要区别是使用了BERT embedding。
参考:
FLAT: Chinese NER Using Flat-Lattice Transformer (github.com)
NLP项目实践——中文序列标注Flat Lattice代码解读、运行与使用_CSDN博客