这次我们使用CNN中最经典的Lenet网络在mnist数据集上进行训练和预测。
卷积NN
主要有两部分组成,一部分是对输入图片特征提取,一部分是全连接网络,主要组成操作包括卷积、池化、激活等。-
Lenet网络模型
Lenet是提出比较早,能有效解决手写数字图片识别的卷积模型,模型结构如下:
其中,padding=valid代表非全0填充,输出图片尺寸=(输入尺寸-卷积核尺寸+1)/步长;padding=same代表全0填充,输出尺寸=输入尺寸/步长;pooling不改变深度。
对Lenet进行调整使其使用于mnist数据集,结构如下:
实现还是分三模块:forward,backwa,test,主要改变是在forward:
定义获得权重、偏执,增加对卷积,池化的函数。
按上层结构前向传播,返回预测值。
backward和test跟上一篇中改动不大,主要是要注意输入的大小:
输入占位大小改变
喂入的barch_size大小改变
同理,在test文件中,测试数据的大小也相应改变。
新手学习,欢迎指教!