如何使用prune_low_magnitude实现模型剪枝?

首先,理解深度学习模型剪枝的理论基础:
深度学习模型,存在着大量冗余地节点(过参数化,Over-parameters),仅仅只有少部分(5-10%)权值参与着主要的计算,也就是说,仅仅训练小部分的权值参数就可以达到和原来网络相近的性能。一旦训练完成到了推理阶段,我们并不需要这么多的参数。

其次,对模型剪枝有两种方式,剪神经元,或者剪权重(不破坏原来的网络结构)。剪神经元对模型的影响较大,剪权重对模型的精度影响较小。当然模型不能瞎剪,因为这样精度可以会下降得很厉害以至无法接受,也有情况会在pruning后精度提高的,这说明原模型过似合(overfit)了,pruning起到了regularization的作用。

剪权重对模型的精度影响较小

剪枝的核心问题是:如何有效地裁剪模型且最小化精度的损失?即如何判断哪个是重要的参数,哪个是不重要的参数,然后把不重要的(对模型影响小)参数剪掉。这样,剪枝问题又转换为如何评估参数重要性的问题。

一个简单且实践有效的思路是:magnitude-based weight pruning,即按照参数(或特征输出)绝对值大小来评估重要性,《Pruning Filters for Efficient ConvNets》

基于上述这个简单有效的思路,TensorFlow Model Optimization工具包,直接提供了prune_low_magnitude API函数来实现。范例代码如下:

import tempfile
import os 

import tensorflow as tf 
import numpy as np 

from tensorflow import keras 

# Load dataset
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()

# Normalize the input image from 0-255 to 0-1
train_images, test_images = train_images / 255.0, test_images / 255.0

# expand to channel last: HWC
train_images = np.expand_dims(train_images, -1)
test_images  = np.expand_dims(test_images, -1)

# Train a model without pruning
inputs = keras.Input(shape=(28,28,1))
x = keras.layers.Conv2D(12,(3,3),activation='relu')(inputs)
x = keras.layers.MaxPooling2D(pool_size=(2,2))(x)
x = keras.layers.Flatten()(x)
outputs = keras.layers.Dense(10)(x)

model_no_pruning = keras.Model(inputs, outputs)

model_no_pruning.summary()

# Train the digit classification model
model_no_pruning.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model_no_pruning.fit(
    train_images,
    train_labels,
    epochs=4,
    validation_split=0.1,
)

loss, no_pruning_acc = model_no_pruning.evaluate(test_images, test_labels, verbose=0)
print('model_no_pruning test accuracy:', no_pruning_acc)

_, keras_file = tempfile.mkstemp('.h5')
keras.models.save_model(model_no_pruning, keras_file, include_optimizer=False)
print('Saved baseline model to:', keras_file)
model_no_pruning.save("model_no_pruning.h5")


import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

batch_size = 128
epochs = 2
validation_split = 0.1

num_images = test_images.shape[0] * (1-validation_split)
end_step   = np.ceil(num_images/batch_size).astype(np.int32) * epochs

# Define model for pruning
pruning_params = {
    'pruning_shedule':tfmot.sparsity.keras.PolynomialDecay(
        initial_sparsity=0.5,
        final_sparsity=0.8,
        begin_step=0,
        end_step=end_step
    )
}

model_for_pruning = prune_low_magnitude(model_no_pruning, **pruning_params)

model_for_pruning.summary()

model_for_pruning.compile(
    optimizer = "adam",
    loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

logdir = tempfile.mkdtemp()

callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),
    tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]

model_for_pruning.fit(
    train_images,
    train_labels,
    batch_size = batch_size,
    epochs = epochs,
    validation_split = validation_split,
    callbacks = callbacks
)

loss, pruning_acc = model_for_pruning.evaluate(
    test_images,
    test_labels,
    verbose=0
)
model_for_pruning.summary()
print('model_no_pruning test accuracy:', no_pruning_acc)
print('model_for_pruning test accuracy:', pruning_acc)
print("logdir:", logdir)

_, pruning_keras_file = tempfile.mkstemp('.h5')
keras.models.save_model(model_for_pruning, pruning_keras_file, include_optimizer=False)
print('Saved baseline model to:', keras_file)
print('Saved pruning model to:', pruning_keras_file)
model_for_pruning.save("model_for_pruning.h5")

model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
_, pruned_keras_file = tempfile.mkstemp('.h5')
tf.keras.models.save_model(model_for_export, pruned_keras_file, include_optimizer=False)
print('Saved pruned Keras model to:', pruned_keras_file)
model_for_export.save("model_for_export.h5")
model_for_export.summary()

结论

  • 剪枝后的模型,比剪枝前的小了很多
    剪枝后的模型,比剪枝前的小了很多
  • 剪枝后的模型,精度没有损失

model_no_pruning test accuracy: 0.9779000282287598
model_for_pruning test accuracy: 0.9797000288963318

  • prune_low_magnitude 剪的是权重,对网络构架无影响,用Netron可以查看。
    剪枝前后模型比较
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 205,033评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 87,725评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 151,473评论 0 338
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,846评论 1 277
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,848评论 5 368
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,691评论 1 282
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,053评论 3 399
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,700评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 42,856评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,676评论 2 323
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,787评论 1 333
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,430评论 4 321
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,034评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,990评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,218评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,174评论 2 352
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,526评论 2 343

推荐阅读更多精彩内容