在https://zhuanlan.zhihu.com/p/27288913的基础上,重写了tf.Graph。
global_step = tf.Variable(0, trainable=False)
# placeholder
images = tf.placeholder(tf.float32, [BATCH_SIZE, 32, 32, 3], name='images')
labels = tf.placeholder(tf.int32, (BATCH_SIZE,), name='labels')
print("Done Initializing Training Placeholders")
labels不是one-hot模式,就是数字本身。
placeholder的第一维都是固定的batch_size。
# Build a Graph that computes the logits predictions from the placeholder
logits = CNN(images)
# Calculate loss
loss = cal_loss(logits, labels)
# Build a Graph that trains the model with one batch of examples and
# updates the model parameters.
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
logits的shape是(batch_size,10),是one-hot形式
cal_loss中,Logits的shape是(batch_size,10),而labels则是(batch_size,1),因此用的函数是tf.nn.sparse_softmax_cross_entropy_with_logits
。
训练部分:
for step in range(1000):
# Current batch number
batch_nb = step % nb_batches
# Current batch start and end indices
start, end = utils.batch_indices(batch_nb, data_length, BATCH_SIZE)
# Prepare dictionnary to feed the session with
feed_dict = {images: X_train[start:end],
labels: y_train[start:end]}
# Run training step
_, loss_value = sess.run([train_step, loss], feed_dict=feed_dict)
# Echo loss once in a while
if step % 20 == 0:
num_examples_per_step = BATCH_SIZE
examples_per_sec = num_examples_per_step / duration
sec_per_batch = float(duration)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print(format_str % (datetime.now(), step, loss_value,
examples_per_sec, sec_per_batch))
检测部分:
newbatch = math.ceil(1000 / BATCH_SIZE)
preds = np.zeros((1000, NUM_CLASS), dtype=np.float32)
# 检测数据有1000,分为64大小的部分循环检测
for cnt in range(0, int(newbatch + 1)):
# Compute batch start and end indices
start, end = utils.batch_indices(cnt, 1000, BATCH_SIZE)
# Prepare feed dictionary
feed_dict = {images: X_test[start:end]}
preds[start:end, :] = sess.run([logits], feed_dict=feed_dict)[0]#取第一维
precision = accuracy(preds, y_test)
print('Precision of teacher after training: ' + str(precision))
训练步长设置为0.1,正确率达到60%
训练步长设置为0.05,正确率达到65%
链接:https://github.com/yingtaomj/cnn-classification