cifar10 官方例子详解

本文按照程序执行的顺序进行详解。执行: python cifar10_train.py 进行训练:

1、首先进入 cifar10_train.py 的 main() 函数。先 调用 cifar10.py 的 maybe_download_and_extract() 下载数据

-------- cifar10.maybe_download_and_extract() --------

2、maybe_download_and_extract() 中先拼接出数据文件的路径 filepath (exp: /tmp/cifar10_data/cifar-10-binary.tar.gz)

3、初始调用时数据文件不存在需要下载,调用 urllib.request.urlretrieve 进行下载

4、调用 tarfile.open(filepath, 'r:gz').extractall(dest_directory) 对下载的压缩包进行解压,后返回  cifar10_train.py 的 main() 函数

-------- main() --------

5、重置 train_dir:如果存在,则先删掉,再创建;通过调用 tensorflow.python.platform 中的 gfile 里的 Exists 、 DeleteRecursively 、 MakeDirs 方法来实现

6、开始训练,进入 train()

-------- train() --------

7、调用 tf.Graph().as_default() 创建一个图,并作为以下所有操作默认的图。通过 with,将以下所有的操作都限定在该图中;

8、创建变量 global_step ,初始化为 0 ,后续作为 train_op 操作的输入参数

9、调用 cifar10.distorted_inputs() 获取 images (128, 24, 24, 3) 和 labels (128,)

-------- cifar10.distorted_inputs() --------

10、得到数据所在目录 data_dir (exp: /tmp/cifar10_data/cifar-10-batches-bin),再调用 cifar10_input.distorted_inputs(data_dir,batches_size) 返回数据

-------- cifar10_input.distorted_inputs() --------

11、得到数据文件路径数组 filenames ,包含有 5 个数据文件的路径:

filenames = ['/tmp/cifar10_data/cifar-10-batches-bin/data_batch_1.bin',

                    '/tmp/cifar10_data/cifar-10-batches-bin/data_batch_2.bin',

                    '/tmp/cifar10_data/cifar-10-batches-bin/data_batch_3.bin',

                    '/tmp/cifar10_data/cifar-10-batches-bin/data_batch_4.bin',

                    '/tmp/cifar10_data/cifar-10-batches-bin/data_batch_5.bin']

12、filename_queue = tf.train.string_input_producer(filenames) 生成一个文件队列对象,然后将该队列传入 read_cifar10(filename_queue) 读取数据

-------- cifar10_input.read_cifar10() --------

13、定义要返回的对象 result = CIFAR10Record() 

result.height = 32   #图片高度

result.width = 32    #图片宽度

result.depth = 3    #图片深度,RGB 三色

13、cifar10 的数据文件为二进制文件,其中每个记录的长度是固定的,1 个字节的标签,然后 3072 字节的图像数据。 record_bytes 即为每个记录的长度 —— 1+3072=3073 ,通过 FixedLengthRecordReader(record_bytes=record_bytes) 生成一个阅读操作器 reader

14、result.key, value = reader.read(filename_queue) 给 reader 传入I/O类型的参数filename_queue,返回一个 tensor(我们现在写的这些读取的代码,仅仅是在画 graph,在操作 run 执行并不会真的执行。仅代表 graph 中的一个节点)

15、record_bytes = tf.decode_raw(value, tf.uint8) 操作将一个字符串转化为一个 unit8 张量

16、将张量 record_bytes 中的第一个字符——标签取出,转化为 int32 类型,赋值给 result.label = tf.cast(tf.slice(record_bytes, [0], [label_bytes]), tf.int32)

17、再从张量 record_bytes 中取出第二部分——图片数据,原始图片数据是 [depth, height, width] ,需要通过 tf.transpose 函数转化为 [height, width, depth] 。此时可得到一张 unit8 格式的 [height, width, depth] 图片矩阵,赋值给 result.uint8image 

18、最终返回的 result 是一个 CIFAR10Record 对象,包含:

    height: 图片高度

    width: 图片宽度

    depth: 图片通道

    key: 描述 filename 和 record number 的 Tensor

    label: a int32 Tensor with the label in the range 0..9.

    uint8image: a [height, width, depth] uint8 Tensor with the image data

-------- cifar10_input.distorted_inputs() --------

19、继续对返回的数据进行处理、预处理。首先是 将图片格式从 unit8 转为 float32

        reshaped_image = tf.cast(read_input.uint8image, tf.float32)

20、对图片进行扩充,通过 随机裁剪 —— tf.random_crop 、左右翻转 —— tf.image.random_flip_left_right 、亮度变化 —— tf.image.random_brightness 、对比度变化 —— tf.image.random_contrast 、归一化处理 —— tf.image.per_image_standardization

21、再调用 _generate_image_and_label_batch ,将多个图片 Tensor 合并成 batch-Tensor 

-------- cifar10_input._generate_image_and_label_batch() --------

