首先,理解深度学习模型剪枝的理论基础:
深度学习模型,存在着大量冗余地节点(过参数化,Over-parameters),仅仅只有少部分(5-10%)权值参与着主要的计算,也就是说,仅仅训练小部分的权值参数就可以达到和原来网络相近的性能。一旦训练完成到了推理阶段,我们并不需要这么多的参数。
其次,对模型剪枝有两种方式,剪神经元,或者剪权重(不破坏原来的网络结构)。剪神经元对模型的影响较大,剪权重对模型的精度影响较小。当然模型不能瞎剪,因为这样精度可以会下降得很厉害以至无法接受,也有情况会在pruning后精度提高的,这说明原模型过似合(overfit)了,pruning起到了regularization的作用。
剪枝的核心问题是:如何有效地裁剪模型且最小化精度的损失?即如何判断哪个是重要的参数,哪个是不重要的参数,然后把不重要的(对模型影响小)参数剪掉。这样,剪枝问题又转换为如何评估参数重要性的问题。
一个简单且实践有效的思路是:magnitude-based weight pruning,即按照参数(或特征输出)绝对值大小来评估重要性,《Pruning Filters for Efficient ConvNets》
基于上述这个简单有效的思路,TensorFlow Model Optimization工具包,直接提供了prune_low_magnitude API函数来实现。范例代码如下:
import tempfile
import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
# Load dataset
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()
# Normalize the input image from 0-255 to 0-1
train_images, test_images = train_images / 255.0, test_images / 255.0
# expand to channel last: HWC
train_images = np.expand_dims(train_images, -1)
test_images = np.expand_dims(test_images, -1)
# Train a model without pruning
inputs = keras.Input(shape=(28,28,1))
x = keras.layers.Conv2D(12,(3,3),activation='relu')(inputs)
x = keras.layers.MaxPooling2D(pool_size=(2,2))(x)
x = keras.layers.Flatten()(x)
outputs = keras.layers.Dense(10)(x)
model_no_pruning = keras.Model(inputs, outputs)
model_no_pruning.summary()
# Train the digit classification model
model_no_pruning.compile(optimizer='adam',
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model_no_pruning.fit(
train_images,
train_labels,
epochs=4,
validation_split=0.1,
)
loss, no_pruning_acc = model_no_pruning.evaluate(test_images, test_labels, verbose=0)
print('model_no_pruning test accuracy:', no_pruning_acc)
_, keras_file = tempfile.mkstemp('.h5')
keras.models.save_model(model_no_pruning, keras_file, include_optimizer=False)
print('Saved baseline model to:', keras_file)
model_no_pruning.save("model_no_pruning.h5")
import tensorflow_model_optimization as tfmot
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
batch_size = 128
epochs = 2
validation_split = 0.1
num_images = test_images.shape[0] * (1-validation_split)
end_step = np.ceil(num_images/batch_size).astype(np.int32) * epochs
# Define model for pruning
pruning_params = {
'pruning_shedule':tfmot.sparsity.keras.PolynomialDecay(
initial_sparsity=0.5,
final_sparsity=0.8,
begin_step=0,
end_step=end_step
)
}
model_for_pruning = prune_low_magnitude(model_no_pruning, **pruning_params)
model_for_pruning.summary()
model_for_pruning.compile(
optimizer = "adam",
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
logdir = tempfile.mkdtemp()
callbacks = [
tfmot.sparsity.keras.UpdatePruningStep(),
tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]
model_for_pruning.fit(
train_images,
train_labels,
batch_size = batch_size,
epochs = epochs,
validation_split = validation_split,
callbacks = callbacks
)
loss, pruning_acc = model_for_pruning.evaluate(
test_images,
test_labels,
verbose=0
)
model_for_pruning.summary()
print('model_no_pruning test accuracy:', no_pruning_acc)
print('model_for_pruning test accuracy:', pruning_acc)
print("logdir:", logdir)
_, pruning_keras_file = tempfile.mkstemp('.h5')
keras.models.save_model(model_for_pruning, pruning_keras_file, include_optimizer=False)
print('Saved baseline model to:', keras_file)
print('Saved pruning model to:', pruning_keras_file)
model_for_pruning.save("model_for_pruning.h5")
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
_, pruned_keras_file = tempfile.mkstemp('.h5')
tf.keras.models.save_model(model_for_export, pruned_keras_file, include_optimizer=False)
print('Saved pruned Keras model to:', pruned_keras_file)
model_for_export.save("model_for_export.h5")
model_for_export.summary()
结论
-
剪枝后的模型,比剪枝前的小了很多
- 剪枝后的模型,精度没有损失
model_no_pruning test accuracy: 0.9779000282287598
model_for_pruning test accuracy: 0.9797000288963318
-
prune_low_magnitude 剪的是权重,对网络构架无影响,用Netron可以查看。