本节课我们讲解手写图片分类识别
训练数据
public DataSetIteratorimageDateSet(String dataLocalPath,int seed,int width,int height,int channels)throws IOException {
String [] allowedExtensions = BaseImageLoader.ALLOWED_FORMATS;
Random randNumGen =new Random(seed);
File parentDir=new File(dataLocalPath);
FileSplit filesInDir =new FileSplit(parentDir, allowedExtensions, randNumGen);
ParentPathLabelGenerator labelMaker =new ParentPathLabelGenerator();
BalancedPathFilter pathFilter =new BalancedPathFilter(randNumGen, allowedExtensions, labelMaker);
InputSplit[] filesInDirSplit = filesInDir.sample(pathFilter, 100);
System.out.println("---------"+filesInDirSplit.length);
InputSplit trainData = filesInDirSplit[0];
ImageRecordReader recordReader =new ImageRecordReader(height,width,channels,labelMaker);
ImageTransform transform =new MultiImageTransform(randNumGen,new ShowImageTransform("Display - before "));
recordReader.initialize(trainData,transform);
int outputNum = recordReader.numLabels();
int batchSize =10;
int labelIndex =1;
DataSetIterator dataIter =new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, outputNum);
return dataIter;
}
模型搭建
public MultiLayerNetworkmodel(){
try{
MultiLayerConfiguration.Builder builder =new NeuralNetConfiguration.Builder()
.seed(12345)
.weightInit(WeightInit.XAVIER)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(Updater.ADAM)
.list()
.layer(0, new ConvolutionLayer.Builder(5, 5)
.nIn(1)
.stride(1, 1)
.nOut(32)
.activation(Activation.LEAKYRELU)
.build())
.layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2,2)
.stride(2,2)
.build())
.layer(2, new ConvolutionLayer.Builder(5, 5)
.stride(1, 1)
.nOut(64)
.activation(Activation.LEAKYRELU)
.build())
.layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2,2)
.stride(2,2)
.build())
.layer(4, new DenseLayer.Builder().activation(Activation.LEAKYRELU)
.nOut(500).build())
.layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(10)
.activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutionalFlat(28, 28, 1));
MultiLayerConfiguration conf = builder.build();
MultiLayerNetwork model =new MultiLayerNetwork(conf);
return model;
}catch (Exception e){
e.printStackTrace();
}
return null;
}
开始训练
public static Booleantrain(MultiLayerNetwork mlp,DataSetIterator trainIter,DataSetIterator testIter){
for(int i =0; i <1; ++i ){
mlp.fit(trainIter); //训练模型
Evaluation trainEval = mlp.evaluate(trainIter); //在验证集上进行准确性测试
Evaluation testEval = mlp.evaluate(testIter);
trainIter.reset();
testIter.reset();
}
return Boolean.TRUE;
}
监督学习并保存
public static void main(String arg[])throws IOException {
DataSetIterator trainData =lesson4.imageDateSet(basePath+"training/",12345,28,28,1);
DataSetIterator testData =lesson4.imageDateSet(basePath+"testing/",12345,28,28,1);
MultiLayerNetwork model =lesson4.model();
File stateFile =new File(basePath+"state");
stateFile.createNewFile();
UIServer uiServer = UIServer.getInstance();
StatsStorage statsStorage =new FileStatsStorage(new File(System.getProperty("java.io.tmpdir"), "ui-stats.dl4j"));
int listenerFrequency =1;
model.setListeners(new StatsListener(statsStorage, listenerFrequency));
uiServer.attach(statsStorage);
train(model,trainData,testData);
ModelSerializer.writeModel(model, new File(basePath +"mlp.mod"), true);
}
关于训练数据的下载可见第二节课所讲内容。
下一节课讲解文本分类
本人诚接各类商业AI模型训练工作,如果您是一家公司,想借助AI解决当前服务问题,可以联系我。微信号:CompanyAiHelper