22、通过调用 tf.train.shuffle_batch 对样本进行乱序批处理,大致原理是,将样本的 Tensor 按顺序压到一个队列 RandomShuffleQueue 中,直到样本个数达到 capacity ,然后需要的时候随机从中取出 batch_size 个样本

images, label_batch = tf.train.shuffle_batch(

        [image, label],

        batch_size=batch_size,

        num_threads=num_preprocess_threads,

        capacity=min_queue_examples + 3 * batch_size,

        min_after_dequeue=min_queue_examples)

23、最后通过 tf.summary.image('images', images) 将 images 保存到 tensorflow board 中

24、最后返回的是 一个 batch_size 的 images 和 labels 的 Tensor ,返回给 cifar10_input.distorted_inputs() ,再返回给 cifar10.distorted_inputs() ,再返回给 cifar10_train.train()

-------- 数据处理结束 --------

-------- cifar10_train.train() --------

25、数据处理之后,就是要在 graph 中增加 关于神经网络 model 的 操作,调用 cifar10.inference(images)

-------- cifar10.inference() --------

26、在 该 model 中,共有 conv1 、pool1 、norm1 、 conv2 、pool2 、norm2 、 local3 、 local4 、 softmax 层,最终返回 softmax 层。

27、conv1 层,先创建一个 scope ,主要是对改成的变量进行统一命名:

        with tf.variable_scope('conv1') as scope:

28、conv1 层, 再调用 _variable_with_weight_decay 初始化卷积核参数

-------- cifar10._variable_with_weight_decay() --------

29、初始化参数 ,初始化使用 tf.truncated_normal_initializer 即截断正太分布,使用 tf.get_variable 来初始化参数,得到初始化参数 var ,并且在 wd > 0 时,会增加 L2范式稀疏化: weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss') ,wd 为衰减系数。然后使用 tf.add_to_collection('losses',weight_decay) 将 weight_decay 作为以 losses 为标签进行收集。这个例子里,只有全连接层对稀疏性有需求

-------- cifar10.inference() --------

30、conv1 层,先调用 tf.nn.conv2d ,再调用 tf.nn.bias_add ,最后调用 tf.nn.relu 生成第一个卷积层操作。该卷积层输入层数为 3,输出层数为 64

31、调用 _activation_summary 将 conv1 层输出到 tensorboard 中

32、pool1 层,通过 tf.nn.max_pool 生成池化层

33、norm1 层,通过 tf.nn.lrn ,对第一个 pooling 层进行局部响应归一化

34、conv2、pool2、norm2 类似

35、local3 层,为全连接层,高度为 384,将 pool2 展开成1纬,然后乘以权重、加上偏移,激活函数用 relu ,调用 _activation_summary 将 local3层输出到 tensorboard 中

36、local4 层,为全连接层,高度为 192,类似 local3

37、softmax 层,输出层,最终将该层返回

-------- 创建模型结束 --------

38、接下来定义 loss 操作,进入 cifar10.loss(logits, labels) ,输入 返回输出层 和 标准输出

-------- cifar10.loss() --------

39、传入的 labels 是 (batch_size,) ,需要转化为 (batch_size, 10),即将 labels = [3,5] 转化成 dense_labels = [[0,0,0,1,0,0,0,0,0,0],[0,0,0,0,1,0,0,0,0,0]]。需要先构造出 concated = [[0,3],[1,5]],代表在最终的 (batch_size, 10) 矩阵中 1 所在的坐标,然后再通过 tf.sparse_to_dense 得到最终的矩阵。其中 concated 可以通过 tf.concat([indices, sparse_labels],1) 得到,其中 indices 是 [[0],[1]] ,sparse_labels 是 [[3],[5]] 

40、计算交叉熵: tf.nn.softmax_cross_entropy_with_logits ,得到的是 [batch_size] 的张量,再调用 tf.reduce_mean 求平均,得到 cross_entropy_mean 平均交叉熵

41、最终将 cross_entropy_mean 跟 L2 的范式部分相加得到总的损失函数,并返回

-------- 创建损失函数结束 --------

42、接下来定义 train 操作,进入 train(total_loss, global_step) ,输入 总损失函数 和 step

-------- cifar10.train() --------

43、首先调用 lr = tf.train.exponential_decay 生成一个随着 steps 指数衰减的 learning_rate 

44、opt = tf.train.GradientDescentOptimizer(lr) 生成梯度递减操作 opt

45、grads = opt.compute_gradients(total_loss) 

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 204,530评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 86,403评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 151,120评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,770评论 1 277
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,758评论 5 367
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,649评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,021评论 3 398
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,675评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,931评论 1 299
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,659评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,751评论 1 330
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,410评论 4 321
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,004评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,969评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,203评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,042评论 2 350
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,493评论 2 343

推荐阅读更多精彩内容