用Tensorflow实现中文手写汉字识别

姓名:冯子豪

学号:16020199001

转载自https://zhuanlan.zhihu.com/p/24899387

【嵌牛导读】TensorFlow是Google基于DistBelief进行研发的第二代人工智能学习系统,被广泛用于语音识别或图像识别等多项机器深度学习领域。其命名来源于本身的运行原理。Tensor(张量)意味着N维数组,Flow(流)意味着基于数据流图的计算,TensorFlow代表着张量从图象的一端流动到另一端计算过程,是将复杂的数据结构传输至人工智能神经网中进行分析和处理的过程。

TensorFlow完全开源,任何人都可以使用。可在小到一部智能手机、大到数千台数据中心服务器的各种设备上运行。

本文带大家深入解析Tensorflow系统的技术实践,由浅入深,一起走上深度学习的进阶之路。

【嵌牛鼻子】Tensorflow 模式识别 手写汉字 

【嵌牛提问】之前可能大家已经接触了mnist数据集,做出了手写数字识别,那么如何在此基础上进一步做出手写汉字识别呢?

【嵌牛正文】

Goal

本文目标是利用TensorFlow做一个简单的图像分类器,在比较大的数据集上,尽可能高效地做图像相关处理,从Train,Validation到Inference,是一个比较基本的Example, 从一个基本的任务学习如果在TensorFlow下做高效地图像读取,基本的图像处理,整个项目很简单,但其中有一些trick,在实际项目当中有很大的好处, 比如绝对不要一次读入所有的 的数据到内存(尽管在Mnist这类级别的例子上经常出现)…

最开始看到是这篇blog里面的TensorFlow练习22: 手写汉字识别, 但是这篇文章只用了140训练与测试,试了下代码 很快,但是当扩展到所有的时,发现32g的内存都不够用,这才注意到原文中都是用numpy,会先把所有的数据放入到内存,但这个不必须的,无论在MXNet还是TensorFlow中都是不必 须的,MXNet使用的是DataIter,会在程序运行的过程中异步读取数据,TensorFlow也是这样的,TensorFlow封装了高级的api,用来做数据的读取,比如TFRecord,还有就是从filenames中读取, 来异步读取文件,然后做shuffle batch,再feed到模型的Graph中来做模型参数的更新。具体在tf如何做数据的读取可以看看reading data in tensorflow

这里我会拿到所有的数据集来做训练与测试,算作是对斗大的熊猫上面那篇文章的一个扩展。

Batch Generate

数据集来自于中科院自动化研究所,感谢分享精神!!!具体下载:

wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1trn_gnt.zipwget http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1tst_gnt.zip

解压后发现是一些gnt文件,然后用了斗大的熊猫里面的代码,将所有文件都转化为对应label目录下的所有png的图片。(注意在HWDB1.1trn_gnt.zip解压后是alz文件,需要再次解压 我在mac没有找到合适的工具,windows上有alz的解压工具)。

import osimport numpy as npimport structfrom PIL import Imagedata_dir = '../data'train_data_dir = os.path.join(data_dir, 'HWDB1.1trn_gnt')test_data_dir = os.path.join(data_dir, 'HWDB1.1tst_gnt')def read_from_gnt_dir(gnt_dir=train_data_dir):    def one_file(f):        header_size = 10        while True:            header = np.fromfile(f, dtype='uint8', count=header_size)            if not header.size: break            sample_size = header[0] + (header[1]<<8) + (header[2]<<16) + (header[3]<<24)            tagcode = header[5] + (header[4]<<8)            width = header[6] + (header[7]<<8)            height = header[8] + (header[9]<<8)            if header_size + width*height != sample_size:                break            image = np.fromfile(f, dtype='uint8', count=width*height).reshape((height, width))            yield image, tagcode    for file_name in os.listdir(gnt_dir):        if file_name.endswith('.gnt'):            file_path = os.path.join(gnt_dir, file_name)            with open(file_path, 'rb') as f:                for image, tagcode in one_file(f):                    yield image, tagcodechar_set = set()for _, tagcode in read_from_gnt_dir(gnt_dir=train_data_dir):    tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312')    char_set.add(tagcode_unicode)char_list = list(char_set)char_dict = dict(zip(sorted(char_list), range(len(char_list))))print len(char_dict)import picklef = open('char_dict', 'wb')pickle.dump(char_dict, f)f.close()train_counter = 0test_counter = 0for image, tagcode in read_from_gnt_dir(gnt_dir=train_data_dir):    tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312')    im = Image.fromarray(image)    dir_name = '../data/train/' + '%0.5d'%char_dict[tagcode_unicode]    if not os.path.exists(dir_name):        os.mkdir(dir_name)    im.convert('RGB').save(dir_name+'/' + str(train_counter) + '.png')    train_counter += 1for image, tagcode in read_from_gnt_dir(gnt_dir=test_data_dir):    tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312')    im = Image.fromarray(image)    dir_name = '../data/test/' + '%0.5d'%char_dict[tagcode_unicode]    if not os.path.exists(dir_name):        os.mkdir(dir_name)    im.convert('RGB').save(dir_name+'/' + str(test_counter) + '.png')    test_counter += 1

