直接上代码
#####################生成神经网络模型和训练部分的代码######################
#coding:utf-8
# !/usr/bin/env python2
"""
tf CNN+LSTM+CTC 训练识别不定长数字字符图片
@author: pengyuanjie
"""
fromgenIDCardimport*
importnumpyasnp
importtime
importos
os.environ['TF_CPP_MIN_LOG_LEVEL'] ='2'
importtensorflowastf
importsys
reload(sys)
sys.setdefaultencoding('utf-8')
# 定义一些常量
# 图片大小,32 x 256
OUTPUT_SHAPE = (32,256)
# 训练最大轮次
num_epochs =1000
num_hidden =64
num_layers =1
obj = gen_id_card()
num_classes = obj.len +1+1# 10位数字 + blank + ctc blank
# 初始化学习速率
INITIAL_LEARNING_RATE =1e-3
DECAY_STEPS =5000
REPORT_STEPS =10
LEARNING_RATE_DECAY_FACTOR =0.9# The learning rate decay factor
MOMENTUM =0.9
########################################
FIRST_INDEX =1
DIGITS ="0123456789uvw"
LETTERS ="ABCDEFGHIJKLMNOPQRSTUVWXYZ"
CHARS =list(DIGITS + LETTERS)
char2num_dic = {'0':'0','1':'1','2':'2','3':'3','4':'4','5':'5','6':'6','7':'7','8':'8','9':'9','a':'11','b':'12','c':'13','d':'14','e':'15','f':'16','g':'17','h':'18','i':'19','j':'20','k':'21','l':'22','m':'23','n':'24','o':'25','p':'26','q':'27','r':'28','s':'29','t':'30','u':'31','v':'32','w':'33','x':'34','y':'35','z':'36','A':
'37','B':'38','C':'39','D':'40','E':'41','F':'42','G':'43','H':'44','I':'45','J':'46','K':'47','L':'48','M':'49','N':'50','O':'51','P':'52','Q':'53','R':
'54','S':'55','T':'56','U':'57','V':'58','W':'59','X':'60','Y':'61','Z':'62'}
num2char_dic = {"0":"0","1":"1","2":"2","3":"3","4":"4","5":"5","6":"6","7":"7","8":"8","9":"9","11":"a","12":"b","13":"c","14":"d","15":"e","16":"f","17":"g","18":"h","19":"i","20":"j","21":"k","22":"l","23":"m","24":"n","25":"o","26":"p","27":"q","28":"r","29":
"s","30":"t","31":"u","32":"v","33":"w","34":"x","35":"y","36":"z","37":"A","38":"B","39":"C","40":"D","41":"E","42":"F","43":"G","44":"H","45":"I","46":"J","47":
"K","48":"L","49":"M","50":"N","51":"O","52":"P","53":"Q","54":"R","55":"S","56":"T","57":"U","58":"V","59":"W","60":"X","61":"Y","62":"Z"}
#####################################
DIGITS ='0123456789'
BATCHES =10
BATCH_SIZE =64
TRAIN_SIZE = BATCHES * BATCH_SIZE
# #@@@@@@@@@@
#
# def decode_sparse_tensor(sparse_tensor):
# # print(sparse_tensor)
# decoded_indexes = list()
# current_i = 0
# current_seq = []
# for offset, i_and_index in enumerate(sparse_tensor[0]):
# i = i_and_index[0]
# if i != current_i:
# decoded_indexes.append(current_seq)
# current_i = i
# current_seq = list()
# current_seq.append(offset)
# decoded_indexes.append(current_seq)
# #
# # print("mmmm", decoded_indexes)
# result = []
# for index in decoded_indexes:
# result.append(decode_a_seq(index, sparse_tensor))
# return result
#
#
# #@@@@@@@@
defdecode_sparse_tensor(sparse_tensor):
print("wahahahahahahahahahahahhahahahahahh")
# print("sparse_tensor = ", sparse_tensor)
decoded_indexes =list()
current_i =0
current_seq = []
foroffset, i_and_indexinenumerate(sparse_tensor[0]):
i = i_and_index[0]
ifi != current_i:
decoded_indexes.append(current_seq)
current_i = i
current_seq =list()
current_seq.append(offset)
decoded_indexes.append(current_seq)
# print("decoded_indexes = ", decoded_indexes)
result = []
forindexindecoded_indexes:
# print("index = ", index)
result.append(decode_a_seq(index, sparse_tensor))
# print(result)
returnresult
#@@@@@@@
defdecode_a_seq(indexes, spars_tensor):
print("kkkkkkkkkkkkkkkk")
str_decoded =''.join([CHARS[spars_tensor[1][m] - FIRST_INDEX]forminindexes])
# Replacing blank label to none
str_decoded = str_decoded.replace(chr(ord('9') +1),'')
# Replacing space label to space
str_decoded = str_decoded.replace(chr(ord('0') -1),' ')
# print("ffffffff", str_decoded)
returnstr_decoded
# #@@@@@@@@@@
#
# def decode_a_seq(indexes, spars_tensor):
# decoded = []
# for m in indexes:
# str = DIGITS[spars_tensor[1][m]]
# decoded.append(str)
# # Replacing blank label to none
# # str_decoded = str_decoded.replace(chr(ord('9') + 1), '')
# # Replacing space label to space
# # str_decoded = str_decoded.replace(chr(ord('0') - 1), ' ')
# # print("ffffffff", str_decoded)
# return decoded
#
#
#
#
# #@@@@@@@@@@@@
#
# def sparse_tuple_from(sequences, dtype=np.int32):
# """Create a sparse representention of x.
# Args:
# sequences: a list of lists of type dtype where each element is a sequence
# Returns:
# A tuple with (indices, values, shape)
# """
# indices = []
# values = []
#
# for n, seq in enumerate(sequences):
# indices.extend(zip([n] * len(seq), range(len(seq))))
# values.extend(seq)
#
# indices = np.asarray(indices, dtype=np.int64)
#
#
# values = np.asarray(values, dtype=type)
# shape = np.asarray([len(sequences), np.asarray(indices).max(0)[1] + 1], dtype=np.int64)
#
# return indices, values, shape
#@@@@@@@@@
# 转化一个序列列表为稀疏矩阵
defsparse_tuple_from(sequences, dtype=np.int32):
"""
Create a sparse representention of x.
Args:
sequences: a list of lists of type dtype where each element is a sequence
Returns:
A tuple with (indices, values, shape)
"""
indices = []
values = []
forn, seqinenumerate(sequences):
indices.extend(zip([n] *len(seq),xrange(len(seq))))
values.extend(seq)
indices = np.asarray(indices,dtype=np.int64)
values = np.asarray(values,dtype=dtype)
shape = np.asarray([len(sequences), np.asarray(indices).max(0)[1] +1],dtype=np.int64)
returnindices, values, shape
defweight_variable(shape):
initial = tf.truncated_normal(shape,stddev=0.5)
returntf.Variable(initial)
defbias_variable(shape):
initial = tf.constant(0.1,shape=shape)
returntf.Variable(initial)
defconv2d(x, W, stride=(1,1), padding='SAME'):
returntf.nn.conv2d(x, W,strides=[1, stride[0], stride[1],1],padding=padding)
defmax_pool(x, ksize=(2,2), stride=(2,2)):
returntf.nn.max_pool(x,ksize=[1, ksize[0], ksize[1],1],strides=[1, stride[0], stride[1],1],padding='SAME')
defavg_pool(x, ksize=(2,2), stride=(2,2)):
returntf.nn.avg_pool(x,ksize=[1, ksize[0], ksize[1],1],strides=[1, stride[0], stride[1],1],padding='SAME')
###################################################@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@######
# 生成一个训练batch
defget_next_batch(batch_size=128):
obj = gen_id_card()
# (batch_size,256,32)
inputs = np.zeros([batch_size, OUTPUT_SHAPE[1], OUTPUT_SHAPE[0]])
codes = []
foriinrange(batch_size):
# 生成不定长度的字串
image, text, vec = obj.gen_image()
# np.transpose 矩阵转置 (32*256,) => (32,256) => (256,32)
inputs[i, :] = np.transpose(image.reshape((OUTPUT_SHAPE[0], OUTPUT_SHAPE[1])))
list4 =list(text)
forjinrange(len(list4)):
list4[j] = char2num_dic[list4[j]]
codes.append(list(text))
targets = [np.asarray(i)foriincodes]
#print targets
sparse_targets = sparse_tuple_from(targets)
print("targets888888888888999999", sparse_targets)
# (batch_size,) 值都是256
seq_len = np.ones(inputs.shape[0]) * OUTPUT_SHAPE[1]
returninputs, sparse_targets, seq_len
# 定义CNN网络,处理图片,
defconvolutional_layers():
# 输入数据,shape [batch_size, max_stepsize, num_features]
inputs = tf.placeholder(tf.float32, [None,None, OUTPUT_SHAPE[0]])
# 第一层卷积层, 32*256*1 => 16*128*48
W_conv1 = weight_variable([5,5,1,48])
b_conv1 = bias_variable([48])
x_expanded = tf.expand_dims(inputs,3)
h_conv1 = tf.nn.relu(conv2d(x_expanded, W_conv1) + b_conv1)
h_pool1 = max_pool(h_conv1,ksize=(2,2),stride=(2,2))
# 第二层, 16*128*48 => 16*64*64
W_conv2 = weight_variable([5,5,48,64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool(h_conv2,ksize=(2,1),stride=(2,1))
# 第三层, 16*64*64 => 8*32*128
W_conv3 = weight_variable([5,5,64,128])
b_conv3 = bias_variable([128])
h_conv3 = tf.nn.relu(conv2d(h_pool2, W_conv3) + b_conv3)
h_pool3 = max_pool(h_conv3,ksize=(2,2),stride=(2,2))
# 全连接
W_fc1 = weight_variable([16*8* OUTPUT_SHAPE[1], OUTPUT_SHAPE[1]])
b_fc1 = bias_variable([OUTPUT_SHAPE[1]])
conv_layer_flat = tf.reshape(h_pool3, [-1,16*8* OUTPUT_SHAPE[1]])
features = tf.nn.relu(tf.matmul(conv_layer_flat, W_fc1) + b_fc1)
# (batchsize,256)
shape = tf.shape(features)
features = tf.reshape(features, [shape[0], OUTPUT_SHAPE[1],1])# batchsize * outputshape * 1
returninputs, features
defget_train_model():
# features = convolutional_layers()
# print features.get_shape()
inputs = tf.placeholder(tf.float32, [None,None, OUTPUT_SHAPE[0]])
# 定义ctc_loss需要的稀疏矩阵'
targets = tf.sparse_placeholder(tf.int32)
# 1维向量 序列长度 [batch_size,]
seq_len = tf.placeholder(tf.int32, [None])
# 定义LSTM网络
cell = tf.contrib.rnn.LSTMCell(num_hidden,state_is_tuple=True)
stack= tf.contrib.rnn.MultiRNNCell([cell] * num_layers,state_is_tuple=True)
outputs, _ = tf.nn.dynamic_rnn(cell, inputs, seq_len,dtype=tf.float32)
shape = tf.shape(inputs)
batch_s, max_timesteps = shape[0], shape[1]
outputs = tf.reshape(outputs, [-1, num_hidden])
W = tf.Variable(tf.truncated_normal([num_hidden,
num_classes],
stddev=0.1),name="W")
b = tf.Variable(tf.constant(0.,shape=[num_classes]),name="b")
logits = tf.matmul(outputs, W) + b
logits = tf.reshape(logits, [batch_s, -1, num_classes])
logits = tf.transpose(logits, (1,0,2))
returnlogits, inputs, targets, seq_len, W, b
deftrain():
global_step = tf.Variable(0,trainable=False)
learning_rate = tf.train.exponential_decay(INITIAL_LEARNING_RATE,
global_step,
DECAY_STEPS,
LEARNING_RATE_DECAY_FACTOR,
staircase=True)
logits, inputs, targets, seq_len, W, b = get_train_model()
# tragets是一个稀疏矩阵
loss = tf.nn.ctc_loss(labels=targets,inputs=logits,sequence_length=seq_len)
cost = tf.reduce_mean(loss)
# optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,momentum=MOMENTUM).minimize(cost, global_step=global_step)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss,global_step=global_step)
# 前面说的划分块之后找每块的类属概率分布,ctc_beam_search_decoder方法,是每次找最大的K个概率分布
# 还有一种贪心策略是只找概率最大那个,也就是K=1的情况ctc_ greedy_decoder
decoded, log_prob = tf.nn.ctc_beam_search_decoder(logits, seq_len,merge_repeated=False)
acc = tf.reduce_mean(tf.edit_distance(tf.cast(decoded[0], tf.int32), targets))
init = tf.global_variables_initializer()
defreport_accuracy(decoded_list, test_targets):
original_list = decode_sparse_tensor(test_targets)
detected_list = decode_sparse_tensor(decoded_list)
true_numer =0
iflen(original_list) !=len(detected_list):
print("len(original_list)",len(original_list),"len(detected_list)",len(detected_list),
" test and detect length desn't match")
return
print("T/F: original(length) <-------> detectcted(length)")
foridx, numberinenumerate(original_list):
detect_number = detected_list[idx]
hit = (number == detect_number)
#print(hit, number, "(", len(number), ") <-------> ", detect_number, "(", len(detect_number), ")")
ifhit:
true_numer = true_numer +1
print("Test Accuracy:", true_numer *1.0/len(original_list))
accurance = true_numer *1.0/len(original_list)
ifaccurance ==1:
save_path= saver.save(session,"./ocr.model",global_step=steps)
defdo_report():
test_inputs, test_targets, test_seq_len = get_next_batch(BATCH_SIZE)
test_feed = {inputs: test_inputs,
targets: test_targets,
seq_len: test_seq_len}
dd, log_probs, accuracy = session.run([decoded[0], log_prob, acc], test_feed)
print("accuracy=======BBBBBBBBBBBBBBBBBBBBBBB:")
report_accuracy(dd, test_targets)
defdo_batch():
train_inputs, train_targets, train_seq_len = get_next_batch(BATCH_SIZE)
print("JJJJJJJBBBBBBBBBBBBBYYYYYYYYY",targets,"CCCCCC",train_inputs,"DDDDD",train_seq_len)
feed = {inputs: train_inputs, targets: train_targets, seq_len: train_seq_len}
b_loss, b_targets, b_logits, b_seq_len, b_cost, steps, _ = session.run(
[loss, targets, logits, seq_len, cost, global_step, optimizer], feed)
print("BBBBBBBBSSSSSSS",)
printb_cost, steps
ifsteps >0andsteps % REPORT_STEPS ==0:
do_report()
returnb_cost, steps
withtf.Session()assession:
session.run(init)
saver = tf.train.Saver(tf.global_variables(),max_to_keep=10)
forcurr_epochinxrange(num_epochs):
print("Epoch.......", curr_epoch)
train_cost = train_ler =0
forbatchinxrange(BATCHES):
start = time.time()
c, steps = do_batch()
train_cost += c * BATCH_SIZE
seconds = time.time() - start
print("Step:", steps,", batch seconds:", seconds)
train_cost /= TRAIN_SIZE
train_inputs, train_targets, train_seq_len = get_next_batch(BATCH_SIZE)
val_feed = {inputs: train_inputs,
targets: train_targets,
seq_len: train_seq_len}
val_cost, val_ler, lr, steps = session.run([cost, acc, learning_rate, global_step],feed_dict=val_feed)
log ="Epoch {}/{}, steps = {}, train_cost = {:.3f}, train_ler = {:.3f}, val_cost = {:.3f}, val_ler = {:.3f}, time = {:.3f}s, learning_rate = {}"
print(
log.format(curr_epoch +1, num_epochs, steps, train_cost, train_ler, val_cost, val_ler, time.time() - start,
lr))
#####################预测函数部分的代码######################
#@@@@@@@@@@@@@@@@@@@
defdetect(test_inputs, test_targets, test_seq_len):
logits, inputs, targets, seq_len, W, b = get_train_model()
decoded, log_prob = tf.nn.ctc_beam_search_decoder(logits, seq_len,merge_repeated=False)
saver = tf.train.Saver()
withtf.Session()assess:
# Restore variables from disk.
# saver.restore(sess, "models/ocr.model-0.95-94999")
saver.restore(sess, tf.train.latest_checkpoint('.'))
print("Model restored.")
#feed_dict = {inputs: test_inputs, targets: test_targets, seq_len: test_seq_len}
feed_dict = {inputs: test_inputs, seq_len: test_seq_len}
dd = sess.run(decoded[0],feed_dict=feed_dict)
#return decode_sparse_tensor(dd)
original_list = decode_sparse_tensor(test_targets)
detected_list = decode_sparse_tensor(dd)
true_numer =0
# print(detected_list)
# if len(original_list) != len(detected_list):
# print("len(original_list)", len(original_list), "len(detected_list)", len(detected_list),
# " test and detect length desn't match")
# return
print("T/F: original(length) <-------> detectcted(length)")
print("detectlist=:", detected_list)
foridx, numberinenumerate(original_list):
detect_number = detected_list[idx]
print("真实值:")
print(number)
#print("真实值:",number, "(", len(number), ") <------->预测值: ", detect_number, "(", len(detect_number), ")")
print("预测值:")
print(detect_number)
if(len(number) ==len(detect_number)):
hit =True
foridy, valueinenumerate(number):
detect_value = detect_number[idy]
if(value != detect_value):
hit =False
break
ifhit:
true_numer = true_numer +1
accuraccy = true_numer *1.0/len(original_list)
print("Test Accuracy:", accuraccy)
returnaccuraccy
#####################测试部分######################
#@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
if__name__ =='__main__':
inputs, sparse_targets, seq_len = get_next_batch(1)
print(detect(inputs, sparse_targets, seq_len))
# train()
#####################生成字符串图片的部分######################
# -*- coding: utf-8 -*-
#!/usr/bin/env python2
"""
身份证文字+数字生成类
@author: pengyuanjie
"""
importnumpyasnp
importfreetype
importmatplotlib.imageasmpimg
importcopy
importrandom
importcv2
importsys
reload(sys)
sys.setdefaultencoding('utf-8')
classput_chinese_text(object):
def__init__(self, ttf):
self._face = freetype.Face(ttf)
defdraw_text(self, image, pos, text, text_size, text_color):
'''
draw chinese(or not) text with ttf
:paramimage: image(numpy.ndarray) to draw text
:parampos: where to draw text
:paramtext: the context, for chinese should be unicode type
:paramtext_size: text size
:paramtext_color:text color
:return: image
'''
self._face.set_char_size(text_size *64)
metrics =self._face.size
ascender = metrics.ascender /64.0
# descender = metrics.descender/64.0
# height = metrics.height/64.0
# linegap = height - ascender + descender
ypos =int(ascender)
if notisinstance(text,unicode):
text = text.decode('utf-8')
img =self.draw_string(image, pos[0], pos[1] + ypos, text, text_color)
returnimg
defdraw_string(self, img, x_pos, y_pos, text, color):
'''
draw string
:paramx_pos: text x-postion on img
:paramy_pos: text y-postion on img
:paramtext: text (unicode)
:paramcolor: text color
:return: image
'''
prev_char =0
pen = freetype.Vector()
pen.x = x_pos <<6# div 64
pen.y = y_pos <<6
hscale =1.0
matrix = freetype.Matrix(int(hscale) *0x10000L,int(0.2*0x10000L), \
int(0.0*0x10000L),int(1.1*0x10000L))
cur_pen = freetype.Vector()
pen_translate = freetype.Vector()
image = copy.deepcopy(img)
forcur_charintext:
self._face.set_transform(matrix, pen_translate)
self._face.load_char(cur_char)
kerning =self._face.get_kerning(prev_char, cur_char)
pen.x += kerning.x
slot =self._face.glyph
bitmap = slot.bitmap
cur_pen.x = pen.x
cur_pen.y = pen.y - slot.bitmap_top *64
self.draw_ft_bitmap(image, bitmap, cur_pen, color)
pen.x += slot.advance.x
prev_char = cur_char
returnimage
defdraw_ft_bitmap(self, img, bitmap, pen, color):
'''
draw each char
:parambitmap: bitmap
:parampen: pen
:paramcolor: pen color e.g.(0,0,255) - red
:return: image
'''
x_pos = pen.x >>6
y_pos = pen.y >>6
cols = bitmap.width
rows = bitmap.rows
glyph_pixels = bitmap.buffer
forrowinrange(rows):
forcolinrange(cols):
ifglyph_pixels[row * cols + col] !=0:
img[y_pos + row][x_pos + col][0] = color[0]
img[y_pos + row][x_pos + col][1] = color[1]
img[y_pos + row][x_pos + col][2] = color[2]
classgen_id_card(object):
def__init__(self):
# self.words = open('AllWords.txt', 'r').read().split(' ')
self.words = ['A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z']
self.number = ['0','1','2','3','4','5','6','7','8','9']
self.char2num_dic = {'0':'0','1':'1','2':'2','3':'3','4':'4','5':'5','6':'6','7':'7','8':'8','9':'9','a':'11','b':'12','c':'13','d':'14','e':'15','f':'16','g':'17','h':'18','i':'19','j':'20','k':'21','l':'22','m':'23','n':'24','o':'25','p':'26','q':'27','r':'28','s':'29','t':'30','u':'31','v':'32','w':'33','x':'34','y':'35','z':'36','A':
'37','B':'38','C':'39','D':'40','E':'41','F':'42','G':'43','H':'44','I':'45','J':'46','K':'47','L':'48','M':'49','N':'50','O':'51','P':'52','Q':'53','R':
'54','S':'55','T':'56','U':'57','V':'58','W':'59','X':'60','Y':'61','Z':'62'}
self.num2char_dic = {"0":"0","1":"1","2":"2","3":"3","4":"4","5":"5","6":"6","7":"7","8":"8","9":"9","11":"a","12":"b","13":"c","14":"d","15":"e","16":"f","17":"g","18":"h","19":"i","20":"j","21":"k","22":"l","23":"m","24":"n","25":"o","26":"p","27":"q","28":"r","29":
"s","30":"t","31":"u","32":"v","33":"w","34":"x","35":"y","36":"z","37":"A","38":"B","39":"C","40":"D","41":"E","42":"F","43":"G","44":"H","45":"I","46":"J","47":
"K","48":"L","49":"M","50":"N","51":"O","52":"P","53":"Q","54":"R","55":"S","56":"T","57":"U","58":"V","59":"W","60":"X","61":"Y","62":"Z"}
self.char_set =self.number
# self.char_set = self.words + self.number
self.len =len(self.char_set)
self.max_size =18
self.ft = put_chinese_text('shuideteshuziti.ttf')
#@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
defgen_text(self, is_ran=True):
text =''
vecs = np.zeros((self.max_size *self.len))
# 唯一变化,随机设定长度
ifis_ran ==True:
size = random.randint(1,self.max_size)
else:
size =self.max_size
foriinrange(size):
c = random.choice(self.char_set)
vec =self.char2vec(c)
text = text + c
vecs[i *self.len:(i +1) *self.len] = np.copy(vec)
returntext, vecs
#@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
# 随机生成字串,长度固定
# 返回text,及对应的向量
defrandom_text(self):
text =''
vecs = np.zeros((self.max_size *self.len))
# size = random.randint(1, self.max_size)
size =self.max_size
foriinrange(size):
c = random.choice(self.char_set)
vec =self.char2vec(c)
text = text + c
vecs[i *self.len:(i +1) *self.len] = np.copy(vec)
returntext, vecs
# 根据生成的text,生成image,返回标签和图片元素数据
defgen_image(self):
text, vec =self.gen_text()
img = np.zeros([32,256,3])
color_ = (255,255,255)# Write
pos = (0,0)
text_size =21
image =self.ft.draw_text(img, pos, text, text_size, color_)
# cv2.imshow('image', image[:, :, 2])
# cv2.waitKey(0)
# 仅返回单通道值,颜色对于汉字识别没有什么意义
# lena = mpimg.imread('jiqi4.jpg')
#
#
# shrink = cv2.resize(lena, (256, 32), interpolation=cv2.INTER_AREA)
#
# for a in range(len(text)):
#
# print("44444",text[a])
# print("55555", self.char2num_dic[text[a]])
#
#
# newstr = list(text)
# newstr[a] = self.char2num_dic[text[a]]
# text = ''.join(newstr)
# cv2.imshow('image', shrink)
# cv2.waitKey(0)
returnimage[:, :,2], text, vec
# 单字转向量
defchar2vec(self, c):
vec = np.zeros((self.len))
forjinrange(self.len):
ifself.char_set[j] == c:
vec[j] =1
returnvec
# 向量转文本
defvec2text(self, vecs):
text =''
v_len =len(vecs)
foriinrange(v_len):
if(vecs[i] ==1):
text = text +self.char_set[i %self.len]
returntext
#####################训练结果######################
原文参考链接:http://www.jianshu.com/p/45828b18f133