变量:创建、初始化、保存和加载

官方教程

class tf.Variable

一个变量通过调用run()来维护图中的状态,通过创建一个类实例添加一个变量到图中。
Variable()构造函数接受一个变量的初始值,该初始值可以是任何类型、任何shape的张量。初始值定义了变量的类型和shape,变量的shape通常是固定的,初始值是可以通过assign methods来改变的。
如果你想稍后改变变量的shape,需要把assign methods中的validate_shape指定为False,示例:

tf.assign(
    ref,
    value,
    validate_shape=None,
    use_locking=None,
    name=None
)
#参数
ref:一个可变张量,应该来自一个变量节点,可能还未初始化
---------------------------------------------------------------------------------------------
value:一个张量,必需和ref类型相同,要分配给变量的值
---------------------------------------------------------------------------------------------
validate_shape:可选bool。默认为True。如果为true,则操作将验证“值”的形状与要分配的张量的形状相匹配。如果为false,'ref'将呈现'value'的形状
---------------------------------------------------------------------------------------------
use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention.
---------------------------------------------------------------------------------------------
name:操作的名称(可选)
#return
张量ref,该张量具有一个新的值

就像任何其他Tensor变量一样,创建的变量Variable()可以用作图中其他Ops的输入。此外,所有为Tensor该类重载的运算符都将转化为变量,因此您还可以通过对变量进行算术将节点添加到图中

import tensorflow as tf

# Create a variable.
w = tf.Variable(<initial-value>, name=<optional-name>)

# Use the variable in the graph like any Tensor.
y = tf.matmul(w, ...another variable or tensor...)

# The overloaded operators are available too.
z = tf.sigmoid(w + b)

# Assign a new value to the variable with `assign()` or a related method.
w.assign(w + 1.0)
w.assign_add(1.0)

启动计算图时,必须先明确初始化变量,然后才能运行使用其值的Ops。您可以通过运行初始化操作初始化变量,从保存文件中恢复变量,或者只需运行assign一个赋值给变量的操作。实际上,变量初始化器op只是一个assignOp,它将变量的初始值赋给变量本身

# Launch the graph in a session.
with tf.Session() as sess:
    # Run the variable initializer.
    sess.run(w.initializer)
    # ...you now can run ops that use the value of 'w'...

通用的初始化模式是用一个方便的函数initialize_all_variables()来初始化所有的变量,在启动计算图后run这个op:

# Add an Op to initialize all variables.
init_op = tf.initialize_all_variables()

# Launch the graph in a session.
with tf.Session() as sess:
    # Run the Op that initializes all variables.
    sess.run(init_op)
    # ...you can now run any Op that uses variable values...

如果您需要创建一个初始值依赖于另一个变量的变量,该初始值可通过initialized_value()函数获得,这确保变量按正确的顺序初始化
所有变量都会自动收集到创建它们的图形中。 默认情况下,构造函数将新变量添加到Graph集合GraphKeys.VARIABLES, 函数all_variables()返回该集合的内容
当我们建立模型的时候,我们需要参数变量进行更好的区分和收集,比如可训练参数weights和global step等超参数;变量构造函数提供了一个trainable=<bool>参数开关,如果 True,新变量也被添加到计算图集合 GraphKeys.TRAINABLE_VARIABLES中,trainable_variables()返回此集合的内容。各种Optimizer类将该集合收集的内容列表作为默认优化变量

Creating a variable.
tf.Variable.__init__(initial_value, trainable=True, collections=None, validate_shape=True, name=None)

用initial_value初始化值来创建一个新变量,新变量被添加到集合中,默认添加到[GraphKeys.VARIABLES]中.
如果trainable参数为True,变量也会添加到GraphKeys.TRAINABLE_VARIABLES中

