前言
在当前深度学习框架下的语音识别,有两个路线可以走:CNN或者RNN.但观当今研究,看cnn大行主流,故选cnn进行实践研究.
实践分两部分,
1,根据CNN对图像处理的巨大优势,将音频做MFCC处理,归一化,再输出为图片;
2, retrain mobilenet
音频mfcc处理
import numpy as np
import matplotlib.pyplot as plt
import librosa
import librosa.display
import os
import sklearn
%matplotlib inline
root = "/home/test/Downloads/audio-cats-and-dogs/cats_dogs"
def getwavfiles(path):
wav = []
for root, _, files in os.walk(path):
for file in files:
wav.append(os.path.join(path, file))
return wav
def waveplot(file):
x, fs = librosa.load(file)
librosa.display.waveplot(x, sr=fs)
return x, fs
def mfcc(x, fs):
mfccs = librosa.feature.mfcc(x, sr=fs)
# librosa.display.specshow(mfccs, sr=fs, x_axis='time')
return mfccs
def scalemfcc(mfccs, fs):
mfccs = sklearn.preprocessing.scale(mfccs, axis=1)
return mfccs
# print (mfccs.mean(axis=1))
# print (mfccs.var(axis=1))
def drawspec(mfcc, savename):
plt.figure(figsize=(10, 4))
librosa.display.specshow(mfccs, x_axis='time')
# plt.colorbar(format='%+2.0f dB')
plt.title('Mel spectrogram')
plt.tight_layout()
plt.savefig(savename)
def mfccfigs():
figs = []
for root, _, files in os.walk(path):
for file in files:
figs.append(os.path.join(path, file.replace(".wav", ".png")))
return figs
wavs = getwavfiles(root)
for wav in wavs:
x, fs = waveplot(wav)
mfccs = mfcc(x, fs)
mfccs = scalemfcc(mfccs, fs)
drawspec(mfccs, wav.replace(".wav", ".png"))
retrain MobileNet
可以进行retrain的前提是
1 已经下载号tensorfow代码,并且切换到与本机安装tensorflow同版本的分支
2 已经下载了retrain.py脚步,因为最新的tensorflow源码中没有这个脚步了
3 你已经准备号自己的训练数据了
python tensorflow/examples/image_retraining/retrain.py \
--image_dir ~/data/cats_dogs/ \
--architecture mobilenet_v2_1.4_224
模型重新训练最后结果如下:
INFO:tensorflow:2018-12-04 14:19:32.957046: Step 3980: Cross entropy = 0.074023
INFO:tensorflow:2018-12-04 14:19:33.017341: Step 3980: Validation accuracy = 88.0% (N=100)
INFO:tensorflow:2018-12-04 14:19:33.617590: Step 3990: Train accuracy = 99.0%
INFO:tensorflow:2018-12-04 14:19:33.617745: Step 3990: Cross entropy = 0.089500
INFO:tensorflow:2018-12-04 14:19:33.679496: Step 3990: Validation accuracy = 89.0% (N=100)
INFO:tensorflow:2018-12-04 14:19:34.222098: Step 3999: Train accuracy = 99.0%
INFO:tensorflow:2018-12-04 14:19:34.222252: Step 3999: Cross entropy = 0.083234
INFO:tensorflow:2018-12-04 14:19:34.283417: Step 3999: Validation accuracy = 87.0% (N=100)
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Restoring parameters from /tmp/_retrain_checkpoint
INFO:tensorflow:Final test accuracy = 80.8% (N=26)
INFO:tensorflow:Save final result to : /tmp/output_graph.pb
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Restoring parameters from /tmp/_retrain_checkpoint
INFO:tensorflow:Froze 378 variables.
INFO:tensorflow:Converted 378 variables to const ops.
模型转换
为了能在手机端用模型,最好将模型转换为tflite模型.
toco
--graph_def_file=output_graph.pb
--output_file=/tmp/mobilenet_v2.tflite --output_format=TFLITE
--input_arrays=Placeholder
--output_arrays=final_result
--input_shapes=1,299,299,3
--inference_type=QUANTIZED_UINT8
--inference_input_type=QUANTIZED_UINT8
--mean_value=128
--std_dev_values=128
--default_ranges_min=0
--default_ranges_max=6