处理好的数据,放到了云盘,大家可以直接在我的云盘来下载处理好的数据集HWDB1. 这里说明下,char_dict是汉字和对应的数字label的记录。

得到数据集后,就要考虑如何读取了,一次用numpy读入内存在很多小数据集上是可以行的,但是在稍微大点的数据集上内存就成了瓶颈,但是不要害怕,TensorFlow有自己的方法:

def batch_data(file_labels,sess, batch_size=128):    image_list = [file_label[0] for file_label in file_labels]    label_list = [int(file_label[1]) for file_label in file_labels]    print 'tag2 {0}'.format(len(image_list))    images_tensor = tf.convert_to_tensor(image_list, dtype=tf.string)    labels_tensor = tf.convert_to_tensor(label_list, dtype=tf.int64)    input_queue = tf.train.slice_input_producer([images_tensor, labels_tensor])    labels = input_queue[1]    images_content = tf.read_file(input_queue[0])    # images = tf.image.decode_png(images_content, channels=1)    images = tf.image.convert_image_dtype(tf.image.decode_png(images_content, channels=1), tf.float32)    # images = images / 256    images =  pre_process(images)    # print images.get_shape()    # one hot    labels = tf.one_hot(labels, 3755)    image_batch, label_batch = tf.train.shuffle_batch([images, labels], batch_size=batch_size, capacity=50000,min_after_dequeue=10000)    # print 'image_batch', image_batch.get_shape()    coord = tf.train.Coordinator()    threads = tf.train.start_queue_runners(sess=sess, coord=coord)    return image_batch, label_batch, coord, threads

简单介绍下,首先你需要得到所有的图像的path和对应的label的列表,利用tf.convert_to_tensor转换为对应的tensor, 利用tf.train.slice_input_producer将image_list ,label_list做一个slice处理,然后做图像的读取、预处理,以及label的one_hot表示,然后就是传到tf.train.shuffle_batch产生一个个shuffle batch,这些就可以feed到你的 模型。 slice_input_producer和shuffle_batch这类操作内部都是基于queue,是一种异步的处理方式,会在设备中开辟一段空间用作cache,不同的进程会分别一直往cache中塞数据 和取数据,保证内存或显存的占用以及每一个mini-batch不需要等待,直接可以从cache中获取。

Data Augmentation

由于图像场景不复杂,只是做了一些基本的处理,包括图像翻转,改变下亮度等等,这些在TensorFlow里面有现成的api,所以尽量使用TensorFlow来做相关的处理:

def pre_process(images):    if FLAGS.random_flip_up_down:        images = tf.image.random_flip_up_down(images)    if FLAGS.random_flip_left_right:        images = tf.image.random_flip_left_right(images)    if FLAGS.random_brightness:        images = tf.image.random_brightness(images, max_delta=0.3)    if FLAGS.random_contrast:        images = tf.image.random_contrast(images, 0.8, 1.2)    new_size = tf.constant([FLAGS.image_size,FLAGS.image_size], dtype=tf.int32)    images = tf.image.resize_images(images, new_size)    return images

Build Graph

这里很简单的构造了一个两个卷积+一个全连接层的网络,没有做什么更深的设计,感觉意义不大,设计了一个dict,用来返回后面要用的所有op,还有就是为了方便再训练中查看loss和accuracy, 没有什么特别的,很容易理解, labels 为None时 方便做inference。

def network(images, labels=None):    endpoints = {}    conv_1 = slim.conv2d(images, 32, [3,3],1, padding='SAME')    max_pool_1 = slim.max_pool2d(conv_1, [2,2],[2,2], padding='SAME')    conv_2 = slim.conv2d(max_pool_1, 64, [3,3],padding='SAME')    max_pool_2 = slim.max_pool2d(conv_2, [2,2],[2,2], padding='SAME')    flatten = slim.flatten(max_pool_2)    out = slim.fully_connected(flatten,3755, activation_fn=None)    global_step = tf.Variable(initial_value=0)    if labels is not None:        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(out, labels))        train_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss, global_step=global_step)        accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(out, 1), tf.argmax(labels, 1)), tf.float32))        tf.summary.scalar('loss', loss)        tf.summary.scalar('accuracy', accuracy)        merged_summary_op = tf.summary.merge_all()    output_score = tf.nn.softmax(out)    predict_val_top3, predict_index_top3 = tf.nn.top_k(output_score, k=3)    endpoints['global_step'] = global_step    if labels is not None:        endpoints['labels'] = labels        endpoints['train_op'] = train_op        endpoints['loss'] = loss        endpoints['accuracy'] = accuracy        endpoints['merged_summary_op'] = merged_summary_op    endpoints['output_score'] = output_score    endpoints['predict_val_top3'] = predict_val_top3    endpoints['predict_index_top3'] = predict_index_top3    return endpoints

