Tensorflow 数据读取

TF官网上给出了三种读取数据的方式:

  1. Preloaded data: 预加载数据
  2. Feeding: Python 产生数据,再把数据喂给后端
  3. Reading from file:从文件中直接读取
    (Ps: 此处参考博客 详解TF数据读取有三种方式(next_batch))
    (Pps: 文中的代码均基于Python3.6版本)

TF的核心是用C++写的,运行快,但是调用不灵活。结合Python和TF,将计算的核心算子和运行框架用C++写,然后以API的形式提供给Python调用。Python的主要工作是设计计算图(模型及数据),将设计好的Graph提供给后端执行。简而言之,TF是Run,Pyhton的角色是Design。

一. Preloaded Data

  • constant,常量
  • variable,初始化或者后面更新均可

这种数据读取方式只适合小数据,通常在程序中定义某固定值,如循环次数等,而很少用来读取训练数据。

import tensorflow as tf
# 设计Graph
a = tf.constant([1, 2, 3])
b = tf.Variable([1, 2, 4])
c = tf.add(a, b)


二. Feeding

Feeding的方式在设计Graph的时候留占位符,在真正Run的时候向占位符中传递数据,喂给后端训练。

#!/usr/bin/env python3
# _*_coding:utf-8 _*_

import tensorflow as tf
# 设计Graph
a = tf.placeholder(tf.int16) 
b = tf.placeholder(tf.int16)
c = tf.add(a, b)
# 用Python产生数据
li1 = [2, 3, 4] # li1:<type:'list'>: [2, 3, 4]
li2 = [4, 0, 1]
# 打开一个session --> 喂数据 --> 计算y
with tf.Session() as sess:
  print(sess.run(c, feed_dict={a: li1, b: li2})) # [6, 3, 5]

这里tf.placeholder代表占位符,先定一下变量a的类型。在实际运行的时候,通过feed_dict来指定a在计算中的实际值。

这种数据读取方式非常灵活,而且易于理解,但是在读取大数据时会非常吃力。



三. Read from file

