学习Tensorflow从数据集开始,先弄清楚Tensorflow的Keras框架提供的数据集。Keras提供了7个数据集,类型比较全面,基本上满足Tensorflow的入门学习。对数据集掌握如下几个方面
1. 数据集的原始来源与说明;
2. 数据集的加载;
3. 数据集的格式;
一. Tensorflow数据集概述
- 从Tensorflow的数据集开始掌握Tensorflow2.0与Keras。
1.1. Tensorflow提供的数据集模块
- boston_housing 模块:
- 波士顿住房价格回归数据集。
- cifar10 模块:
- CIFAR10小图像分类数据集。
- cifar100 模块:
- CIFAR100小图像分类数据集。
- fashion_mnist 模块:
- 时尚mnist数据集。
- imdb 模块:
- IMDB情绪分类数据集。
- mnist 模块:
- mnist手写数字数据集。
- reuters 模块:
- 路透社主题分类数据集。
1.2. 数据集模块函数-load_data
tf.keras.datasets.fashion_mnist.load_data()
返回的数据集格式如下:
(x_train, y_train), (x_test, y_test)
- 上面的数据集因为google的缘故,所以有的数据集无法访问。但是可以通过sklearn库获取。
- 能通过tensorflow模块访问的如下,其他则无法访问。
二. CIFAR10小图像分类数据集
2.1. CIFAR-10数据集介绍
- CIFAR-10数据集是一个包含60000张图片的数据集。
- 其中每张照片为32*32的彩色照片,每个像素点包括RGB三个数值,数值范围 0 ~ 255。
- 所有照片分属10个不同的类别,分别是 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'
- 其中五万张图片被划分为训练集,剩下的一万张图片属于测试集。
2.2. 数据集的读取
2.2.1. 使用tensorflow模块加载
- 数据加载与数据格式
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow.keras.datasets.cifar10 as cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train.shape, y_train.shape, x_test.shape, y_test.shape
((50000, 32, 32, 3), (50000, 1), (10000, 32, 32, 3), (10000, 1))
- 数据可视化
%matplotlib inline
import matplotlib.pyplot as plt
plt.imshow(x_train[0])
plt.show()
y_train[0] # 青蛙:frog。
array([6], dtype=uint8)
- 类别说明
- 官网介绍地址:
http://www.cs.toronto.edu/~kriz/cifar.html
- 官网介绍地址:
-类别索引: 类别下标从0开始
- 0:airplane
- 1:automobile
- 2:bird
- 3:cat
- 4:deer
- 5:dog
- 6:frog
- 7:horse
- 8:ship
- 9:truck
2.2.2. 从解压的本地文件读取
- 可以从
http://www.cs.toronto.edu/~kriz/cifar.html
下载tar.gz压缩格式的CIFAR10数据集。 -
下载解压缩的数据文件如下:
- 读取方式可以从官网获取例子代码
import pickle # python的序列化归档模块
with open('01datasets/cifar-10-batches-py/data_batch_1', 'rb') as fo: # 打开文件
dict = pickle.load(fo, encoding='bytes') # 读取文件
# 显示数据结构
print('数据项:', dict.keys())
print('数据批次:', dict[b'batch_label'])
# 只读取第一个图像数据的信息
print('标签:', dict[b'labels'][0])
print('数据:', dict[b'data'][0])
print('文件名:', dict[b'filenames'][0])
数据项: dict_keys([b'batch_label', b'labels', b'data', b'filenames'])
数据批次: b'training batch 1 of 5'
标签: 6
数据: [ 59 43 50 ... 140 84 72]
文件名: b'leptodactylus_pentadactylus_s_000004.png'
- 可视化图像
%matplotlib inline
import matplotlib.pyplot as plt
print(dict[b'data'].shape)
print(type(dict[b'data'][0])) # 显示数据的数据类型
img = dict[b'data'][0].reshape((3, 32, 32))
# 坐标轴交换方法
# img = img.swapaxes(0, 1) # 0->1
# img = img.swapaxes(1, 2) # 0->2
# 矩阵维度转换方法
img = img.transpose(1, 2, 0) # 这是正常的图像顺序
plt.imshow(img)
plt.show()
(10000, 3072)
<class 'numpy.ndarray'>
- meta文件的读取
import pickle # python的序列化归档模块
with open('01datasets/cifar-10-batches-py/batches.meta', 'rb') as fo: # 打开文件
dict = pickle.load(fo, encoding='bytes') # 读取文件
print(dict.keys())
print(dict[b'num_cases_per_batch']) # 每个文件中的样本数量
print(dict[b'label_names']) # 标签名(标签的编号是按照类别的字典序的序号)
print(dict[b'num_vis']) # 每个图像的数据个数
dict_keys([b'num_cases_per_batch', b'label_names', b'num_vis'])
10000
[b'airplane', b'automobile', b'bird', b'cat', b'deer', b'dog', b'frog', b'horse', b'ship', b'truck']
3072
2.2.3. 直接从下载的压缩文件中读取
- Python提供了.tar.gz归档文件读取方式。
- 读取归档的文件列表
import tarfile
import os.path
# 要读取的压缩归档文件
list_files = 'data_batch_1'
with tarfile.open('01datasets/cifar-10-python.tar.gz', mode='r') as fz:
# 得到归档文件中的文件列表
filenames = fz.getnames()
# 遍历文件列表
for filename in filenames:
print(filename)
cifar-10-batches-py
cifar-10-batches-py/data_batch_4
cifar-10-batches-py/readme.html
cifar-10-batches-py/test_batch
cifar-10-batches-py/data_batch_3
cifar-10-batches-py/batches.meta
cifar-10-batches-py/data_batch_2
cifar-10-batches-py/data_batch_5
cifar-10-batches-py/data_batch_1
- 读归档文件内容
import tarfile
import os.path
# 要读取的压缩归档文件
list_files = 'data_batch_1'
with tarfile.open('01datasets/cifar-10-python.tar.gz', mode='r') as fz:
# 得到归档文件中的文件列表
filenames = fz.getnames()
# 遍历文件列表
for filename in filenames:
print('文件名:',filename)
# 把文件解析成不同组件
base_name = os.path.basename(filename)
# 只读取指定的文件
if base_name == list_files:
# 抽取文件内容
buffer_reader = fz.extractfile(filename)
print('抽取的返回值类型:', type(buffer_reader))
# 可以使用序列化工具实现反序列化(buffer_reader的类型是ExFileOObject对象,就是一个打开的文件,使用load)
dict = pickle.load(buffer_reader, encoding='bytes')
break
print('-------------------------')
print('读取的数据字典key:', dict.keys())
文件名: cifar-10-batches-py
文件名: cifar-10-batches-py/data_batch_4
文件名: cifar-10-batches-py/readme.html
文件名: cifar-10-batches-py/test_batch
文件名: cifar-10-batches-py/data_batch_3
文件名: cifar-10-batches-py/batches.meta
文件名: cifar-10-batches-py/data_batch_2
文件名: cifar-10-batches-py/data_batch_5
文件名: cifar-10-batches-py/data_batch_1
抽取的返回值类型: <class 'tarfile.ExFileObject'>
-------------------------
读取的数据字典key: dict_keys([b'batch_label', b'labels', b'data', b'filenames'])
- 读取的数据可视化
%matplotlib inline
import matplotlib.pyplot as plt
img = dict[b'data'][0].reshape((3, 32, 32))
# 坐标轴交换方法
img = img.swapaxes(0, 1) # 0->1
img = img.swapaxes(1, 2) # 0->2
# 矩阵维度转换方法
# img = img.transpose(1, 2, 0) # 这是正常的图像顺序
plt.imshow(img)
plt.show()
三. CIFAR100小图像分类数据集
3.1. CIFAR100数据集说明
-
CIFAR100数据集与CIFAR10数据集一样,差别在于类别与图像数不同:
- CIFAR100一共100类数据;
- 100类图像又分成20个大类;
- 每类数据600个图像;
- CIFAR100一共100类数据;
-
每个类别的说明:
- 详细的资料可以参考官网:
http://www.cs.toronto.edu/~kriz/cifar.html
- 详细的资料可以参考官网:
3.2. 数据集的读取
3.2.1. 使用tensorflow模块加载
- 数据加载与数据格式
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow.keras.datasets.cifar100 as cifar100
(x_train, y_train), (x_test, y_test) = cifar100.load_data()
x_train.shape, y_train.shape, x_test.shape, y_test.shape
((50000, 32, 32, 3), (50000, 1), (10000, 32, 32, 3), (10000, 1))
- 数据可视化
%matplotlib inline
import matplotlib.pyplot as plt
plt.imshow(x_train[0])
plt.show()
y_train[0] # 牛,返回的是细分类别的标签(就是类别名的字典排序的序号)
array([19])
- 类别说明
- 100种类别有点多,这里不意义对应罗列。正如上面cifar10的规律,其中细分类别的标签是细分类别的名字的字典序的序号;
- 直接从meta文件中读取;
- 大的类别在数据中没有;
with open('01datasets/cifar-100-python/meta', 'rb') as fo: # 打开文件
dict = pickle.load(fo, encoding='bytes') # 读取文件
print(dict.keys())
print('细分类型:', dict[b'fine_label_names'])
print('大类型:', dict[b'coarse_label_names'])
dict_keys([b'fine_label_names', b'coarse_label_names'])
细分类型: [b'apple', b'aquarium_fish', b'baby', b'bear', b'beaver', b'bed', b'bee', b'beetle', b'bicycle', b'bottle', b'bowl', b'boy', b'bridge', b'bus', b'butterfly', b'camel', b'can', b'castle', b'caterpillar', b'cattle', b'chair', b'chimpanzee', b'clock', b'cloud', b'cockroach', b'couch', b'crab', b'crocodile', b'cup', b'dinosaur', b'dolphin', b'elephant', b'flatfish', b'forest', b'fox', b'girl', b'hamster', b'house', b'kangaroo', b'keyboard', b'lamp', b'lawn_mower', b'leopard', b'lion', b'lizard', b'lobster', b'man', b'maple_tree', b'motorcycle', b'mountain', b'mouse', b'mushroom', b'oak_tree', b'orange', b'orchid', b'otter', b'palm_tree', b'pear', b'pickup_truck', b'pine_tree', b'plain', b'plate', b'poppy', b'porcupine', b'possum', b'rabbit', b'raccoon', b'ray', b'road', b'rocket', b'rose', b'sea', b'seal', b'shark', b'shrew', b'skunk', b'skyscraper', b'snail', b'snake', b'spider', b'squirrel', b'streetcar', b'sunflower', b'sweet_pepper', b'table', b'tank', b'telephone', b'television', b'tiger', b'tractor', b'train', b'trout', b'tulip', b'turtle', b'wardrobe', b'whale', b'willow_tree', b'wolf', b'woman', b'worm']
大类型: [b'aquatic_mammals', b'fish', b'flowers', b'food_containers', b'fruit_and_vegetables', b'household_electrical_devices', b'household_furniture', b'insects', b'large_carnivores', b'large_man-made_outdoor_things', b'large_natural_outdoor_scenes', b'large_omnivores_and_herbivores', b'medium_mammals', b'non-insect_invertebrates', b'people', b'reptiles', b'small_mammals', b'trees', b'vehicles_1', b'vehicles_2']
3.2.2. 从解压的本地文件读取
- 下载地址:
http://www.cs.toronto.edu/~kriz/cifar.html
-
解压缩的文件如下:
- 读取数据
import pickle # python的序列化归档模块
with open('01datasets/cifar-100-python/train', 'rb') as fo: # 打开文件
dict = pickle.load(fo, encoding='bytes') # 读取文件
# 显示数据结构
print('数据项:', dict.keys())
print('数据批次:', dict[b'batch_label'])
# 只读取第一个图像数据的信息
print('标签:', dict[b'fine_labels'][0]) # 细分类型
print('标签:', dict[b'coarse_labels'][0]) # 大类型
print('数据:', dict[b'data'][0])
print('文件名:', dict[b'filenames'][0])
数据项: dict_keys([b'filenames', b'batch_label', b'fine_labels', b'coarse_labels', b'data'])
数据批次: b'training batch 1 of 1'
标签: 19
标签: 11
数据: [255 255 255 ... 10 59 79]
文件名: b'bos_taurus_s_000507.png'
-
类别说明
0:aquatic mammals
1:fish
2:flowers
3:food containers
4:fruit and vegetables
5:household electrical devices
6:household furniture
7:insects
8:large carnivores
9:large man-made outdoor things
10:large natural outdoor scenes
11:large omnivores and herbivores
12:medium-sized mammals
13:non-insect invertebrates
14:people
15:reptiles
16:small mammals
17:trees
18:vehicles 1
19:vehicles 2
其中coarse_labels是大类的类名的字典序的序号。
数据可视化
%matplotlib inline
import matplotlib.pyplot as plt
img = dict[b'data'][0].reshape((3, 32, 32))
# 坐标轴交换方法
img = img.swapaxes(0, 1) # 0->1
img = img.swapaxes(1, 2) # 0->2
# 矩阵维度转换方法
# img = img.transpose(1, 2, 0) # 这是正常的图像顺序
plt.imshow(img)
plt.show()
3.2.3. 直接从压缩文件中读取
(略)
四. mnist手写数字数据集
4.1. mnist手写数字数据集说明
- 数据集一共70000个样本,包含60000个训练样本与10000个测试集样本。
- 返回的数据结构是元组:
- 元组第一个元素也是元组,
- 60000个训练样本,元组格式(训练集,标签集)
- 元组第一个元素也是元组
- 10000个训练样本,元组格式(训练集,标签集)
- 元组第一个元素也是元组,
4.2. 数据集的读取
4.2.1. 使用tensorflow模块加载数据集
- 数据集加载代码
import tensorflow as tf
import tensorflow.keras.datasets.mnist as mnist
(data_train, label_train), (data_test, label_test) = mnist.load_data()
print(data_train.shape, label_train.shape) # 训练样本集
print(data_test.shape, label_test.shape) # 测试样本集
(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)
- 数据集可视化
%matplotlib inline
import matplotlib.pyplot as plt
plt.imshow(data_train[0], cmap='gray')
plt.show()
4.2.2. 从本地加载数据
4.2.2.1. 下载本地文件说明
-
下载地址
http://yann.lecun.com/exdb/mnist/
-
下载的文件说明:
- 训练数据:train-images.idx3-ubyte
- 训练标签:train-labels.idx1-ubyte
- 测试数据:t10k-images.idx3-ubyte
- 测试标签:t10k-labels.idx1-ubyte
数据集种图像文件与标签文件的格式
4.2.2.2. 加载图像数据集
- 加载图像数据集的头部meta信息
import struct
with open('./01datasets/minist/t10k-images.idx3-ubyte', 'br') as fd:
# 读取图像的信息
header_buf = fd.read(16) # 16字节,4个int整数
# 按照字节解析头信息(具体参考python SL的struct帮助)
magic, nums, width, height = struct.unpack('>iiii', header_buf) # 解析成四个整数:>表示大端字节序,i表示4字节整数
print('magic number:', magic) # 魔法字一般用来表示文件类型与格式类型
print('图像数量:', nums)
print('宽度:', width)
print('高度:', height)
magic number: 2051
图像数量: 10000
宽度: 28
高度: 28
- 加载图像数据-循环方式
import struct
import numpy as np
imgs = [] # 格式1
with open('./01datasets/minist/t10k-images.idx3-ubyte', 'br') as fd:
# 读取图像的信息
header_buf = fd.read(16) # 16字节,4个int整数
# 按照字节解析头信息(具体参考python SL的struct帮助)
magic, nums, width, height = struct.unpack('>iiii', header_buf) # 解析成四个整数:>表示大端字节序,i表示4字节整数
# 保存成ndarray对象
np_imgs = np.empty((nums, height, width)) # 格式2
# 循环读取图像
for i in range(nums):
# 图像缓冲
img_buf = fd.read(width * height)
# 解析图像为
img = struct.unpack('>'+str(width * height)+'B', img_buf)
imgs.append(img)
# ndarrary保存格式
np_imgs[i] = np.array(img).reshape(height, width)
- 可视化图像
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
# 显示读取的格式1图像
ax1 = plt.subplot(121)
ax1.imshow(np.array(imgs[0]).reshape(height, width), cmap='gray')
# 显示读取的格式2图像
ax2 = plt.subplot(122)
ax2.imshow(np_imgs[0], cmap='gray')
plt.show()
- 加载图像-使用nparray的load功能
import struct
with open('./01datasets/minist/t10k-images.idx3-ubyte', 'br') as fd:
# 读取图像的信息
header_buf = fd.read(16) # 16字节,4个int整数
# 按照字节解析头信息(具体参考python SL的struct帮助)
magic, nums, width, height = struct.unpack('>iiii', header_buf) # 解析成四个整数:>表示大端字节序,i表示4字节整数
# 保存成ndarray对象
imgs = np.fromfile(fd, dtype=np.uint8)
# 返回的所有剩余数据对象的格式
print(type(imgs))
print(imgs.shape)
# 重新reshape一下。
imgs = imgs.reshape(nums, height, width)
print(imgs.shape)
# 可视化读取的图像,验证是否正确
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
plt.imshow(imgs[0], cmap='gray')
plt.show()
<class 'numpy.ndarray'>
(7840000,)
(10000, 28, 28)
4.2.2.3. 加载标签数据
- 读取标签头
# t10k-labels.idx1-ubyte
import struct
with open('./01datasets/minist/t10k-labels.idx1-ubyte', 'br') as fd:
# 读取头,8字节
header_buf = fd.read(8) # 16字节,4个int整数
# 解析头信息
magic, nums = struct.unpack('>ii' ,header_buf) # 解析成2个整数
print('魔法字:', magic)
print('标签个数', nums)
魔法字: 2049
标签个数 10000
- 读取标签- 循环字节方式
import struct
labels = []
with open('./01datasets/minist/t10k-labels.idx1-ubyte', 'br') as fd:
# 读取头,8字节
header_buf = fd.read(8) # 16字节,4个int整数
# 解析头信息
magic, nums = struct.unpack('>ii' ,header_buf) # 解析成2个整数
# 循环读取标签,每个标签一个字节
for i in range(nums):
label_buf = fd.read(1) # 读取一个字节
labels.append(struct.unpack('>B', label_buf)) # 记得返回的是元组
print(labels[0])
(7,)
- 读取标签-ndarray的fromfile函数
import struct
with open('./01datasets/minist/t10k-labels.idx1-ubyte', 'br') as fd:
# 读取头,8字节
header_buf = fd.read(8) # 16字节,4个int整数
# 解析头信息
magic, nums = struct.unpack('>ii' ,header_buf) # 解析成2个整数
# 循环读取标签,每个标签一个字节
labels = np.fromfile(fd, np.uint8)
print(labels[0])
7
4.2.3. 图片格式的数据集
- 还可以下载到原始的图片格式的手写数字数据集。
- 图片的处理与加载速度慢一点,处理麻烦点,需要可以手工处理;
五. 时尚mnist数据集
5.1. 时尚mnist数据集介绍
该数据集与mnist手写数字数据集一样,也是70000样本,分成10类,差别就是类别是单件服饰等物品。
-
该数据集因为
https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
无法访问,所以我们可以从其他地方获取。https://github.com/zalandoresearch/fashion-mnist
5.2. 数据集的读取
5.2.1. 使用tensorflow读取
- 这种方式因为数据集在google的官网,该官网在中国目前无法访问,所以下面代码执行一般会报网络错误。
- 错误提示为:
Exception: URL fetch failure on https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz: None -- [Errno 65] No route to host
%matplotlib inline
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow.keras.datasets.fashion_mnist as fashion_mnist
# (data_train, label_train), (data_test, label_test) = fashion_mnist.load_data()
# print(data_train.shape, label_train.shape) # 训练样本集
# print(data_test.shape, label_test.shape) # 测试样本集
# plt.imshow(data_train[0], cmap='gray')
# plt.show()
5.2.2. 从本地加载数据
5.2.2.1. 下载本地文件
-
下载地址:
https://github.com/zalandoresearch/fashion-mnist/tree/master/data/fashion
https://www.kaggle.com/zalando-research/fashionmnist
-
下载文件说明:
- 训练图像文件:train-images-idx3-ubyte
- 训练标签文件:train-labels-idx1-ubyte
- 测试图像文件:t10k-images-idx3-ubyte
-
测试标签文件:t10k-labels-idx1-ubyte
- 读取图像数据并可视化
import struct
with open('./01datasets/fashion-mnist/t10k-images-idx3-ubyte', 'br') as fd:
# 读取图像的信息
header_buf = fd.read(16) # 16字节,4个int整数
# 按照字节解析头信息(具体参考python SL的struct帮助)
magic, nums, width, height = struct.unpack('>iiii', header_buf) # 解析成四个整数:>表示大端字节序,i表示4字节整数
# 保存成ndarray对象
imgs = np.fromfile(fd, dtype=np.uint8)
# 重新reshape一下。
imgs = imgs.reshape(nums, height, width)
print(imgs.shape)
# 可视化读取的图像,验证是否正确
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
plt.imshow(imgs[0], cmap='gray')
plt.show()
(10000, 28, 28)
- 读取标签数据
import struct
with open('./01datasets/fashion-mnist/t10k-labels-idx1-ubyte', 'br') as fd:
# 读取头,8字节
header_buf = fd.read(8) # 16字节,4个int整数
# 解析头信息
magic, nums = struct.unpack('>ii' ,header_buf) # 解析成2个整数
# 循环读取标签,每个标签一个字节
labels = np.fromfile(fd, np.uint8)
print(labels[0])
9
- 类别说明
- 关于fashion-mnist的说明在如下网站有详细介绍:
https://www.kaggle.com/zalando-research/fashionmnist
- 0:T恤(T-shirt/top)
- 1:裤子(Trouser)
- 2:套头衫(Pullover)
- 3:连衣裙(Dress)
- 4:外套(Coat)
- 5:凉鞋(Sandal)
- 6:衬衫(Shirt)
- 7:运动鞋(Sneaker)
- 8:包(Bag)
- 9:靴子(Ankle boot)
- 关于fashion-mnist的说明在如下网站有详细介绍:
六. IMDB情绪分类数据集
6.1. IMDB数据集说明
MDB数据集包含来自互联网的50000条严重两极分化的评论,该数据被分为用于训练的25000条评论和用于测试的25000条评论,训练集和测试集都包含50%的正面评价和50%的负面评价。
该数据集已经经过预处理:评论(单词序列)已经被转换为整数序列,其中每个整数代表字典中的某个单词。
6.2. 数据集读取
6.2.1. 使用tensorflow模块加载
- 该加载方式因为无法访问google官方站点,所以不能使用,可以在网络其他地方下载。
import tensorflow.keras.datasets.imdb as imdb
# (x_train, y_train), (x_test, y_test) = imdb.load_data()
6.2.2. 从本地文件加载数据
- 文件下载
- 下载keras支持的格式
-
https://download.csdn.net/download/luffysman/10959289
- 或者通过网络搜索
-
- 下载地址完整的文本数据
http://ai.stanford.edu/~amaas/data/sentiment/
- 下载keras支持的格式
-
文件说明
- 词汇表索引文件:imdb_word_index.json
- IMDB数据集文件:imdb.npz
-
使用数据集文件
- 因为数据集使用二进制存放,需要知道数据集的格式,目前没有数据集格式说明,所以使用keras的本地机制来加载IMDB数据集。
- 使用方法:
- 下载好文件;
- 拷贝到默认工作目录:
'~/.keras/datasets/'
,因为该目录是默认加载路径; -
在load_data中使用path参数指定加载的文件
- 加载数据集代码
import tensorflow.keras.datasets.imdb as imdb
(x_train, y_train), (x_test, y_test) = imdb.load_data(path='imdb.npz')
x_train.shape, y_train.shape, x_test.shape, y_test.shape,
((25000,), (25000,), (25000,), (25000,))
- 数据格式
- 数据是处理好的整数格式。
print(x_train[0])
[1, 14, 22, 16, 43, 530, 973, 1622, 1385, 65, 458, 4468, 66, 3941, 4, 173, 36, 256, 5, 25, 100, 43, 838, 112, 50, 670, 22665, 9, 35, 480, 284, 5, 150, 4, 172, 112, 167, 21631, 336, 385, 39, 4, 172, 4536, 1111, 17, 546, 38, 13, 447, 4, 192, 50, 16, 6, 147, 2025, 19, 14, 22, 4, 1920, 4613, 469, 4, 22, 71, 87, 12, 16, 43, 530, 38, 76, 15, 13, 1247, 4, 22, 17, 515, 17, 12, 16, 626, 18, 19193, 5, 62, 386, 12, 8, 316, 8, 106, 5, 4, 2223, 5244, 16, 480, 66, 3785, 33, 4, 130, 12, 16, 38, 619, 5, 25, 124, 51, 36, 135, 48, 25, 1415, 33, 6, 22, 12, 215, 28, 77, 52, 5, 14, 407, 16, 82, 10311, 8, 4, 107, 117, 5952, 15, 256, 4, 31050, 7, 3766, 5, 723, 36, 71, 43, 530, 476, 26, 400, 317, 46, 7, 4, 12118, 1029, 13, 104, 88, 4, 381, 15, 297, 98, 32, 2071, 56, 26, 141, 6, 194, 7486, 18, 4, 226, 22, 21, 134, 476, 26, 480, 5, 144, 30, 5535, 18, 51, 36, 28, 224, 92, 25, 104, 4, 226, 65, 16, 38, 1334, 88, 12, 16, 283, 5, 16, 4472, 113, 103, 32, 15, 16, 5345, 19, 178, 32]
- 加载词汇索引数据集
- 返回字典类型:key是单词,values是索引
word_dict = imdb.get_word_index(path='imdb_word_index.json')
word_dict['woods']
1408
- 把情绪分类数据还原为文本
x_ints = x_train[0]
x_texts = []
for item in x_ints:
x_texts.append(list (word_dict.keys()) [list (word_dict.values()).index (item)])
print(x_texts)
['the', 'as', 'you', 'with', 'out', 'themselves', 'powerful', 'lets', 'loves', 'their', 'becomes', 'reaching', 'had', 'journalist', 'of', 'lot', 'from', 'anyone', 'to', 'have', 'after', 'out', 'atmosphere', 'never', 'more', 'room', 'titillate', 'it', 'so', 'heart', 'shows', 'to', 'years', 'of', 'every', 'never', 'going', 'villaronga', 'help', 'moments', 'or', 'of', 'every', 'chest', 'visual', 'movie', 'except', 'her', 'was', 'several', 'of', 'enough', 'more', 'with', 'is', 'now', 'current', 'film', 'as', 'you', 'of', 'mine', 'potentially', 'unfortunately', 'of', 'you', 'than', 'him', 'that', 'with', 'out', 'themselves', 'her', 'get', 'for', 'was', 'camp', 'of', 'you', 'movie', 'sometimes', 'movie', 'that', 'with', 'scary', 'but', 'pratfalls', 'to', 'story', 'wonderful', 'that', 'in', 'seeing', 'in', 'character', 'to', 'of', '70s', 'musicians', 'with', 'heart', 'had', 'shadows', 'they', 'of', 'here', 'that', 'with', 'her', 'serious', 'to', 'have', 'does', 'when', 'from', 'why', 'what', 'have', 'critics', 'they', 'is', 'you', 'that', "isn't", 'one', 'will', 'very', 'to', 'as', 'itself', 'with', 'other', 'tricky', 'in', 'of', 'seen', 'over', 'landed', 'for', 'anyone', 'of', "gilmore's", 'br', "show's", 'to', 'whether', 'from', 'than', 'out', 'themselves', 'history', 'he', 'name', 'half', 'some', 'br', 'of', "'n", 'odd', 'was', 'two', 'most', 'of', 'mean', 'for', '1', 'any', 'an', 'boat', 'she', 'he', 'should', 'is', 'thought', 'frog', 'but', 'of', 'script', 'you', 'not', 'while', 'history', 'he', 'heart', 'to', 'real', 'at', 'barrel', 'but', 'when', 'from', 'one', 'bit', 'then', 'have', 'two', 'of', 'script', 'their', 'with', 'her', 'nobody', 'most', 'that', 'with', "wasn't", 'to', 'with', 'armed', 'acting', 'watch', 'an', 'for', 'with', 'heartfelt', 'film', 'want', 'an']
- 字典的key与value交换
# 可以把字典的key与value交换
new_word_dict = {v : k for k, v in word_dict.items()}
print(new_word_dict[1])
the
七. 路透社主题分类数据集
7.1. 路透社主题分类数据集说明
- 数据集来源于路透社的 11,228 条新闻文本,总共分为 46 个主题。
- 与 IMDB 数据集一样,每条新闻都被编码为一个词索引的序列。
7.2. 加载数据集
7.2.1. 使用tensorflow加载数据集
- 因为google无法访问的缘故,所以直接从网络下载会出现失败。
import tensorflow.keras.datasets.reuters as reuters
# (x_train, y_train), (x_test, y_test) = reuters.load_data()
# x_train.shape, y_train.shape, x_test.shape, y_test.shape,
7.2.2. 使用本地数据集
- 下载
- 完整原始数据集下载:
http://www.daviddlewis.com/resources/testcollections/reuters21578/
- npz文件下载:使用网络搜索
- 完整原始数据集下载:
- 拷贝到用户主目录下:
-
~/.keras/datasets/
-
- 加载数据集代码
import tensorflow.keras.datasets.reuters as reuters
(x_train, y_train), (x_test, y_test) = reuters.load_data(path='reuters.npz')
x_train.shape, y_train.shape, x_test.shape, y_test.shape,
((8982,), (8982,), (2246,), (2246,))
- 数据集显示
print(x_train[0])
[1, 27595, 28842, 8, 43, 10, 447, 5, 25, 207, 270, 5, 3095, 111, 16, 369, 186, 90, 67, 7, 89, 5, 19, 102, 6, 19, 124, 15, 90, 67, 84, 22, 482, 26, 7, 48, 4, 49, 8, 864, 39, 209, 154, 6, 151, 6, 83, 11, 15, 22, 155, 11, 15, 7, 48, 9, 4579, 1005, 504, 6, 258, 6, 272, 11, 15, 22, 134, 44, 11, 15, 16, 8, 197, 1245, 90, 67, 52, 29, 209, 30, 32, 132, 6, 109, 15, 17, 12]
- 把路透社主题数据集还原为文本
word_dict = imdb.get_word_index(path='reuters_word_index.json')
x_ints = x_train[0]
x_texts = []
for item in x_ints:
x_texts.append(list (word_dict.keys()) [list (word_dict.values()).index (item)])
print(x_texts)
['the', 'kazuo', 'operandi', 'in', 'out', 'i', 'several', 'to', 'have', 'always', 'place', 'to', 'catholic', 'plot', 'with', 'women', 'horror', 'made', 'can', 'br', "don't", 'to', 'film', 'characters', 'is', 'film', 'does', 'for', 'made', 'can', 'great', 'you', 'lead', 'he', 'br', 'what', 'of', 'good', 'in', 'believable', 'or', 'comedy', 'work', 'is', 'old', 'is', 'first', 'this', 'for', 'you', '10', 'this', 'for', 'br', 'what', 'it', 'christians', 'ideas', "they're", 'is', 'although', 'is', 'different', 'this', 'for', 'you', 'while', 'has', 'this', 'for', 'with', 'in', 'between', 'military', 'made', 'can', 'very', 'all', 'comedy', 'at', 'an', 'say', 'is', 'being', 'for', 'movie', 'that']
八. 波士顿住房价格数据集
- 该数据集与上面一样。下面直接上代码(记得下载数据集文件,并拷贝到用户主目录下:
~/.keras/datasets/
)。
8.1. 波士顿住房价格数据集说明
- 该数据集是一个回归数据集。
- 每个类的观察值数量是均等的,共有 506 个观察,13 个输入变量和1个输出变量。每条数据包含房屋以及房屋周围的详细信息。
- 输入数据:
- CRIM:城镇人均犯罪率。
- ZN:住宅用地超过 25000 sq.ft. 的比例。
- INDUS:城镇非零售商用土地的比例。
- CHAS:查理斯河空变量(如果边界是河流,则为1;否则为0)。
- NOX:一氧化氮浓度。
- RM:住宅平均房间数。
- AGE:1940 年之前建成的自用房屋比例。
- DIS:到波士顿五个中心区域的加权距离。
- RAD:辐射性公路的接近指数。
- TAX:每 10000 美元的全值财产税率。
- PTRATIO:城镇师生比例。
- B:1000(Bk-0.63)^ 2,其中 Bk 指代城镇中黑人的比例。
- LSTAT:人口中地位低下者的比例。
- MEDV:自住房的平均房价,以千美元计。
- 输出数据:
- 预测平均值的基准性能的均方根误差(RMSE)是约 9.21 千美元。
- 输入数据:
8.2. 加载数据集
- 因为该数据集也是在google.com中,所以直接使用本地数据集。
- 也可以直接使用sklearn获取boston房价数据集。
import tensorflow.keras.datasets.boston_housing as boston_housing
(x_train, y_train), (x_test, y_test) = boston_housing.load_data(path='boston_housing.npz')
x_train.shape, y_train.shape, x_test.shape, y_test.shape,
((404, 13), (404,), (102, 13), (102,))
九. 备注
- 在调用的时候,容易出现的问题
ValueError: Object arrays cannot be loaded when allow_pickle=False
- 出问题的原因是numpy的版本不匹配造成的。
- 本文种,路透社主题分类数据集加载需要的numpy版本是1.16.2。在1.16.3环境中会报错。
ValueError: Object arrays cannot be loaded when allow_pickle=False
-
在sklearn能获取的数据集