在深度学习中,样本不均衡是指不同类别的样本数据量差别比较大,利用不均衡样本训练出来的模型泛化能力差,并且容易发生过拟合。
对于图像的分类问题,通过各个类别的样本数量就可以判断出训练样本是否过拟合。对于OCR识别问题,我们的目的是准确的识别出文本中的每个字符,训练样本通常是不定长度不同字符的组合,因此这里的样本不均衡指的是不同字符的数量差别比较大,无法简单通过图像的数量判断,但是可以对训练样本和对应的标注文档进行遍历,从而获得字典中每个字符的出现频率,进一步判断训练样本是否均衡!
贴上完整的代码!
最后贴上完整的代码:
# -*- coding: utf-8 -*-
# Usage:
# python /input/image/and/txt/folder/path/ /output/folder/path
import os
import sys
import glob as gb
def get_label_files(folder):
in_path = os.path.join(folder, "*.txt")
files = []
for txt_file in gb.glob(in_path):
img_file = txt_file[:-3] + "jpg"
if os.path.exists(img_file):
files.append(os.path.basename(txt_file))
return files
if __name__ == '__main__':
# input_folder = sys.argv[1]
# output_folder = sys.argv[2]
# get images and corresponding txt files
files = []
files = get_label_files(input_folder)
# get dict
dict_file = './dml_digital.txt'
result = dict()
with open(dict_file, 'r') as file:
for line in file:
line = line.decode('utf-8').strip()
result[line] = 0
# trace files to get statistics
for label_file in files:
with open(os.path.join(input_folder, label_file)) as file:
text = file.read().decode("utf-8").strip()
length = len(text)
for k in range(length):
if result.has_key(text[k]):
result[text[k]] += 1
# write statistics
output_file_path = os.path.join(output_folder,os.path.basename(input_folder) + '.txt')
with open(output_file_path, 'w+') as f:
for key, val in result.items():
f.write(key.encode("utf-8") + ' ' + str(val) + '\n')
print('Finished calculation!')
当代码执行完成后,在 output_folder 目录下就可以看到与训练样本所在文件夹同名的txt文件,下面是txt内容的部分截图,第一列是字典中的字符,第二列是训练集中该字符出现的次数,可以很明显的看出训练样本是不均衡的,阿拉伯数字的数量远远高于大写字母和汉字,接下来我们就可以根据识别需求对训练样本进行调整了,收工!其实样本的均衡性判断是训练前的工作,活生生的被我拖到了现在,又一次战胜了拖延症,奖励自己一块西瓜🍉!