#参数
initial_value:一个tensor(变量的初始值)或者一个可转换为tensor的python对象,除非validate_shape设置为False,否则必须指定形状 
---------------------------------------------------------------------------------------------
trainable:如果True,将变量添加到集GraphKeys.TRAINABLE_VARIABLES,各种Optimizer类将该集合收集的内容列表作为默认优化变量
---------------------------------------------------------------------------------------------
collections: List of graph collections keys. The new variable is added to these collections. Defaults to [GraphKeys.VARIABLES].
---------------------------------------------------------------------------------------------
validate_shape:如果False,允许变量初始化为未知形状的值。如果True,initial_value的shape必须是已知的
---------------------------------------------------------------------------------------------
name:可选的,变量的名称;默认为'Variable'并且会自动扩展
#返回
A Variable

#Raises:
ValueError: If the initial value does not have a shape and `validate_shape` is `True`.

对于变量默认名称自动扩展,在此举个例子:
import tensorflow as tf
with tf.name_scope('zzh') as scope:
    a = tf.Variable(tf.random_normal([10], stddev=0.35))
    b = tf.Variable(tf.random_normal([10], stddev=0.35))
    print(a.op.name, b.op.name)
#输出
zzh/Variable zzh/Variable_1        #自动扩展为独一无二的名称
tf.Variable.initialized_value()

返回一个已经初始化变量的值,在用已经初始化变量的值来初始化另一个变量时,可用到此操作
# Initialize 'v' with a random tensor.
v = tf.Variable(tf.truncated_normal([10, 40]))
# Use `initialized_value` to guarantee that `v` has been
# initialized before its value is used to initialize `w`.
# The random values are picked only once.
w = tf.Variable(v.initialized_value() * 2.0)
tf.Variable.assign(value, use_locking=False)
给一个变量赋一个新值,和tf.assign()的作用一样
#参数
value:一个tensor
use_locking: If True, use locking during the assignment.
tf.Variable.assign_add(delta, use_locking=False)
给一个变量加上一个值
#参数
delta:tensor,要加的值
use_locking: If True, use locking during the assignment.
tf.Variable.assign_sub(delta, use_locking=False)
从一个变量中减去一个值
#参数
delta:tensor,要减去的值
use_locking: If True, use locking during the assignment.
#返回
一个tensor,该值为操作完成后的新值
tf.Variable.scatter_sub(sparse_delta, use_locking=False)
Subtracts `IndexedSlices` from this variable.
This is essentially a shortcut for `scatter_sub(self, sparse_delta.indices, sparse_delta.values)`.

# Args:
sparse_delta: `IndexedSlices` to be subtracted from this variable.
use_locking: If `True`, use locking during the operation.

# Returns:
A `Tensor` that will hold the new value of this variable after the scattered subtraction has completed.

# Raises:
ValueError: if `sparse_delta` is not an `IndexedSlices`
tf.Variable.count_up_to(limit)
一个计数器,根据计数条件是否满足控制流程;它有两个主要参数,ref,limit,表示
每次都在 ref 的基础上递增,直到等于 limit;超出限制的话就抛出异常 `OutOfRangeError`.

If no error is raised, the Op outputs the value of the variable before the increment.
This is essentially a shortcut for `count_up_to(self, limit)`.

# Args:
limit: value at which incrementing the variable raises an error.

# Returns:
A `Tensor` that will hold the variable value before the increment. If no other 
Op modifies this variable, the values produced will all be distinct.
-------------------------------------------------------------------------------------------------
例子:
x = tf.Variable(1, name='X', dtype=tf.int32)
init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    for i in range(4):
        print(sess.run(x))
        print(sess.run(x.count_up_to(5)))
#输出
1
1
2
2
3
3
4
4

stack overflow的解释

tf.Variable.eval(session=None)
在会话中,计算并返回此变量的值,这不是一个计算图构造方法,它不会将操作添加到图形中。
This convenience method requires a session where the graph containing this variable has been launched. If no session is passed, the default session is used.

v = tf.Variable([1, 2])
init = tf.initialize_all_variables()

with tf.Session() as sess:
    sess.run(init)
    # Usage passing the session explicitly.
    print v.eval(sess)
    # Usage with the default session.  The 'with' block
    # above makes 'sess' the default session.
    print v.eval()
