关于在《TensorFlow实战》这本书第5章中出现的cifar10,好多人都pip install cifar10发现失败:
no module named cifar10与no modulw named cifar10_input,这是因为你需要下载一个tensorflow 的models,具体链接放在这里https://github.com/tensorflow/models/tree/r1.13.0,新的tensorflow由于变成2.x版了,所以没有models这个包,我这里用的是19年的branch。
下载完成后,将它解压缩并命名为models然后放到你安装tensorflow的那个目录下
紧接着要修改里边的文件
一共要修改两个地方:
①
删去那两行,用这两行代替
②
OK了你可以用了!
华丽的分割线(以上是2020年11月16日更新的,以下是2019年的)
之前看《TensorFlow实战》的时候就卡在了第五章“TensorFlow实现卷积神经网络”,原因是这里的cifar10数据集导入不进去。
cifar10.maybe_download_and_extract()
这里,如果你尝试的话会出现
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-12-02a754d7036a> in <module>()
----> 1 cifar10.maybe_download_and_extract()
AttributeError: module 'cifar10' has no attribute 'maybe_download_and_extract'
找不到 maybe_download_and_extract() 方法。。。什么鬼!
这可咋整,然后我还从官网上下载了cifar-10-batches-py,170多M,然后data_dir。。还是不行,后来彻底放弃了,这憨批TensorFlow。
时隔俩月,我今天又头铁,查了无数资料,没有提到 module 'cifar10' has no attribute 'maybe_download_and_extract'这种错误的???怎么说??网上冲浪的各位难道用的都是2016版TensorFlow-model???
那肯定是model更新了(https://github.com/tensorflow/models.git)models模块在这里有需要可以去下载。
一筹莫展之际,书上的下一行代码引起了我注意:
images_train, labels_train = cifar10_input.distorted_inputs(
data_dir=data_dir, batch_size=batch_size)
OK,数据可以从这个线索去找,于是我打开models下的cifar10_input.py文件查找到distorted_inputs函数如下:
def distorted_inputs(batch_size):
"""Construct distorted input for CIFAR training using the Reader ops.
Args:
batch_size: Number of images per batch.
Returns:
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
labels: Labels. 1D tensor of [batch_size] size.
"""
return _get_images_labels(batch_size, tfds.Split.TRAIN, distords=True)
OK,那我再去找_get_images_labels,也是cifar10_input.py中(奇怪为什么同一个功能要用两个函数):
def _get_images_labels(batch_size, split, distords=False):
"""Returns Dataset for given split."""
dataset = tfds.load(name='cifar10', split=split)
scope = 'data_augmentation' if distords else 'input'
with tf.name_scope(scope):
dataset = dataset.map(DataPreprocessor(distords), num_parallel_calls=10)
# Dataset is small enough to be fully loaded on memory:
dataset = dataset.prefetch(-1)
dataset = dataset.repeat().batch(batch_size)
iterator = dataset.make_one_shot_iterator()
images_labels = iterator.get_next()
images, labels = images_labels['input'], images_labels['target']
tf.summary.image('images', images)
return images, labels
可以看到datasets 是由tfds的load方法加载的
tfds,也就是import tensorflow_datasets as tfds
找到cifar10_download_and_extract.py这个文件,发现
DATA_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
这是cifar的下载路径,然后我打开,发现500服务器拒绝访问,难怪下载不了
修改方案:DATA_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'改为
DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
去掉s再访问发现是可行的(当场摔电脑)
保存,运行
images_train, labels_train = cifar10_input.distorted_inputs(
data_dir=data_dir, batch_size=batch_size)
等下载完成后程序自动解压。此时在C:\Users\***\tensorflow_datasets文件夹下可以找到“cifar10”和“download”这两个文件夹。
OK完成,可以进行接下来的学习了。
2019年6月20日编辑
今天我尝试了一下,发现改成'https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'也能正常下载了,不知道为什么,昨天访问的时候是500服务器错误。
需要注意的一点是,无论有没有http(s),我们都需要搭VPN才能下载。毕竟你懂的。