为了使用TensorFlow,我们需要理解以下5点:
- 用图表示计算
- 在Session的上下文中执行图运算
- 用Tensor表示数据
- 用Variable维护图的状态
- feed给op输入,fetch到op的输出
Overview
pass
什么是计算图
TF程序可以被结构化为两部分:构造图(construction phase)、执行图(execution phase)。在构造图阶段,组装一个图;执行阶段在Session中执行ops。
例如,我们一般在构造阶段创建一个图来表示和训练神经网络(设定训练的方法),在执行阶段,不断重复地执行图中训练相关的ops。
如何构建图
import tensorflow as tf
# Create a Constant op that produces a 1x2 matrix. The op is
# added as a node to the default graph.
#
# The value returned by the constructor represents the output
# of the Constant op.
matrix1 = tf.constant([[3., 3.]])
# Create another Constant that produces a 2x1 matrix.
matrix2 = tf.constant([[2.],[2.]])
# Create a Matmul op that takes 'matrix1' and 'matrix2' as inputs.
# The returned value, 'product', represents the result of the matrix
# multiplication.
product = tf.matmul(matrix1, matrix2)
上面的代码中没有指明图的类型,所以是默认图。对于我们的例子,默认图足够用了。在我们上面创建的默认图中,有三个结点:2个constant op、1个matmul op。为了执行我们定义好的op,需要在session中启动图。
在Session中启动图
构造好一个图之后,我们需要先创建一个Session对象,然后启动图运算。创建Session的时候,如果不设定任何参数,就创建默认图。
完整的Session API见于Session Class.
# Launch the default graph.
sess = tf.Session()
# To run the matmul op we call the session 'run()' method, passing 'product'
# which represents the output of the matmul op. This indicates to the call
# that we want to get the output of the matmul op back.
#
# All inputs needed by the op are run automatically by the session. They
# typically are run in parallel.
#
# The call 'run(product)' thus causes the execution of three ops in the
# graph: the two constants and matmul.
#
# The output of the op is returned in 'result' as a numpy `ndarray` object.
result = sess.run(product)
print(result)
# ==> [[ 12.]]
# Close the Session when we're done.
sess.close()
注意执行完毕后要关闭Session,释放资源。当然也可以使用Python的with
语句创建一个Session块,执行完毕后,由块负责自动关闭Session。
with tf.Session() as sess:
result = sess.run([product])
print(result)
TF自动根据图的定义翻译到可执行op,然后把计算任务分配到不同的计算资源上,例如CPU、GPU。我们不必显式指定CPU或GPU。如果有可利用的GPU,TF会默认使用第一个GPU,并分配尽可能多的计算任务。
如果想使用机器上的其他GPU(非第一个GPU),需要显式地把op分配到指定的GPU上。使用with ... Device
可以指定op所需的CPU或GPU:
with tf.Session() as sess:
with tf.device("/gpu:1"):
matrix1 = tf.constant([[3., 3.]])
matrix2 = tf.constant([[2.],[2.]])
product = tf.matmul(matrix1, matrix2)
...
Device是通过字符串指定的,现在支持的Device包括:
*"/cpu:0"
: 机器上的CPU.
*"/gpu:0"
: 1号GPU,如果你有一个的话。
*"/gpu:1"
: 2号GPU,以此类推。
更多关于TF和GPU的信息见这里使用GPU
在分布式Session中启动图
pass
Tensors
TF的程序只是用Tensor这种数据结构,ops之间传递的数据也都是Tensor。你可以把TF的tensor想象成多维数组或列表。tensor有静态类型(staic type)、秩(rank)、形状(shape). Rank、Shape和Type有更加详细的介绍。
Variables
图运算过程中的状态需要Variable维护,下面的例子是把Variable用作计数器使用。详细信息见Variable
# Create a Variable, that will be initialized to the scalar value 0.
state = tf.Variable(0, name="counter")
# Create an Op to add one to `state`.
one = tf.constant(1)
new_value = tf.add(state, one)
update = tf.assign(state, new_value)
# Variables must be initialized by running an `init` Op after having
# launched the graph. We first have to add the `init` Op to the graph.
init_op = tf.initialize_all_variables()
# Launch the graph and run the ops.
with tf.Session() as sess:
# Run the 'init' op
sess.run(init_op)
# Print the initial value of 'state'
print(sess.run(state))
# Run the op that updates 'state' and print 'state'.
for _ in range(3):
sess.run(update)
print(sess.run(state))
# output:
# 0
# 1
# 2
# 3
代码中的assign()和add()是图表达式中的一部分,所以只有当真正run()
的时候才会执行计算。
通常,我们会把统计模型中的参数看作是多个Variable。例如,把神经网络中的权重看成是Variable tensor。训练过程中通过重复运行训练图更新权重tensor。
Fetches
为了得到op运算的结果,只需要调用Session对象的run()
方法,需要获取的结果作为run()
的参数。
input1 = tf.constant([3.0])
input2 = tf.constant([2.0])
input3 = tf.constant([5.0])
intermed = tf.add(input2, input3)
mul = tf.mul(input1, intermed)
with tf.Session() as sess:
result = sess.run([mul, intermed])
print(result)
# output:
# [array([ 21.], dtype=float32), array([ 7.], dtype=float32)]
Feeds
上面的例子都是借助constant或者Variable类型,把tensor输入到计算图中。TF提供了更加直接的方式,即feed机制,方便把tensor直接输入到计算图中任何op。
使用feed的方式是,在你需要feed的op处调用run()
方法,并提供输入数据,类似于run(op_name, feed_data)
,其中要用到tf.placeholder()
占位符:
input1 = tf.placeholder(tf.float32)
input2 = tf.placeholder(tf.float32)
output = tf.mul(input1, input2)
with tf.Session() as sess:
print(sess.run([output], feed_dict={input1:[7.], input2:[2.]}))
# output:
# [array([ 14.], dtype=float32)]
运行的时候,如果不给占位符提供feed就会报错。较大规模的feed例子见MNIST fully-connected教程。