#参数
session:用于评估此变量的会话;如果没有,则使用默认会话
#返回:
该变量的一个副本————numpy ndarray
其他属性
 `tf.Variable.name`
The name of this variable.
-------------------------------------------------------------------------
`tf.Variable.dtype`
The `DType` of this variable.
-------------------------------------------------------------------------
`tf.Variable.get_shape()`
The `TensorShape` of this variable.
# Returns:
A `TensorShape`.
-------------------------------------------------------------------------
 `tf.Variable.device`
The device of this variable.
-------------------------------------------------------------------------
 `tf.Variable.initializer`
The initializer operation for this variable.
-------------------------------------------------------------------------
 `tf.Variable.graph`
The `Graph` of this variable.
-------------------------------------------------------------------------
 `tf.Variable.op`
The `Operation` of this variable.

Variable helper functions

tensorflow也提供了一些列函数来管理计算图中的变量集合

tf.all_variables()

Variable()构造函数会自动地将变量添加到计算图集合中,默认是raphKeys.VARIABLES,而该函数可以非常方便的返回集合的内容
返回值:Variable对象列表

tf.trainable_variables()

返回所有可训练的参数变量(也就是设置了trainable=True的变量)
此类变量会被添加到GraphKeys.TRAINABLE_VARIABLES集合中
返回值:Variable对象列表

tf.initialize_all_variables()

这个就不做介绍了
不知道为什么tf1.7提示我要用老版本。。。。。
Returns an Op that initializes all variables.
This is just a shortcut for `initialize_variables(all_variables())`

#Returns:
An Op that initializes all variables in the graph.

tf.initialize_variables(var_list, name='init')

返回初始化变量列表的Op,该函数的功能是初始化var_list里面的变量并返回op,
可以理解为初始化指定的部分变量;如果var_list为空,代码不报错;
#参数
var_list:需要初始化的Variable对象列表
name:可选的,返回操作的名称
#返回
An Op that run the initializers of all the specified variables.

tf.assert_variables_initialized(var_list=None)

功能:检查变量是否被初始化
如果有任何一个变量未初始化,则返回的op会触发异常
#参数
var_list:待检查的变量对象列表,默认是all_variables().
#返回
An Op, or None if there are no variables.

Saving and Restoring Variables

class tf.train.Saver

保存和重载变量,Saver类将ops保存到checkpoints中并从中重载,Checkpoints是一个二进制格式的文件,主要包含从变量名到tensor值的映射关系
当你创建一个Saver对象时,你可以选择性地为检查点文件中的变量挑选变量名。默认情况下,将每个变量Variable.name属性的值
保存的checkpoint文件可以是多个,并且文件名会自动扩充,so,我们可以在训练的不同阶段进行保存模型的状态;为了避免保存无限个文件,可以设置global_step参数来进行控制,该参数不是必需的

saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'
-----------------------------------------------------------------------------------------------
另外,还提供了一些可选参数来管理如何保存:
max_to_keep:保留的最近检查点文件的最大数量。当新文件被创建时,旧文件被
删除;如果无或0,则保留所有检查点文件,默认为5
-----------------------------------------------------------------------------------------------
keep_checkpoint_every_n_hours:每N个小时保存一个文件,适用于长时间训练
的实验;默认值为10000,也就意味着默认禁止该功能
示例
# Create a saver.
saver = tf.train.Saver(...variables...)
# Launch the graph and train, saving the model every 1,000 steps.
sess = tf.Session()
for step in xrange(1000000):
    sess.run(..training_op..)
    if step % 1000 == 0:
        # Append the step number to the checkpoint name:
        saver.save(sess, 'my-model', global_step=step)

In addition to checkpoint files, savers keep a protocol buffer on disk with the list of recent checkpoints. This is used to manage numbered checkpoint files and by latest_checkpoint(), which makes it easy to discover the path to the most recent checkpoint. That protocol buffer is stored in a file named 'checkpoint' next to the checkpoint files.