Train

train函数包括从已有checkpoint中restore,得到step,快速恢复训练过程,训练主要是每一次得到mini-batch,更新参数,每隔eval_steps后做一次train batch的eval,每隔save_steps 后保存一次checkpoint。

def train():    sess = tf.Session()    file_labels = get_imagesfile(FLAGS.train_data_dir)    images, labels, coord, threads = batch_data(file_labels, sess)    endpoints = network(images, labels)    saver = tf.train.Saver()    sess.run(tf.global_variables_initializer())    train_writer = tf.train.SummaryWriter('./log' + '/train',sess.graph)    test_writer = tf.train.SummaryWriter('./log' + '/val')    start_step = 0    if FLAGS.restore:        ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)        if ckpt:            saver.restore(sess, ckpt)            print "restore from the checkpoint {0}".format(ckpt)            start_step += int(ckpt.split('-')[-1])    logger.info(':::Training Start:::')    try:        while not coord.should_stop():        # logger.info('step {0} start'.format(i))            start_time = time.time()            _, loss_val, train_summary, step = sess.run([endpoints['train_op'], endpoints['loss'], endpoints['merged_summary_op'], endpoints['global_step']])            train_writer.add_summary(train_summary, step)            end_time = time.time()            logger.info("the step {0} takes {1} loss {2}".format(step, end_time-start_time, loss_val))            if step > FLAGS.max_steps:                break            # logger.info("the step {0} takes {1} loss {2}".format(i, end_time-start_time, loss_val))            if step % FLAGS.eval_steps == 1:                accuracy_val,test_summary, step = sess.run([endpoints['accuracy'], endpoints['merged_summary_op'], endpoints['global_step']])                test_writer.add_summary(test_summary, step)                logger.info('===============Eval a batch in Train data=======================')                logger.info( 'the step {0} accuracy {1}'.format(step, accuracy_val))                logger.info('===============Eval a batch in Train data=======================')            if step % FLAGS.save_steps == 1:                logger.info('Save the ckpt of {0}'.format(step))                saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'), global_step=endpoints['global_step'])    except tf.errors.OutOfRangeError:        # print "============train finished========="        logger.info('==================Train Finished================')        saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'), global_step=endpoints['global_step'])    finally:        coord.request_stop()    coord.join(threads)    sess.close()

Graph

Loss and Accuracy

Validation

训练完成之后,想对最终的模型在测试数据集上做一个评估,这里我也曾经尝试利用batch_data,将slice_input_producer中epoch设置为1,来做相关的工作,但是发现这里无法和train 共用,会出现epoch无初始化值的问题(train中传epoch为None),所以这里自己写了shuffle batch的逻辑,将测试集的images和labels通过feed_dict传进到网络,得到模型的输出, 然后做相关指标的计算:

def validation():    # it should be fixed by using placeholder with epoch num in train stage    sess = tf.Session()    file_labels = get_imagesfile(FLAGS.test_data_dir)    test_size = len(file_labels)    print test_size    val_batch_size = FLAGS.val_batch_size    test_steps = test_size / val_batch_size    print test_steps    # images, labels, coord, threads= batch_data(file_labels, sess)    images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1])    labels = tf.placeholder(dtype=tf.int32, shape=[None,3755])    # read batch images from file_labels    # images_batch = np.zeros([128,64,64,1])    # labels_batch = np.zeros([128,3755])    # labels_batch[0][20] = 1    #    endpoints = network(images, labels)    saver = tf.train.Saver()    ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)    if ckpt:        saver.restore(sess, ckpt)        # logger.info("restore from the checkpoint {0}".format(ckpt))    # logger.info('Start validation')    final_predict_val = []    final_predict_index = []    groundtruth = []    for i in range(test_steps):        start = i* val_batch_size        end = (i+1)*val_batch_size        images_batch = []        labels_batch = []        labels_max_batch = []        logger.info('=======start validation on {0}/{1} batch========='.format(i, test_steps))        for j in range(start,end):            image_path = file_labels[j][0]            temp_image = Image.open(image_path).convert('L')            temp_image = temp_image.resize((FLAGS.image_size, FLAGS.image_size),Image.ANTIALIAS)            temp_label = np.zeros([3755])            label = int(file_labels[j][1])            # print label            temp_label[label] = 1            # print "====",np.asarray(temp_image).shape            labels_batch.append(temp_label)            # print "====",np.asarray(temp_image).shape            images_batch.append(np.asarray(temp_image)/255.0)            labels_max_batch.append(label)        # print images_batch        images_batch = np.array(images_batch).reshape([-1, 64, 64, 1])        labels_batch = np.array(labels_batch)        batch_predict_val, batch_predict_index = sess.run([endpoints['predict_val_top3'],                        endpoints['predict_index_top3']], feed_dict={images:images_batch, labels:labels_batch})        logger.info('=======validation on {0}/{1} batch end========='.format(i, test_steps))        final_predict_val += batch_predict_val.tolist()        final_predict_index += batch_predict_index.tolist()        groundtruth += labels_max_batch    sess.close()    return final_predict_val, final_predict_index, groundtruth

