motivation
现有的OCR识别主要基于encoder-decoder的架构,采用基于CNN的encoder进行图像特征理解;基于RNN的decoder完成文本生成,再使用CTC/attention对齐。
TrOCR的目的:
1、OCR领域大规模数据预训练模型;
2、引入transformer的子监督替换CNN
网络结构
图像transformer作为encoder,输入为图像patch;文本transformer作为decoder,输入为text sequence
encoder
将输入切片图像进行固定大小的patch划分,经过线性映射,加上position embedding,输入transformer的encoder,与ViT模型结构一致;输出为切片图像的视觉特征;保留[CLS]全局token代表整图特征;
decoder
输入为encoder的输出 + 之前产生的wordpiece;输出为wordpiece 。
过程如下:
数据增强
随机旋转、高斯模糊、图像膨胀、图像腐蚀、下采样、添加下划线、保持原样。机会均等地随机七选一
训练
第一个阶段,包含上亿张打印体文本行的图像以及对应文本标注的数据集;第二个阶段,包含两个相对较小的数据集,分别对应打印体文本识别任务和手写体文本识别任务,均包含上百万的文本行图像
输入wordpiece序列向后旋转,把“[EOS]”符号挪到首位,输入到解码器中,并使用交叉熵损失函数来监督解码器的输出。
TrOCR_Base:330M参数
TrOCR_Large: 558M参数
推理
encoder前向计算即可;
decoder从“[EOS]”符号开始迭代预测之后的 wordpiece,并把预测出的 wordpiece 作为下一次decoding的输入。
decoder shifted right
实质上是decoding时,给输出添加起始符/结束符,方便预测第一个Token/结束预测过程。