If you create several savers, you can specify a different filename for the protocol buffer file in the call to save().

tf.train.Saver.__ init__(var_list=None, reshape=False, sharded=False, max_to_keep=5, keep_checkpoint_every_n_hours=10000.0, name=None, restore_sequentially=False, saver_def=None, builder=None)
构造函数
var_list参数指定要保存和重载的变量,它可以传入一个字典或者列表,来看例子吧:

v1 = tf.Variable(..., name='v1')
v2 = tf.Variable(..., name='v2')

# Pass the variables as a dict:
saver = tf.train.Saver({'v1': v1, 'v2': v2})

# Or pass them as a list.
saver = tf.train.Saver([v1, v2])
# Passing a list is equivalent to passing a dict with the variable op names
# as keys:
saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})
参数
var_list:要保存和重载的变量,它可以传入一个字典或者列表
-----------------------------------------------------------------------------------------------------
reshape:允许变量以不同的shape保存和加载
-----------------------------------------------------------------------------------------------------
sharded:The optional sharded argument, if True, instructs the saver to 
shard checkpoints per device.
-----------------------------------------------------------------------------------------------------
max_to_keep:保留的最近检查点文件的最大数量
-----------------------------------------------------------------------------------------------------
keep_checkpoint_every_n_hours:每N个小时保存一个文件,适用于长时间训练
的实验;默认值为10000,也就意味着默认禁止该功能
-----------------------------------------------------------------------------------------------------
name:syring类型,Optional name to use as a prefix when adding operations.
-----------------------------------------------------------------------------------------------------
restore_sequentially: A Bool, which if true, causes restore of different variables 
to happen sequentially within each device. This can lower memory usage when restoring very large models.
saver_def: Optional SaverDef proto to use instead of running the builder. This 
is only useful for specialty code that wants to recreate a Saver object for a previously built Graph that had a Saver. The saver_def proto should be the one returned by the as_saver_def() call of the Saver that was created for that Graph.
builder: Optional SaverBuilder to use if a saver_def was not provided. Defaults 
to BaseSaverBuilder().

Raises:
TypeError: If `var_list` is invalid.
ValueError: If any of the keys or values in `var_list` is not unique.

tf.train.Saver.save(sess, save_path, global_step=None, latest_filename=None)

#Args:
sess: A Session to use to save the variables.
--------------------------------------------------------------------------------------
save_path: string. Path to the checkpoint filename. If the saver is `sharded`, this 
is the prefix of the sharded checkpoint filename.
--------------------------------------------------------------------------------------
global_step: If provided the global step number is appended to `save_path` 
to create the checkpoint filename. The optional argument can be a Tensor, a 
Tensor name or an integer.
--------------------------------------------------------------------------------------
latest_filename: Optional name for the protocol buffer file that will contains the list
 of most recent checkpoint filenames. That file, kept in the same directory as 
the checkpoint files, is automatically managed by the saver to keep track of 
recent checkpoints. Defaults to 'checkpoint'.
#返回
checkpoint文件的路径

tf.train.Saver.restore(sess, save_path)
加载之前保存的变量,这些变量不需要初始化,因为重载变量行为本身就是一种初始化
Other utility methods

Sharing Variables

重点:

  • tf.get_variable(name, shape=None, dtype=tf.float32, initializer=None, trainable=True, collections=None)
  • tf.variable_scope(name_or_scope, reuse=None, initializer=None)
    讲的很详细
    随后介绍tf.name_scope(),比较一下两者的区别·······························待续
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 199,711评论 5 468
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 83,932评论 2 376
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 146,770评论 0 330
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 53,799评论 1 271
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 62,697评论 5 359
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,069评论 1 276
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,535评论 3 390
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,200评论 0 254
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,353评论 1 294
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,290评论 2 317
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,331评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,020评论 3 315
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,610评论 3 303
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,694评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 30,927评论 1 255
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 42,330评论 2 346
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 41,904评论 2 341

推荐阅读更多精彩内容