在训练20w个step之后,大概能达到在测试集上能够达到:

相信如果在网络设计上多花点时间能够在一定程度上提升accuracy和top 3 accuracy.有兴趣的小伙伴们可以玩玩这个数据集。

Inference

def inference(image):    temp_image = Image.open(image).convert('L')    temp_image = temp_image.resize((FLAGS.image_size, FLAGS.image_size),Image.ANTIALIAS)    sess = tf.Session()    logger.info('========start inference============')    images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1])    endpoints = network(images)    saver = tf.train.Saver()    ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)    if ckpt:        saver.restore(sess, ckpt)    predict_val, predict_index = sess.run([endpoints['predict_val_top3'],endpoints['predict_index_top3']], feed_dict={images:temp_image})    sess.close()    return final_predict_val, final_predict_index

运气挺好,随便找了张图片就能准确识别出来

Summary

综上,就是利用tensorflow做中文手写识别的全部,从如何使用tensorflow内部的queue来有效读入数据,到如何设计network, 到如何做train,validation,inference,珍格格流程比较清晰, 美中不足的是,原本打算是在训练过程中,来对测试集做评估,但是在使用queue读test_data_dir下的filenames,和train本身的好像有点问题,不过应该是可以解决的,我这里就pass了。另外可能 还有一些可以改善的地方,比如感觉可以把batch data one hot的部分写入到network,这样,减缓在validation时内存会因为onehot的sparse开销比较大。

感觉这个中文手写汉字数据集价值很大,后面感觉会有好多可以玩的,比如

可以参考项亮大神的这篇文章端到端的OCR:验证码识别做定长的字符识别和不定长的字符识别,定长的基本原理是说,可以把最终输出扩展为k个输出, 每个值表示对应的字符label,这样cnn模型在feature extract之后就可以自己去识别对应字符而无需人工切割;而LSTM+CTC来解决不定长的验证码,类似于将音频解码为汉字

最近GAN特别火,感觉可以考虑用这个数据来做某个字的生成,和text2img那个项目text-to-image

这部分的代码都在我的github上tensorflow-101,有遇到相关功能,想参考代码的可以去上面找找,没准就能解决你们遇到的一些小问题.

——————

相关阅读推荐:

机器学习进阶笔记之七 | MXnet初体验

机器学习进阶笔记之六 | 深入理解Fast Neural Style

机器学习进阶笔记之五 | 深入理解VGG\Residual Network

机器学习进阶笔记之四 | 深入理解GoogLeNet

机器学习进阶笔记之三 | 深入理解Alexnet 

机器学习进阶笔记之二 | 深入理解Neural Style

机器学习进阶笔记之一 | TensorFlow安装与入门 

​本文由『UCloud内核与虚拟化研发团队』提供。

关于作者:

Burness(

@段石石

 ), UCloud平台研发中心深度学习研发工程师,tflearn Contributor & tensorflow Contributor,做过电商推荐、精准化营销相关算法工作,专注于分布式深度学习框架、计算机视觉算法研究,平时喜欢玩玩算法,研究研究开源的项目,偶尔也会去一些数据比赛打打酱油,生活中是个极客,对新技术、新技能痴迷。

你可以在Github上找到他:http://hacker.duanshishi.com/

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 196,099评论 5 462
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 82,473评论 2 373
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 143,229评论 0 325
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 52,570评论 1 267
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 61,427评论 5 358
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 46,335评论 1 273
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 36,737评论 3 386
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 35,392评论 0 254
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 39,693评论 1 294
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 34,730评论 2 312
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 36,512评论 1 326
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 32,349评论 3 314
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 37,750评论 3 299
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,017评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 30,290评论 1 251
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 41,706评论 2 342
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 40,904评论 2 335

推荐阅读更多精彩内容