tf.data.Dataset中的cache操作主要完成:
- 第一次迭代数据集时,其元素将缓存在指定文件或内存
- 后续迭代将使用缓存数据
使用注意事项:
- 必须完整地迭代输入数据集。 否则,后续迭代将不会使用缓存数据
- 缓存将在数据集的每次迭代期间生成完全相同的元素,所以请确保在调用cache后调用shuffle
cache操作带来的性能提升
import tensorflow as tf
import time
dataset = tf.data.Dataset.range(10000)
def h(x):
x = tf.cast(x, dtype=tf.float32)
x = tf.math.sin(x)
return x
dataset = dataset.map(h)
dataset = dataset.cache()
start = time.time()
sum = 0
for i in dataset.as_numpy_iterator():
sum += i
end1 = time.time()
sum = 0
for i in dataset.as_numpy_iterator():
sum += i
end2 = time.time()
print("no cache time:", (end1 - start))
print("cache time:", (end2 - end1))
no cache time: 0.6333067417144775
cache time: 0.5754616260528564