官网上给出的例子是从csv等文件中读取数据,这里都会涉及到队列的概念, 我们首先简单介绍一下Queue读取数据的原理,便于后面代码的理解。(参考 Blog

读取数据其实是为了后续的计算,以图片为例,假设我们的硬盘中有一个图片数据集0001.jpg,0002.jpg,0003.jpg……我们只需要把它们读取到内存中,然后提供给GPU或是CPU进行计算就可以了。这听起来很容易,但事实远没有那么简单。事实上,我们必须要把数据先读入后才能进行计算,假设读入用时0.1s,计算用时0.9s,那么就意味着每过1s,GPU都会有0.1s无事可做,这就大大降低了运算的效率。

队列的存在就是为了使计算的速度不完全受限于数据读取的速度,保证有足够多的数据喂给计算。如图所示,将数据的读入和计算分别放在两个线程中,读入的数据保存为内存中的一个队列,负责计算的线程可以源源不断地从内存队列中读取数据。这样就解决了GPU因为IO而空闲的问题。

Tensorflow中在内存队列之前又添加了一个文件名队列,这是因为机器学习中一般会设定epoch。对于一个数据集来说,运行一个epoch就是将这个数据集中的样本数据全部计算一遍。如图所示,当数据集结束后可以做一个标注,以此来告诉计算机这个epoch结束了。

文件名队列,我们用tf.train.string_input_producer()函数创建文件名队列。

tf.train.string_input_producer(
    string_tensor,     # 文件名列表
    num_epochs=None,   # epoch的个数,None代表无限循环
    shuffle=True,      # 一个epoch内的样本(文件)顺序是否打乱
    seed=None,         # 当shuffle=True时才用,应该是指定一个打乱顺序的入口
    capacity=32,       # 设置队列的容量
    shared_name=None,
    name=None,
    cancel_op=None)

ps: 在Tensorflow中,内存队列不需要我们自己建立,后续只需要使用reader从文件名队列中读取数据就可以。

tf.train.string_input_produecer()会将一个隐含的QueueRunner添加到全局图中(类似的操作还有tf.train.shuffle_batch()等)。由于没有显式地返回QueueRunner()来调用create_threads()启动线程,这里使用了tf.train.start_queue_runners()方法直接启动tf.GraphKeys.QUEUE_RUNNERS集合中的所有队列线程。

在我们使用tf.train.string_input_producer创建文件名队列后,整个系统其实还是处于“停滞状态”的,也就是说,我们文件名并没有真正被加入到队列中(如下图所示)。此时如果我们开始计算,因为内存队列中什么也没有,计算单元就会一直等待,导致整个系统被阻塞。

而使用tf.train.start_queue_runners()之后,才会启动填充队列的线程,这时系统就不再“停滞”。此后计算单元就可以拿到数据并进行计算,整个程序也就跑起来了,这就是函数tf.train.start_queue_runners的用处。

在读取文件的整个过程中会涉及到:

  • 文件名队列创建: tf.train.string_input_producer()
  • 文件阅读器: tf.TFRecordReader()
  • 文件解析器:tf.parse_single_example() 或者decode_csv()
  • Batch_size:tf.train.shuffle_batch()
  • 填充进程:tf.train.start_queue_runners()

下面我们用python生成数据,并将数据转换成tfrecord格式,然后读取tfrecord文件。在这过程中,我们会介绍几种不同的从文件读取数据的方法。

生成数据:

#!/usr/bin/env python3 
# _*_coding:utf-8 _*_

import os
import numpy as np
'''
二分类问题,样本数据是形如1,2,5,8,9(1*5)的随机数,对应标签是0或1
arg:
    data_filename: 路径下的文件名 'data/data_train.txt'
    size: 设定生成样本数据的size=(10000, 5),其中10000是样本个数,5是单个样本的特征。
'''
gene_data = 'data/data_train.txt'
size = (100000, 5)
def generate_data(gene_data, size):
    if not os.path.exists(gene_data):
        np.random.seed(9)
        x_data = np.random.randint(0, 10, size=size)
        # 这里设置标签值一半样本是0,一半样本是1
        y1_data = np.ones((size[0]//2, 1), int) # 这里需要注意python3和python2的区别。
        y2_data = np.zeros((size[0]//2, 1), int) # python2用/得到整数,python3要用//。否则会报错“'float' object cannot be interpreted as an integer”
        y_data = np.append(y1_data, y2_data)
        np.random.shuffle(y_data)

        # 将样本和标签以1 2 3 6 8/1的形式来保存
        xy_data = str('')
        for xy_row in range(len(x_data)):
            x_str = str('')
            for xy_col in range(len(x_data[0])):
                if not xy_col == (len(x_data[0])-1):
                    x_str =x_str+str(x_data[xy_row, xy_col])+' '
                else:
                    x_str = x_str + str(x_data[xy_row, xy_col])
            y_str = str(y_data[xy_row])
            xy_data = xy_data+(x_str+'/'+y_str + '\n')
        #print(xy_data[1])

        # write to txt 保存成txt格式
        write_txt = open(gene_data, 'w')
        write_txt.write(xy_data)
        write_txt.close()
    return
# generate_data(gene_data=gene_data, size=size) # 取消注释后可以直接生成数据

从txt文件中读取数据,并转换成TFrecord格式

tfrecord数据文件是一种将数据和标签统一存储的二进制文件,能更好的利用内存,在tensorflow中快速的复制,移动,读取,存储等。

TFRecord 文件中的数据是通过 tf.train.Example() 以 Protocol Buffer(协议缓冲区) 的格式存储。Protocol Buffer是Google的一种数据交换的格式,他独立于语言,独立于平台,以二进制的形式存在,能更好的利用内存,方便复制和移动。
tf.train.Example()包含Features字段,通过feature将数据和label进行统一封装, 然后将example协议内存块转化为字符串。tf.train.Features()是字典结构,包括字符串格式的key,可以自己定义key。与key对应的是value值,这里需要注意的是,feature的value值只支持列表,可以是字符串(Byteslist),浮点数列表(Floatlist)和整型数列表(int64list),所以,在给value赋值时一定要注意类型将数据转换为这三种类型的列表。

  • 类型为标量:如0,1标签,转为列表。 tf.train.Int64List(value=[label])
  • 类型为数组:sample = [1, 2, 3],tf.train.Int64List(value=sample)
  • 类型为矩阵:sample = [[1, 2, 3], [1, 2 ,3]],
    两种方式:
    转成list类型:将张量fatten成list(向量)
    转成string类型:将张量用.tostring()转换成string类型。
    同时要记得保存形状信息,在读取后恢复shape。
'''
读取txt中的数据,并将数据保存成tfrecord文件
arg:
    txt_filename: 是txt保存的路径+文件名 'data/data_train.txt'
    tfrecord_path:tfrecord文件将要保存的路径及名称 'data/test_data.tfrecord'
'''
def txt_to_tfrecord(txt_filename=gene_data, tfrecord_path=tfrecord_path):
    # 第一步:生成TFRecord Writer
    writer = tf.python_io.TFRecordWriter(tfrecord_path)

    # 第二步:读取TXT数据,并分割出样本数据和标签
    file = open(txt_filename)
    for data_line in file.readlines(): # 每一行
        data_line = data_line.strip('\n') # 去掉换行符
        sample = []
        spls = data_line.split('/', 1)[0]# 样本
        for m in spls.split(' '):
            sample.append(int(m))
        label = data_line.split('/', 1)[1]# 标签
        label = int(label)
        # print('sample:', sample, 'labels:', label)

        # 第三步: 建立feature字典,tf.train.Feature()对单一数据编码成feature
        feature = {'sample': tf.train.Feature(int64_list=tf.train.Int64List(value=sample)),
                   'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))}
        # 第四步:可以理解为将内层多个feature的字典数据再编码,集成为features
        features = tf.train.Features(feature = feature)
        # 第五步:将features数据封装成特定的协议格式
        example = tf.train.Example(features=features)
        # 第六步:将example数据序列化为字符串
        Serialized = example.SerializeToString()
        # 第七步:将序列化的字符串数据写入协议缓冲区
        writer.write(Serialized)
    # 记得关闭writer和open file的操作
    writer.close()
    file.close()
    return
# txt_to_tfrecord(txt_filename=gene_data, tfrecord_path=tfrecord_path)

所以在上面的程序中我们涉及到了读取txt文本数据,并将数据写成tfrecord文件。在网络训练过程中数据的读取通常是对tfrecord文件的操作。

TF读取tfrecord文件有两种方式:一种是Queue方式,就是上面介绍的队列,另外一种是用dataset来读取。先介绍Queue读取文件数据的方法

1. Queue方式

Queue读取数据可以分为两种:tf.parse_single_example()和tf.parse_example()

(1). tf.parse_single_example()读取数据

tf.parse_single_example(
    serialized,  # 张量
    features,  # 对应写入的features
    name=None,
    example_names=None)
'''
用tf.parse_single_example()读取并解析tfrecord文件
args: 
      filename_queue: 文件名队列
      shuffle_batch: 判断在batch的时候是否要打乱顺序
      if_enq_many: 设定batch中的参数enqueue_many,评估该参数的作用
'''
# 第一步: 建立文件名队列,可设置Epoch次数
filename_queue = tf.train.string_input_producer([tfrecord_path], num_epochs=3)

def read_single(filename_queue, shuffle_batch, if_enq_many):
    # 第二步: 建立阅读器
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    # 第三步:根据写入时的格式建立相对应的读取features
    features = {
        'sample': tf.FixedLenFeature([5], tf.int64),# 如果不是标量,一定要在这里说明数组的长度
        'label': tf.FixedLenFeature([], tf.int64)
    }
    # 第四步: 用tf.parse_single_example()解析单个EXAMPLE PROTO
    Features = tf.parse_single_example(serialized_example, features)

    # 第五步:对数据进行后处理
    sample = tf.cast(Features['sample'], tf.float32)
    label = tf.cast(Features['label'], tf.float32)
    # 第六步:生成Batch数据 generate batch
    if shuffle_batch:  # 打乱数据顺序,随机取样
        sample_single, label_single = tf.train.shuffle_batch([sample, label],
                                                 batch_size=2,
                                                 capacity=200000,
                                                 min_after_dequeue=10000,
                                                 num_threads=1,
                                                 enqueue_many=if_enq_many)# 主要是为了评估enqueue_many的作用
    else:  # # 如果不打乱顺序则用tf.train.batch(), 输出队列按顺序组成Batch输出
        sample_single, label_single = tf.train.batch([sample, label],
                                                batch_size=2,
                                                capacity=200000,
                                                min_after_dequeue=10000,
                                                num_threads=1,
                                                enqueue_many = if_enq_many)
    return sample_single, label_single
x1_samples, y1_labels = read_single(filename_queue=filename_queue, 
shuffle_batch=False, if_enq_many=False)
x2_samples, y2_labels = read_single(filename_queue=filename_queue, 
shuffle_batch=True, if_enq_many=False)
print(x1_samples, y1_labels) # 因为是tensor,这里还处于构造tensorflow计算图的过程,输出仅仅是shape等,不会是具体的数值。
# 如果想得到具体的数值,必须建立session,是tensor在计算图中流动起来,也就是用session.run()的方式得到具体的数值。
# 定义初始化变量范围
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)  # 初始化
    # 如果tf.train.string_input_producer([tfrecord_path], num_epochs=3)中num_epochs不为空的化,必须要初始化local变量
    sess.run(tf.local_variables_initializer())
    coord = tf.train.Coordinator()  # 管理线程
    threads = tf.train.start_queue_runners(coord=coord)  # 文件名开始进入文件名队列和内存
    for i in range(1):
        # Queue + tf.parse_single_example()读取tfrecord文件
        X1, Y1 = sess.run([x1_samples, y1_labels])
        print('X1: ', X1, 'Y1: ', Y1) # 这里就可以得到tensor具体的数值
        X2, Y2 = sess.run([x2_samples, y2_labels])
        print('X2: ', X2, 'Y2: ', Y2) # 这里就可以得到tensor具体的数值
    coord.request_stop()
    coord.join(threads)

Ps: 如果建立文件名tf.train.string_input_producer([tfrecord_path], num_epochs=3)时, 设置num_epochs为具体的值(不是None)。在初始化的时候必须对local_variables进行初始化sess.run(tf.local_variables_initializer())。否则会报错:
OutOfRangeError (see above for traceback): RandomShuffleQueue '_1_shuffle_batch/random_shuffle_queue' is closed and has insufficient elements (requested 2, current size 0)

上面第六步batch前取到的是单个样本数据,在实际训练中通常用批量数据来更新参数,设置批量读取数据的时候有按顺序读取数据的tf.train.batch()和打乱数据出列顺序的tf.train.shuffle_batch()。假设文本中的数据如图所示:

设置batch_size=2, shuffle_batch=True和False时的输出分别为:

X11:  [[5. 6. 8. 6. 1.] [6. 4. 8. 1. 8.]] Y11:  [1. 1.] #用tf.train.batch()
X21:  [[0. 4. 3. 7. 8.] [5. 0. 2. 8. 7.]] Y21:  [0. 1.] # 用tf.train.shuffle_batch()

这里需要对tf.train.shuffle_batch()和tf.train.batch()的参数进行说明

tf.train.shuffle_batch(
    tensors,
    batch_size, # 设置batch_size的大小
    capacity,  # 设置队列中最大的数据量,容量。一般要求capacity > min_after_dequeue + num_threads*batch_size
    min_after_dequeue, # 队列中最小的数据量作为随机取样的缓冲区。越大,数据混合越充分,认为采样到的数据更具有随机性。
    # 但是这个值设置太大在初始启动时,需要给队列喂足够多的数据,启动慢,而且占用内存。
    num_threads=1, # 设置线程数
    seed=None,
    enqueue_many=False, # Whether each tensor in tensor_list is a single example. 在下面单独说明
    shapes=None,
    allow_smaller_final_batch=False, # (Optional) Boolean. If True, allow the final batch to be smaller if there are insufficient items left in the queue.
    shared_name=None,
    name=None)
tf.train.batch(
    tensors,
    batch_size,
    num_threads=1,
    capacity=32,
    enqueue_many=False,
    shapes=None,
    dynamic_pad=False,
    allow_smaller_final_batch=False,
    shared_name=None,
    name=None)  # 注意:这里没有min_after_dequeue这个参数

读取数据的目的是为了训练网络,而使用Batch训练网络的原因可以解释为:

深度学习的优化说白了就是梯度下降。每次的参数更新有两种方式。

  • 第一种,遍历全部数据集算一次损失函数,然后算函数对各个参数的梯度,更新梯度。这种方法每更新一次参数都要把数据集里的所有样本都看一遍,计算量开销大,计算速度慢,不支持在线学习,这称为Batch gradient descent,批梯度下降。
  • 另一种,每看一个数据就算一下损失函数,然后求梯度更新参数,这个称为随机梯度下降,stochastic gradient descent。这个方法速度比较快,但是收敛性能不太好,可能在最优点附近晃来晃去,hit不到最优点。两次参数的更新也有可能互相抵消掉,造成目标函数震荡的比较剧烈。
    为了克服两种方法的缺点,现在一般采用的是一种折中手段,mini-batch gradient decent,小批的梯度下降,这种方法把数据分为若干个批,按批来更新参数,这样,一个批中的一组数据共同决定了本次梯度的方向,下降起来就不容易跑偏,减少了随机性。另一方面因为批的样本数与整个数据集相比小了很多,计算量也不是很大。

个人理解:大Batch_size一是会受限于计算机硬件,另一方面将会降低梯度下降的随机性。 而小Batch_size收敛速度慢

这里用代码对enqueue_many这个参数进行理解

# -*- coding: utf-8 -*-

import tensorflow as tf
import numpy as np

tensor_list = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]

with tf.Session() as sess:
    x1 = tf.train.batch(tensor_list, batch_size=3, enqueue_many=False)
    x2 = tf.train.batch(tensor_list, batch_size=3, enqueue_many=True)
    x3 = tf.train.shuffle_batch(tensor_list, batch_size=3, capacity = 1000, min_after_dequeue=100, num_threads=1, enqueue_many=False)
    x4 = tf.train.shuffle_batch(tensor_list, batch_size=3, capacity = 1000, min_after_dequeue=100, num_threads=1, enqueue_many=True)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    print("x1 batch:" + "-" * 10)
    print(sess.run(x1))

    print("x2 batch:" + "-" * 10)
    print(sess.run(x2))
    print("x2 batch:" + "-" * 10)
    print(sess.run(x2))

    print("x3 batch:" + "-" * 10)
    print(sess.run(x3))

    print("x4 batch:" + "-" * 10)
    print(sess.run(x4))

    coord.request_stop()
    coord.join(threads)

输出如下:

由以上输出可以看出,当enqueue_many=False(默认值)时,输出为batch_size*tensor.shape,把输入tensors看作一个样本,Batch就是对第一个维度的数据进行重复采样,将tensor扩展一个维度。
当enqueue_many=True时,tensor是一个样本,batch_size只是调整样本中的维度。这里tensor的维度保持不变,只是在最后一个维度上根据batch_size调整了大小。而最后一个维度内的顺序是乱序的。
对于shuffle_batch,注意到,第1维(矩阵每一行)上的数据是打乱的,所以从[1, 2, 3, 4]中取到了[2, 4, 4]。
如果输入的样本是一个3x6的矩阵。设置batch_size=5,enqueue_many = False时,tensor会被扩展为3x6x5的张量, 并且。当enqueue_many = True时,tensor是3x5,第二个维度上截取size。
这里比较疑惑的是shuffle在这里感觉没有任何作用???

(2). tf.parse_example()读取数据

'''
用tf.parse_example()批量读取数据,据说比tf.parse_single_exaple()读取数据的速度快(没有验证)
args:
      filename_queue: 文件名队列
      shuffle_batch: 是否批量读取数据
      if_enq_many: batch时enqueue_many参数的设定,这里主要用于评估该参数的作用
'''
# 第一步: 建立文件名队列
filename_queue = tf.train.string_input_producer([tfrecord_path])
def read_parse(filename_queue, shuffle_batch, if_enq_many):
    # 第二步: 建立阅读器
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    # 第三步: 设置shuffle_batch
    if shuffle_batch:
        batch = tf.train.shuffle_batch([serialized_example],
                               batch_size=3,
                               capacity=10000,
                               min_after_dequeue=1000,
                               num_threads=1,
                               enqueue_many=if_enq_many)# 主要是为了评估enqueue_many的作用

    else:
        batch = tf.train.batch([serialized_example],
                               batch_size=3,
                               capacity=10000,
                               num_threads=1,
                               enqueue_many=if_enq_many)
        # 第四步:根据写入时的格式建立相对应的读取features
    features = {
        'sample': tf.FixedLenFeature([5], tf.int64),  # 如果不是标量,一定要在这里说明数组的长度
        'label': tf.FixedLenFeature([], tf.int64)
    }
    # 第五步: 用tf.parse_example()解析多个EXAMPLE PROTO
    Features = tf.parse_example(batch, features)

    # 第六步:对数据进行后处理
    samples_parse= tf.cast(Features['sample'], tf.float32)
    labels_parse = tf.cast(Features['label'], tf.float32)
    return samples_parse, labels_parse

x2_samples, y2_labels = read_parse(filename_queue=filename_queue, shuffle_batch=True, if_enq_many=False)
print(x2_samples, y2_labels)
# 定义初始化变量范围
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)  # 初始化
    coord = tf.train.Coordinator()  # 管理线程
    threads = tf.train.start_queue_runners(coord=coord)  # 文件名开始进入文件名队列和内存
    for i in range(1):
        X2, Y2 = sess.run([x2_samples, y2_labels])
        print('X2: ', X2, 'Y2: ', Y2)

    coord.request_stop()
    coord.join(threads)

调试的时候这里碰到一个bug,提示:return处local variable 'samples_parse' referenced before assignment。网上给的解决办法基本是python在自上而下执行的时候无法区分变量是全局变量还是局部变量。实际上是我在写第四步/第五步的时候多了缩进,导致没有定义features。(⚠️:python对缩进敏感)

⚠️ 阅读器 + 样本

根据以上例子,假设txt中的数据只有2个样本,如下图所示:

在建立文件名队列时,加入这两个txt文档的文件名

# 第一步: 建立文件名队列
filename_queue = tf.train.string_input_producer([tfrecord_path, tfrecord_path1])

(1). 单个阅读器 + 单个样本

batch_size=1 (注意:这里先将num_threads设置为1)

sample_single, label_single = tf.train.batch([sample, label],
                                                 batch_size=1,
                                                 capacity=10000,     
                                                 num_threads=1,
                                                 enqueue_many=if_enq_many)
    for i in range(5):
        X14, Y14 = sess.run([x14_samples, y14_labels])
        print('X14: ', X14, 'Y14: ', Y14)

打印输出结果为:

('X14: ', array([[8., 2., 6., 8., 1.]], dtype=float32), 'Y14: ', array([0.], dtype=float32))
('X14: ', array([[8., 3., 5., 3., 6.]], dtype=float32), 'Y14: ', array([0.], dtype=float32))
('X14: ', array([[5., 6., 8., 6., 1.]], dtype=float32), 'Y14: ', array([1.], dtype=float32))
('X14: ', array([[6., 4., 8., 1., 8.]], dtype=float32), 'Y14: ', array([1.], dtype=float32))
('X14: ', array([[8., 2., 6., 8., 1.]], dtype=float32), 'Y14: ', array([0.], dtype=float32))

(2). 单个阅读器 + 多个样本

batch_size = 3
输出结果为:

('X14: ', array([[8., 2., 6., 8., 1.],[8., 3., 5., 3., 6.],[5., 6., 8., 6., 1.]], dtype=float32), 'Y14: ', array([0., 0., 1.], dtype=float32))
('X14: ', array([[6., 4., 8., 1., 8.],[5., 6., 8., 6., 1.],[6., 4., 8., 1., 8.]], dtype=float32), 'Y14: ', array([1., 1., 1.], dtype=float32))
('X14: ', array([[8., 2., 6., 8., 1.],[8., 3., 5., 3., 6.],[8., 2., 6., 8., 1.]], dtype=float32), 'Y14: ', array([0., 0., 0.], dtype=float32))
('X14: ', array([[8., 3., 5., 3., 6.],[5., 6., 8., 6., 1.],[6., 4., 8., 1., 8.]], dtype=float32), 'Y14: ', array([0., 1., 1.], dtype=float32))
('X14: ', array([[8., 2., 6., 8., 1.],[8., 3., 5., 3., 6.],[5., 6., 8., 6., 1.]], dtype=float32), 'Y14: ', array([0., 0., 1.], dtype=float32))

(3). 多个阅读器 + 多个样本

多阅读器需要用tf.train.batch_join()或者tf.train.shuffle_batch_join(),对程序作稍微的修改

example_list = [[sample, label] for _ in range(2)]  # Reader设置为2
sample_single, label_single = tf.train.batch_join(example_list, batch_size=3)

输出结果为:

('X14: ', array([[5., 6., 8., 6., 1.],[6., 4., 8., 1., 8.],[8., 2., 6., 8., 1.]], dtype=float32), 'Y14: ', array([1., 1., 0.], dtype=float32))
('X14: ', array([[8., 3., 5., 3., 6.],[8., 2., 6., 8., 1.],[8., 3., 5., 3., 6.]], dtype=float32), 'Y14: ', array([0., 0., 0.], dtype=float32))
('X14: ', array([[5., 6., 8., 6., 1.],[6., 4., 8., 1., 8.],[8., 2., 6., 8., 1.]], dtype=float32), 'Y14: ', array([1., 1., 0.], dtype=float32))
('X14: ', array([[8., 3., 5., 3., 6.],[5., 6., 8., 6., 1.],[6., 4., 8., 1., 8.]], dtype=float32), 'Y14: ', array([0., 1., 1.], dtype=float32))
('X14: ', array([[8., 2., 6., 8., 1.],[8., 3., 5., 3., 6.],[5., 6., 8., 6., 1.]], dtype=float32), 'Y14: ', array([0., 0., 1.], dtype=float32))

从输出结果来看,单个阅读器+多个样本多个阅读器+多个样本在结果呈现时并没有什么区别,至于对运行速度的影响还有待验证。

附上对阅读器进行测试的完整代码:

# -*- coding: UTF-8 -*-
# !/usr/bin/python3
# Env: python3.6
import tensorflow as tf
import numpy as np
import os

data_filename1 = 'data/data_train1.txt'  # 生成txt数据保存路径
data_filename2 = 'data/data_train2.txt'  # 生成txt数据保存路径
tfrecord_path1 = 'data/test_data1.tfrecord'  # tfrecord1文件保存路径
tfrecord_path2 = 'data/test_data2.tfrecord'  # tfrecord2文件保存路径

##############################  读取txt文件,并转为tfrecord文件 ###########################
# every line of data is just as follow: 1 2 3 4 5/1. train data: 1 2 3 4 5, label: 1
def txt_to_tfrecord(txt_filename, tfrecord_path):
    # 第一步:生成TFRecord Writer
    writer = tf.python_io.TFRecordWriter(tfrecord_path)

    # 第二步:读取TXT数据,并分割出样本数据和标签
    file = open(txt_filename)
    for data_line in file.readlines():  # 每一行
        data_line = data_line.strip('\n')  # 去掉换行符
        sample = []
        spls = data_line.split('/', 1)[0]  # 样本
        for m in spls.split(' '):
            sample.append(int(m))
        label = data_line.split('/', 1)[1]  # 标签
        label = int(label)

        # 第三步: 建立feature字典,tf.train.Feature()对单一数据编码成feature
        feature = {'sample': tf.train.Feature(int64_list=tf.train.Int64List(value=sample)),
                   'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))}
        # 第四步:可以理解为将内层多个feature的字典数据再编码,集成为features
        features = tf.train.Features(feature=feature)
        # 第五步:将features数据封装成特定的协议格式
        example = tf.train.Example(features=features)
        # 第六步:将example数据序列化为字符串
        Serialized = example.SerializeToString()
        # 第七步:将序列化的字符串数据写入协议缓冲区
        writer.write(Serialized)
    # 记得关闭writer和open file的操作
    writer.close()
    file.close()
    return
txt_to_tfrecord(txt_filename=data_filename1, tfrecord_path=tfrecord_path1)
txt_to_tfrecord(txt_filename=data_filename2, tfrecord_path=tfrecord_path2)


# 第一步: 建立文件名队列
filename_queue = tf.train.string_input_producer([tfrecord_path1, tfrecord_path2])
def read_single(filename_queue, shuffle_batch, if_enq_many):
    # 第二步: 建立阅读器
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    # 第三步:根据写入时的格式建立相对应的读取features
    features = {
        'sample': tf.FixedLenFeature([5], tf.int64),  # 如果不是标量,一定要在这里说明数组的长度
        'label': tf.FixedLenFeature([], tf.int64)
    }
    # 第四步: 用tf.parse_single_example()解析单个EXAMPLE PROTO
    Features = tf.parse_single_example(serialized_example, features)

    # 第五步:对数据进行后处理
    sample = tf.cast(Features['sample'], tf.float32)
    label = tf.cast(Features['label'], tf.float32)

    # 第六步:生成Batch数据 generate batch
    if shuffle_batch:  # 打乱数据顺序,随机取样
        sample_single, label_single = tf.train.shuffle_batch([sample, label],
                                                             batch_size=1,
                                                             capacity=10000,
                                                             min_after_dequeue=1000,
                                                             num_threads=1,
                                                             enqueue_many=if_enq_many)  # 主要是为了评估enqueue_many的作用
    else:  # # 如果不打乱顺序则用tf.train.batch(), 输出队列按顺序组成Batch输出

        ###################### multi reader, multi samples, please code as below     ###############################
        '''
        example_list = [[sample,label] for _ in range(2)]  # Reader设置为2

        sample_single, label_single = tf.train.batch_join(example_list, batch_size=3)
        '''
        #######################  single reader, single sample,  please set batch_size = 1   #########################
        #######################  single reader, multi samples,  please set batch_size = batch_size    ###############
        sample_single, label_single = tf.train.batch([sample, label],
                                                     batch_size=1,
                                                     capacity=10000,
                                                     num_threads=1,
                                                     enqueue_many=if_enq_many)

    return sample_single, label_single

x1_samples, y1_labels = read_single(filename_queue, shuffle_batch=False, if_enq_many=False)

# 定义初始化变量范围
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)  # 初始化
    # 如果tf.train.string_input_producer([tfrecord_path], num_epochs=30)中num_epochs不为空的化,必须要初始化local变量
    sess.run(tf.local_variables_initializer())
    coord = tf.train.Coordinator()  # 管理线程
    threads = tf.train.start_queue_runners(coord=coord)  # 文件名开始进入文件名队列和内存
    for i in range(5):
        # Queue + tf.parse_single_example()读取tfrecord文件
        X1, Y1 = sess.run([x1_samples, y1_labels])
        print('X1: ', X1, 'Y1: ', Y1)
        # Queue + tf.parse_example()读取tfrecord文件

    coord.request_stop()
    coord.join(threads)

2. Dataset + TFrecrods读取数据

这是目前官网上比较推荐的一种方式,相对于队列读取文件的方法,更为简单。
Dataset API:将数据直接放在graph中进行处理,整体对数据集进行上述数据操作,使代码更加简洁

Dataset直接导入比较简单,这里只是简单介绍:

dataset = tf.data.Dataset.from_tensor_slices([1,2,3]) # 输入必须是list

我们重点看dataset读取tfrecord文件的过程 (关于pipeline的相关信息可以参见博客)

def _parse_function(example_proto): # 解析函数
    # 创建解析字典
    dics = {  
        'sample': tf.FixedLenFeature([5], tf.int64),  # 如果不是标量,一定要在这里说明数组的长度
        'label': tf.FixedLenFeature([], tf.int64)}
    # 把序列化样本和解析字典送入函数里得到解析的样本
    parsed_example = tf.parse_single_example(example_proto, dics)
    # 对样本数据类型的变换
    # 这里得到的样本数据都是向量,如果写数据的时候对数据进行过reshape操作,可以在这里根据保存的reshape信息,对数据进行还原。
    parsed_example['sample'] = tf.cast(parsed_example['sample'], tf.float32)
    parsed_example['label'] = tf.cast(parsed_example['label'], tf.float32)

    # 返回所有feature
    return parsed_example
'''
read_dataset:
arg: tfrecord_path是需要读取的tfrecord文件路径,如tfrecord_path = ['test.tfrecord', 'test2.tfrecord'],同上面Queue方式相同,可以同时读取多个文件
'''
def read_dataset(tfrecord_path = tfrecord_path):
    # 第一步:声明 tf.data.TFRecordDataset
    # The tf.data.TFRecordDataset class enables you to stream over the contents of one or more TFRecord files as part of an input pipeline
    dataset = tf.data.TFRecordDataset(tfrecord_path)
    # 第二步:解析样本数据。 tfrecord文件记录的是序列化的样本,因此需要对样本进行解析。
    # 个人理解:这个解析的过程,是通过上面_parse_function函数建立feature的字典。
    # 而dataset.map()是对dataset的统一操作,map操作可以理解为在每一个元素上应用一个函数,所以其输入是一个函数。
    new_dataset = dataset.map(_parse_function)
    # 创建获取数据集中样本的迭代器
    iterator = new_dataset.make_one_shot_iterator()
    # 获得下一个样本
    next_element = iterator.get_next()
    return next_element

next_element = read_dataset()
# 建立session,打印输出,查看数据是否正确
# 定义初始化变量范围
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init) # 初始化
    coord = tf.train.Coordinator() # 管理线程
    threads = tf.train.start_queue_runners(coord=coord) # 文件名开始进入文件名队列和内存
    for i in range(5):
        print('dataset:', sess.run([next_element['sample'],
                                    next_element['label']]))

    coord.request_stop()
    coord.join(threads)

输出结果如下:

('dataset:', [array([5., 6., 8., 6., 1.], dtype=float32), 1.0])
('dataset:', [array([6., 4., 8., 1., 8.], dtype=float32), 1.0])
('dataset:', [array([5., 1., 0., 8., 8.], dtype=float32), 0.0])
('dataset:', [array([8., 2., 6., 8., 1.], dtype=float32), 0.0])
('dataset:', [array([8., 3., 5., 3., 6.], dtype=float32), 0.0])

PS: 这里需要特别特别注意的是当sample 或者 label不是标量,而且长度事先无法获得的时候怎么创建解析函数。
此时 tf.FixedLenFeature(shape=(), dtype=tf.float32)的 shape 无法指定。

举例来说: sample.shape=[2,3], 在写入tfrecord的时候要对矩阵reshape,同时保存值和shape. 如果已经知道sample的长度,在解析函数中可以用上面的tf.FixedLenFeature([6,1], dtype=tf.float32)来解析。一定一定不能用tf.FixedLenFeature([6], dtype=tf.float32)。这样无法还原sample的值,而且会报出各种奇葩错误。如果不知道sample的shape,可以用tf.VarLenFeature(dtype=tf.float32)。由于变长得到的是稀疏矩阵,解析后需要进行转为密集矩阵的处理。

parsed_example['sample'] = tf.sparse_tensor_to_dense(parsed_example['sample'])

上面的代码输出是每次取一个样本,按顺序一个样本一个样本出列。如果需要打乱顺序,用.shuffle(buffer_size= ) 来打乱顺序。其中buffer_size设置成大于数据集汇总样本数量的值,以保证样本顺序充分打乱。

打乱样本出列顺序

def read_dataset(tfrecord_path = tfrecord_path):
    # 声明读tfrecord文件
    dataset = tf.data.TFRecordDataset(tfrecord_path)
    # 建立解析函数
    new_dataset = dataset.map(_parse_function)
    # 打乱样本顺序
    shuffle_dataset = new_dataset.shuffle(buffer_size=20000)
    # 数据提前进入队列
    prefetch_dataset = batch_dataset.prefetch(2000) # 会快很多
    # 建立迭代器
    iterator = prefetch_dataset.make_one_shot_iterator()
    # 获得下一个样本
    next_element = iterator.get_next()
    return next_element

输出的结果是:

('dataset:', [array([5., 1., 1., 7., 5.], dtype=float32), 0.0])
('dataset:', [array([8., 0., 8., 2., 7.], dtype=float32), 1.0])
('dataset:', [array([6., 5., 9., 1., 2.], dtype=float32), 1.0])
('dataset:', [array([9., 9., 4., 0., 5.], dtype=float32), 0.0])
('dataset:', [array([1., 9., 9., 2., 9.], dtype=float32), 0.0])

再运行一次,取到的数据也完全不一样。已打乱顺序,单样本输出。

批量输出样本:.batch( batch_size )

def read_dataset(tfrecord_path = tfrecord_path):
    # 声明阅读器
    dataset = tf.data.TFRecordDataset(tfrecord_path)
    # 建立解析函数
    new_dataset = dataset.map(_parse_function)
    # 打乱样本顺序
    shuffle_dataset = new_dataset.shuffle(buffer_size=20000)
    # batch输出
    batch_dataset = shuffle_dataset.batch(2)
    # 数据提前进入队列
    prefetch_dataset = batch_dataset.prefetch(2000)
    # 建立迭代器
    iterator = prefetch_dataset.make_one_shot_iterator()
    # 获得下一个样本
    next_element = iterator.get_next()
    return next_element

输出结果如下:

('dataset:', [array([[1., 4., 6., 2., 5.], [3., 7., 6., 6., 9.]], dtype=float32), array([0., 0.], dtype=float32)])
('dataset:', [array([[8., 2., 2., 6., 3.], [7., 5., 3., 0., 3.]], dtype=float32), array([0., 1.], dtype=float32)])
('dataset:', [array([[2., 8., 9., 5., 7.], [0., 5., 1., 5., 5.]], dtype=float32), array([1., 0.], dtype=float32)])
('dataset:', [array([[0., 8., 1., 6., 0.], [7., 3., 8., 8., 1.]], dtype=float32), array([0., 0.], dtype=float32)])
('dataset:', [array([[2., 4., 9., 8., 9.], [3., 5., 9., 6., 0.]], dtype=float32), array([1., 0.], dtype=float32)])

Epoch: 使用.repeat(num_epochs) 来指定遍历几遍数据集
关于Epoch次数,在Queue读取文件的方式中,是在创建文件名队列时设定的

filename_queue = tf.train.string_input_producer([tfrecord_path], num_epochs=3)

根据博客中的实验可知,先取出(样本总数✖️num_Epoch)的数据,打乱顺序,按照batch_size,无放回的取样,保证每个样本都被访问num_Epoch次。

三种读取方式的完整代码

# -*- coding: UTF-8 -*-
# !/usr/bin/python3
# Env: python3.6
import tensorflow as tf
import numpy as np
import os

# path
data_filename = 'data/data_train.txt'  # 生成txt数据保存路径
size = (10000, 5)
tfrecord_path = 'data/test_data.tfrecord'  # tfrecord文件保存路径

#################### 生成txt数据 10000个样本。########################
def generate_data(data_filename=data_filename, size=size):
    if not os.path.exists(data_filename):
        np.random.seed(9)
        x_data = np.random.randint(0, 10, size=size)
        y1_data = np.ones((size[0] // 2, 1), int)  # 一半标签是0,一半是1
        y2_data = np.zeros((size[0] // 2, 1), int)
        y_data = np.append(y1_data, y2_data)
        np.random.shuffle(y_data)

        xy_data = str('')
        for xy_row in range(len(x_data)):
            x_str = str('')
            for xy_col in range(len(x_data[0])):
                if not xy_col == (len(x_data[0]) - 1):
                    x_str = x_str + str(x_data[xy_row, xy_col]) + ' '
                else:
                    x_str = x_str + str(x_data[xy_row, xy_col])
            y_str = str(y_data[xy_row])
            xy_data = xy_data + (x_str + '/' + y_str + '\n')

        # write to txt
        write_txt = open(data_filename, 'w')
        write_txt.write(xy_data)
        write_txt.close()
    return

################  读取txt文件,并转为tfrecord文件 ###########################
# every line of data is just as follow: 1 2 3 4 5/1. train data: 1 2 3 4 5, label: 1
def txt_to_tfrecord(txt_filename=data_filename, tfrecord_path=tfrecord_path):
    # 第一步:生成TFRecord Writer
    writer = tf.python_io.TFRecordWriter(tfrecord_path)

    # 第二步:读取TXT数据,并分割出样本数据和标签
    file = open(txt_filename)
    for data_line in file.readlines():  # 每一行
        data_line = data_line.strip('\n')  # 去掉换行符
        sample = []
        spls = data_line.split('/', 1)[0]  # 样本
        for m in spls.split(' '):
            sample.append(int(m))
        label = data_line.split('/', 1)[1]  # 标签
        label = int(label)
        print('sample:', sample, 'labels:', label)

        # 第三步: 建立feature字典,tf.train.Feature()对单一数据编码成feature
        feature = {'sample': tf.train.Feature(int64_list=tf.train.Int64List(value=sample)),
                   'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))}
        # 第四步:可以理解为将内层多个feature的字典数据再编码,集成为features
        features = tf.train.Features(feature=feature)
        # 第五步:将features数据封装成特定的协议格式
        example = tf.train.Example(features=features)
        # 第六步:将example数据序列化为字符串
        Serialized = example.SerializeToString()
        # 第七步:将序列化的字符串数据写入协议缓冲区
        writer.write(Serialized)
    # 记得关闭writer和open file的操作
    writer.close()
    file.close()
    return


###############   用Queue方式中的tf.parse_single_example解析tfrecord  #########################

# 第一步: 建立文件名队列
filename_queue = tf.train.string_input_producer([tfrecord_path], num_epochs=30)


def read_single(filename_queue, shuffle_batch, if_enq_many):
    # 第二步: 建立阅读器
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    # 第三步:根据写入时的格式建立相对应的读取features
    features = {
        'sample': tf.FixedLenFeature([5], tf.int64),  # 如果不是标量,一定要在这里说明数组的长度
        'label': tf.FixedLenFeature([], tf.int64)
    }
    # 第四步: 用tf.parse_single_example()解析单个EXAMPLE PROTO
    Features = tf.parse_single_example(serialized_example, features)

    # 第五步:对数据进行后处理
    sample = tf.cast(Features['sample'], tf.float32)
    label = tf.cast(Features['label'], tf.float32)

    # 第六步:生成Batch数据 generate batch
    if shuffle_batch:  # 打乱数据顺序,随机取样
        sample_single, label_single = tf.train.shuffle_batch([sample, label],
                                                             batch_size=2,
                                                             capacity=10000,
                                                             min_after_dequeue=1000,
                                                             num_threads=1,
                                                             enqueue_many=if_enq_many)  # 主要是为了评估enqueue_many的作用
    else:  # # 如果不打乱顺序则用tf.train.batch(), 输出队列按顺序组成Batch输出
        '''
        example_list = [[sample,label] for _ in range(2)]  # Reader设置为2

        sample_single, label_single = tf.train.batch_join(example_list, batch_size=1)
        '''

        sample_single, label_single = tf.train.batch([sample, label],
                                                     batch_size=1,
                                                     capacity=10000,
                                                     num_threads=1,
                                                     enqueue_many=if_enq_many)

    return sample_single, label_single


#############   用Queue方式中的tf.parse_example解析tfrecord  ##################################

def read_parse(filename_queue, shuffle_batch, if_enq_many):
    # 第二步: 建立阅读器
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    # 第三步: 设置shuffle_batch
    if shuffle_batch:
        batch = tf.train.shuffle_batch([serialized_example],
                                       batch_size=3,
                                       capacity=10000,
                                       min_after_dequeue=1000,
                                       num_threads=1,
                                       enqueue_many=if_enq_many)  # 主要是为了评估enqueue_many的作用

    else:
        batch = tf.train.batch([serialized_example],
                               batch_size=3,
                               capacity=10000,
                               num_threads=1,
                               enqueue_many=if_enq_many)
        # 第四步:根据写入时的格式建立相对应的读取features
    features = {
        'sample': tf.FixedLenFeature([5], tf.int64),  # 如果不是标量,一定要在这里说明数组的长度
        'label': tf.FixedLenFeature([], tf.int64)
    }
    # 第五步: 用tf.parse_example()解析多个EXAMPLE PROTO
    Features = tf.parse_example(batch, features)

    # 第六步:对数据进行后处理
    samples_parse = tf.cast(Features['sample'], tf.float32)
    labels_parse = tf.cast(Features['label'], tf.float32)
    return samples_parse, labels_parse


############### 用Dataset读取tfrecord文件  ###############################################

# 定义解析函数
def _parse_function(example_proto):
    dics = {  # 这里没用default_value,随后的都是None
        'sample': tf.FixedLenFeature([5], tf.int64),  # 如果不是标量,一定要在这里说明数组的长度
        'label': tf.FixedLenFeature([], tf.int64)}
    # 把序列化样本和解析字典送入函数里得到解析的样本
    parsed_example = tf.parse_single_example(example_proto, dics)

    parsed_example['sample'] = tf.cast(parsed_example['sample'], tf.float32)
    parsed_example['label'] = tf.cast(parsed_example['label'], tf.float32)
    # 返回所有feature
    return parsed_example


def read_dataset(tfrecord_path=tfrecord_path):
    # 声明阅读器
    dataset = tf.data.TFRecordDataset(tfrecord_path)
    # 建立解析函数,其中num_parallel_calls指定并行线程数
    new_dataset = dataset.map(_parse_function, num_parallel_calls=4)
    # 打乱样本顺序
    shuffle_dataset = new_dataset.shuffle(buffer_size=20000)
    # 设置epoch次数为10,这里需要注意的是目前看来只支持先shuffle再repeat的方式
    repeat_dataset = shuffle_dataset.repeat(10) 
    # batch输出
    batch_dataset = repeat_dataset.batch(2)
    # 数据提前进入队列
    prefetch_dataset = batch_dataset.prefetch(2000)
    # 建立迭代器
    iterator = prefetch_dataset.make_one_shot_iterator()
    # 获得下一个样本
    next_element = iterator.get_next()
    return next_element


##################   建立graph ####################################

# 生成数据
# generate_data()
# 读取数据转为tfrecord文件
# txt_to_tfrecord()
# Queue + tf.parse_single_example()读取tfrecord文件
x1_samples, y1_labels = read_single(filename_queue, shuffle_batch=True, if_enq_many=False)
# Queue + tf.parse_example()读取tfrecord文件
x2_samples, y2_labels = read_parse(filename_queue, shuffle_batch=True, if_enq_many=False)
# Dataset读取数据
next_element = read_dataset()

# 定义初始化变量范围
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)  # 初始化
    # 如果tf.train.string_input_producer([tfrecord_path], num_epochs=30)中num_epochs不为空的化,必须要初始化local变量
    sess.run(tf.local_variables_initializer())
    coord = tf.train.Coordinator()  # 管理线程
    threads = tf.train.start_queue_runners(coord=coord)  # 文件名开始进入文件名队列和内存
    for i in range(1):
        # Queue + tf.parse_single_example()读取tfrecord文件
        X1, Y1 = sess.run([x1_samples, y1_labels])
        print('X1: ', X1, 'Y1: ', Y1)
        # Queue + tf.parse_example()读取tfrecord文件
        X2, Y2 = sess.run([x2_samples, y2_labels])
        print('X2: ', X2, 'Y2: ', Y2)
        # Dataset读取数据
        print('dataset:', sess.run([next_element['sample'],
                                    next_element['label']]))
        #这里需要注意,每run一次,迭代器会取下一个样本。
        # 如果是 a= sess.run(next_element['sample'])
        #             b = sess.run(next_element['label']),
        # 则a样本对应的标签值不是b,b是下一个样本对应的标签值。

    coord.request_stop()
    coord.join(threads)

另外,关于dataset加速的用法,可以参见官网说明

Dataset+TFRecord读取变长数据

使用dataset中的padded_batch方法来进行

padded_batch(
    batch_size,
    padded_shapes,
    padding_values=None    #默认使用各类型数据的默认值,一般使用时可忽略该项
)

参数padded_shapes 指明每条记录中各成员要pad成的形状,成员若是scalar,则用[ ],若是list,则用[mx_length],若是array,则用[d1,...,dn],假如各成员的顺序是scalar数据、list数据、array数据,则padded_shapes=([], [mx_length], [d1,...,dn]);
例如tfrecord文件中的key是fea, e.g.fea.shape=[568, 366], 二维,长度变化。fea_shape=[568,366],一维, label=[1, 0, 2,0,3,0]一维,长度变化。
再读取变长数据的时候映射函数应为:

def _parse_function(example_proto):
    dics = {
        'fea': tf.VarLenFeature(dtype=tf.float32),
        'fea_shape': tf.FixedLenFeature(shape=(2,), dtype=tf.int64),
        'label': tf.VarLenFeature(dtype=tf.float32)}

    parsed_example = tf.parse_single_example(example_proto, dics)
    parsed_example['fea'] = tf.sparse_tensor_to_dense(parsed_example['fea'])
    parsed_example['label'] = tf.sparse_tensor_to_dense(parsed_example['label'])
    parsed_example['label'] = tf.cast(parsed_example['label'], tf.int32)
    parsed_example['fea'] = tf.reshape(parsed_example['fea'], parsed_example['fea_shape'])
    return parsed_example

利用tf.VarLenFeature()代替tf.FixedLenFeature(),在后处理中要注意用tf.sparse_tensor_to_dense()将读取的变长数据转为稠密矩阵。

def dataset():
    tf_lst = get_tf_list(tf_file_lst)
    dataset = tf.data.TFRecordDataset(tf_lst)
    new_dataset = dataset.map(_parse_function)
    shuffle_dataset = new_dataset.shuffle(buffer_size=20000)
    repeat_dataset = shuffle_dataset.repeat(10)
    prefetch_dataset = repeat_dataset.prefetch(2000)
    batch_dataset = prefetch_dataset.padded_batch(2, padded_shapes={'fea': [None, None], 'fea_shape': [None], 'label': [None]})
    iterator = batch_dataset.make_one_shot_iterator()
    next_element = iterator.get_next()

    return next_element

这里padded_shapes={'fea': [None, None], 'fea_shape': [None], 'label': [None]}
如果报错 All elements in a batch must have the same rank as the padded shape for component1: expected rank 2 but got element with rank 1请仔细查看padded_shapes中设置的维度是否正确。如果padded_shapes={'fea': [None, None], 'fea_shape': [None, None], 'label': [None]}即fea_shape本来的rank应该是1,但是在pad的时候设置了2,所以报错。

如果报错The two structures don't have the same sequence type. Input structure has type <class 'tuple'>, while shallow structure has type <class 'dict'>.,则可能是padded_shapes定义的格式不对,如定义成了padded_shapes=([None, None],[None],[None]),请按照字典格式定义pad的方式。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 206,839评论 6 482
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 88,543评论 2 382
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 153,116评论 0 344
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 55,371评论 1 279
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 64,384评论 5 374
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,111评论 1 285
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,416评论 3 400
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,053评论 0 259
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 43,558评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,007评论 2 325
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,117评论 1 334
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,756评论 4 324
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,324评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,315评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,539评论 1 262
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,578评论 2 355
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,877评论 2 345

推荐阅读更多精彩内容