大型元素集
-
源数据集 (Source Datasets)
创建数据集的最简单方法 list:
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) for element in dataset: print(element)
处理文本文件:
dataset = tf.data.TextLineDataset(["file1.txt", "file2.txt"])
处理TFRecord文件
dataset = tf.data.TFRecordDataset(["file1.tfrecords", "file2.tfrecords"])
创建一个匹配规则的所有文件的数据集
dataset = tf.data.Dataset.list_files("/path/*.txt") # doctest: +SKIP
-
转换 (Transformations)
有了数据集后,您可以对准备的数据进行转换
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) dataset = dataset.map(lambda x: x*2) list(dataset.as_numpy_iterator()) # 输出 [2, 4, 6]
-
小 tips
数据集里的元素说明
element_spec
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]).element_spec TensorSpec(shape=(), dtype=tf.int32, name=None)
方法
-
apply
apply( transformation_func ) 参数:transformation_func 一个方法名(此方法接收一个dataset参数 并返回处理后的dataset) 返回值:apply的返回值 即参数 transformation_func 方法的返回值
将转换函数应用于此数据集
dataset = tf.data.Dataset.range(100) def dataset_fn(ds): return ds.filter(lambda x: x < 5) dataset = dataset.apply(dataset_fn) list(dataset.as_numpy_iterator()) # Output: [0, 1, 2, 3, 4]
-
as_numpy_iterator
返回一个迭代器,该迭代器将数据集的所有元素转换为numpy
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) for element in dataset.as_numpy_iterator(): print(element) # Output: # 1 # 2 # 3 dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) print(list(dataset.as_numpy_iterator())) # Output: [1, 2, 3]
-
as_numpy_iterator()
将保留数据集元素的嵌套结构dataset = tf.data.Dataset.from_tensor_slices({'a': ([1, 2], [3, 4]), 'b': [5, 6]}) list(dataset.as_numpy_iterator()) == [{'a': (1, 3), 'b': 5}, {'a': (2, 4), 'b': 6}] # Output: True
-
batch
将数据集进行分批处理
batch( batch_size, drop_remainder=False ) # param1: batch_size 每批次包含几个元素 # param2: drop_remainder ds_length/batch_size 不能被整除时 是否删掉最后一个批次 # return: 一个 Dataset
dataset = tf.data.Dataset.range(8) dataset = dataset.batch(3) list(dataset.as_numpy_iterator()) # Output: [array([0, 1, 2]), array([3, 4, 5]), array([6, 7])]
dataset = tf.data.Dataset.range(8) dataset = dataset.batch(3, drop_remainder=True) list(dataset.as_numpy_iterator()) # Output: [array([0, 1, 2]), array([3, 4, 5])]
-
cache
缓存数据集中的元素 可缓存在内存或指定文件内
cache( filename='' ) # param: filename 文件名 如未提供此参数 则默认缓存到内存 # return: 一个 Dataset
dataset = tf.data.Dataset.range(5) dataset = dataset.map(lambda x: x**2) dataset = dataset.cache() # 缓存到内存中 # The first time reading through the data will generate the data using `range` and `map`. list(dataset.as_numpy_iterator()) # Subsequent iterations read from the cache. list(dataset.as_numpy_iterator())
dataset = tf.data.Dataset.range(5) dataset = dataset.cache("/path/to/file") # doctest: +SKIP 缓存到文件中 list(dataset.as_numpy_iterator()) # doctest: +SKIP dataset = tf.data.Dataset.range(10) dataset = dataset.cache("/path/to/file") # Same file! # doctest: +SKIP list(dataset.as_numpy_iterator()) # doctest: +SKIP
-
concatenate
通过将给定数据集与此数据集连接来创建一个 Dataset
注意 两个数据集的 结构 和 数据类型 必须一致
concatenate( dataset )
a = tf.data.Dataset.range(1, 4) # ==> [ 1, 2, 3 ] b = tf.data.Dataset.range(4, 8) # ==> [ 4, 5, 6, 7 ] ds = a.concatenate(b) list(ds.as_numpy_iterator()) # Output: [1, 2, 3, 4, 5, 6, 7]
-
enumerate
枚举此数据集的元素
enumerate( start=0 ) # param: start 表示枚举的起始值 # return: A Dataset
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) dataset = dataset.enumerate(start=5) for element in dataset.as_numpy_iterator(): print(element) # Output: # (5, 1) # (6, 2) # (7, 3)
# The nested structure of the input dataset determines the structure of # elements in the resulting dataset. dataset = tf.data.Dataset.from_tensor_slices([(7, 8), (9, 10)]) dataset = dataset.enumerate() for element in dataset.as_numpy_iterator(): print(element) # Output: # (0, array([7, 8], dtype=int32)) # (1, array([ 9, 10], dtype=int32))
-
filter
过滤数据集 进行条件筛选
filter( predicate ) # param: predicate 将数据集元素映射到布尔值的函数。 # return: A Dataset
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) dataset = dataset.filter(lambda x: x < 3) list(dataset.as_numpy_iterator()) # Output: [1, 2] # `tf.math.equal(x, y)` is required for equality comparison def filter_fn(x): return tf.math.equal(x, 1) dataset = dataset.filter(filter_fn) list(dataset.as_numpy_iterator()) # Output: [1]
-
flat_map
flat_map( map_func ) # param: map_func 映射数据集里每一个元素的方法 # return: A Dataset
根据
map_func
方法映射并展开数据集dataset = Dataset.from_tensor_slices([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) dataset = dataset.flat_map(lambda x: Dataset.from_tensor_slices(x)) list(dataset.as_numpy_iterator()) # Output: [1, 2, 3, 4, 5, 6, 7, 8, 9]
-
from_generator
根据生成器创建一个数据集
@staticmethod from_generator( generator, output_types, output_shapes=None, args=None ) # param1: generator 一个可调用的生成器 其返回值的结构和类型必须与param2, param3一致 该生成器函数的参数数量需要与param4一致 # param2: 确定生成器返回的每个值的数据类型 # param3: 确定生成器返回的每个值的结构形状 # param4: 一个元组 传递给生成器作为参数
import itertools def gen(): for i in itertools.count(1): yield (i, [1] * i) dataset = tf.data.Dataset.from_generator( gen, (tf.int64, tf.int64), (tf.TensorShape([]), tf.TensorShape([None]))) list(dataset.take(3).as_numpy_iterator()) # Output: [(1, array([1])), (2, array([1, 1])), (3, array([1, 1, 1]))]