【Tool】Keras 实战 II: VGG16图片分类

上篇文章中我们自己设计了一个神经网络,从头开始训练用于图片分类。由于我们只使用了5000张图片,我们只取得了80%左右的准确率。这篇文章中,我们使用VGG16作为我们的base_model,在这个基础上进行训练。keras.applications模块中,有几种训练好的base model,可以直接用来进行迁移学习。通过设计include_top=False,我们可以获得不含全连接层的基础网络。通过在后面加入自己的custom layers,我们将其可以用于不同的分类任务。

# finetune from the base model VGG16
base_model = VGG16(include_top=False, weights='imagenet', input_shape=(150, 150, 3))
base_model.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 150, 150, 3)       0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 150, 150, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 150, 150, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 75, 75, 64)        0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 75, 75, 128)       73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 75, 75, 128)       147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 37, 37, 128)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 37, 37, 256)       295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 37, 37, 256)       590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 37, 37, 256)       590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 18, 18, 256)       0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 18, 18, 512)       1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 18, 18, 512)       2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 18, 18, 512)       2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 9, 9, 512)         0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 9, 9, 512)         2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 9, 9, 512)         2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 9, 9, 512)         2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 4, 4, 512)         0         
=================================================================
Total params: 14,714,688
Trainable params: 14,714,688
Non-trainable params: 0
_________________________________________________________________

此时我们有两种做法,一种是使用base_model作为特征提取器,不参与训练,只训练自己加入的全连接层,第二种是base_model也参加训练,此时我们训练的是一个end-to-end model。第二种方法要更难训练一点,我们先看看第一种。

VGG16 as feature extractor

keras中通过设置layers.trainable,我们可以控制哪些层是可以训练的,哪些层是不可以训练的。基础代码和上一篇文章一样。区别就是如何使用base_model和新加入的层作为自己的model。

import os
import numpy as np
from keras.models import Sequential, Model
from keras import layers
from keras.preprocessing.image import ImageDataGenerator
from keras import optimizers
from keras.applications.vgg16 import VGG16
from keras.utils.np_utils import to_categorical
from scipy.misc import imread, imresize
import matplotlib.pyplot as plt
imgs = []
labels = []
img_shape =(150,150)
# image generator
files = os.listdir('data/test')
# read 1000 files for the generator
for img_file in files[:1000]:
    img = imread('data/test/' + img_file).astype('float32')
    img = imresize(img, img_shape)
    imgs.append(img)

imgs = np.array(imgs)
train_gen = ImageDataGenerator(
     # rescale = 1./255,
     featurewise_center=True,
     featurewise_std_normalization=True,
     rotation_range=20,
     width_shift_range=0.2,
     height_shift_range=0.2,
     horizontal_flip=True)
val_gen = ImageDataGenerator(
     # rescale = 1./255,
     featurewise_center=True,
     featurewise_std_normalization=True)

train_gen.fit(imgs)
val_gen.fit(imgs)

# 4500 training images 
train_iter = train_gen.flow_from_directory('data/train',class_mode='binary',
                                            target_size=img_shape,   batch_size=16)
# 501 validation images
val_iter = val_gen.flow_from_directory('data/val', class_mode='binary', 
                                        target_size=img_shape, batch_size=16)

'''
# image generator debug
for x_batch, y_batch in img_iter:
    print(x_batch.shape)
    print(y_batch.shape)
    plt.imshow(x_batch[0])
    plt.show()
'''

# finetune from the base model VGG16
base_model = VGG16(include_top=False, weights='imagenet', input_shape=(150, 150, 3))
base_model.summary()

out = base_model.layers[-1].output
out = layers.Flatten()(out)
out  = layers.Dense(1024, activation='relu')(out)
# 因为前面输出的dense feature太多了,我们这里加入dropout layer来防止过拟合
out = layers.Dropout(0.5)(out)
out = layers.Dense(512, activation='relu')(out)
out = layers.Dropout(0.3)(out)
out = layers.Dense(1, activation='sigmoid')(out)
tuneModel = Model(inputs=base_model.input, outputs = out)
for layer in tuneModel.layers[:19]: # freeze the base model only use it as feature extractors
    layer.trainable = False
tuneModel.compile(loss='binary_crossentropy', optimizer=optimizers.RMSprop(lr=1e-4),
        metrics=['acc'])

history = tuneModel.fit_generator(
        generator=train_iter,
        steps_per_epoch=100,
        epochs=100,
        validation_data=val_iter,
        validation_steps=32
        )

acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1,101)
plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'r', label='Validation acc')
plt.legend()
plt.figure()
plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.legend()
plt.show()
# 输出
Epoch 1/100
100/100 [==============================] - 677s 7s/step - loss: 0.4214 - acc: 0.8113 - val_loss: 0.1659 - val_acc: 0.9311
Epoch 2/100
100/100 [==============================] - 786s 8s/step - loss: 0.2618 - acc: 0.8900 - val_loss: 0.1576 - val_acc: 0.9351

可以看到两个epoch之后就基本达到93%的accuracy,感觉像magic,在自己数据和计算资源有限的情况下finetune确实是一种很有效的提升效果的方式啊。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 204,053评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,527评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 150,779评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,685评论 1 276
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,699评论 5 366
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,609评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,989评论 3 396
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,654评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,890评论 1 298
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,634评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,716评论 1 330
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,394评论 4 319
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,976评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,950评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,191评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 44,849评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,458评论 2 342

推荐阅读更多精彩内容