DL4J实战之二:鸢尾花分类

欢迎访问我的GitHub

https://github.com/zq2599/blog_demos

内容:所有原创文章分类汇总及配套源码,涉及Java、Docker、Kubernetes、DevOPS等;

本篇概览

  • 本文是《DL4J》实战的第二篇,前面做好了准备工作,接下来进入正式实战,本篇内容是经典的入门例子:鸢尾花分类
  • 下图是一朵鸢尾花,我们可以测量到它的四个特征:花瓣(petal)的宽和高,花萼(sepal)的 宽和高:
在这里插入图片描述
  • 鸢尾花有三种:Setosa、Versicolor、Virginica
  • 今天的实战是用前馈神经网络Feed-Forward Neural Network (FFNN)就行鸢尾花分类的模型训练和评估,在拿到150条鸢尾花的特征和分类结果后,我们先训练出模型,再评估模型的效果:
在这里插入图片描述

源码下载

名称 链接 备注
项目主页 https://github.com/zq2599/blog_demos 该项目在GitHub上的主页
git仓库地址(https) https://github.com/zq2599/blog_demos.git 该项目源码的仓库地址,https协议
git仓库地址(ssh) git@github.com:zq2599/blog_demos.git 该项目源码的仓库地址,ssh协议
  • 这个git项目中有多个文件夹,《DL4J实战》系列的源码在<font color="blue">dl4j-tutorials</font>文件夹下,如下图红框所示:
在这里插入图片描述
  • <font color="blue">dl4j-tutorials</font>文件夹下有多个子工程,本次实战代码在<font color="blue">dl4j-tutorials</font>目录下,如下图红框:
在这里插入图片描述

编码

  • 在<font color="blue">dl4j-tutorials</font>工程下新建子工程<font color="red">classifier-iris</font>,其pom.xml如下:
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <parent>
        <artifactId>dlfj-tutorials</artifactId>
        <groupId>com.bolingcavalry</groupId>
        <version>1.0-SNAPSHOT</version>
    </parent>
    <modelVersion>4.0.0</modelVersion>

    <artifactId>classifier-iris</artifactId>

    <properties>
        <maven.compiler.source>8</maven.compiler.source>
        <maven.compiler.target>8</maven.compiler.target>
    </properties>

    <dependencies>
        <dependency>
            <groupId>com.bolingcavalry</groupId>
            <artifactId>commons</artifactId>
            <version>${project.version}</version>
        </dependency>

        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
        </dependency>

        <dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>${nd4j.backend}</artifactId>
        </dependency>

        <dependency>
            <groupId>ch.qos.logback</groupId>
            <artifactId>logback-classic</artifactId>
        </dependency>
    </dependencies>
</project>
  • 上述pom.xml有一处需要注意的地方,就是<font color="blue">${nd4j.backend}</font>参数的值,该值在决定了后端线性代数计算是用CPU还是GPU,本篇为了简化操作选择了CPU(因为个人的显卡不同,代码里无法统一),对应的配置就是<font color="red">nd4j-native</font>;

  • 源码全部在Iris.java文件中,并且代码中已添加详细注释,就不再赘述了:

package com.bolingcavalry.classifier;

import com.bolingcavalry.commons.utils.DownloaderUtility;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.File;

/**
 * @author will (zq2599@gmail.com)
 * @version 1.0
 * @description: 鸢尾花训练
 * @date 2021/6/13 17:30
 */
@SuppressWarnings("DuplicatedCode")
@Slf4j
public class Iris {

    public static void main(String[] args) throws  Exception {

        //第一阶段:准备

        // 跳过的行数,因为可能是表头
        int numLinesToSkip = 0;
        // 分隔符
        char delimiter = ',';

        // CSV读取工具
        RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);

        // 下载并解压后,得到文件的位置
        String dataPathLocal = DownloaderUtility.IRISDATA.Download();

        log.info("鸢尾花数据已下载并解压至 : {}", dataPathLocal);

        // 读取下载后的文件
        recordReader.initialize(new FileSplit(new File(dataPathLocal,"iris.txt")));

        // 每一行的内容大概是这样的:5.1,3.5,1.4,0.2,0
        // 一共五个字段,从零开始算的话,标签在第四个字段
        int labelIndex = 4;

        // 鸢尾花一共分为三类
        int numClasses = 3;

        // 一共150个样本
        int batchSize = 150;    //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)

        // 加载到数据集迭代器中
        DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);

        DataSet allData = iterator.next();

        // 洗牌(打乱顺序)
        allData.shuffle();

        // 设定比例,150个样本中,百分之六十五用于训练
        SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);  //Use 65% of data for training

        // 训练用的数据集
        DataSet trainingData = testAndTrain.getTrain();

        // 验证用的数据集
        DataSet testData = testAndTrain.getTest();

        // 指定归一化器:独立地将每个特征值(和可选的标签值)归一化为0平均值和1的标准差。
        DataNormalization normalizer = new NormalizerStandardize();

        // 先拟合
        normalizer.fit(trainingData);

        // 对训练集做归一化
        normalizer.transform(trainingData);

        // 对测试集做归一化
        normalizer.transform(testData);

        // 每个鸢尾花有四个特征
        final int numInputs = 4;

        // 共有三种鸢尾花
        int outputNum = 3;

        // 随机数种子
        long seed = 6;

        //第二阶段:训练
        log.info("开始配置...");
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(seed)
            .activation(Activation.TANH)       // 激活函数选用标准的tanh(双曲正切)
            .weightInit(WeightInit.XAVIER)     // 权重初始化选用XAVIER:均值 0, 方差为 2.0/(fanIn + fanOut)的高斯分布
            .updater(new Sgd(0.1))  // 更新器,设置SGD学习速率调度器
            .l2(1e-4)                          // L2正则化配置
            .list()                            // 配置多层网络
            .layer(new DenseLayer.Builder().nIn(numInputs).nOut(3)  // 隐藏层
                .build())
            .layer(new DenseLayer.Builder().nIn(3).nOut(3)          // 隐藏层
                .build())
            .layer( new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)   // 损失函数:负对数似然
                .activation(Activation.SOFTMAX)                     // 输出层指定激活函数为:SOFTMAX
                .nIn(3).nOut(outputNum).build())
            .build();

        // 模型配置
        MultiLayerNetwork model = new MultiLayerNetwork(conf);

        // 初始化
        model.init();

        // 每一百次迭代打印一次分数(损失函数的值)
        model.setListeners(new ScoreIterationListener(100));

        long startTime = System.currentTimeMillis();

        log.info("开始训练");
        // 训练
        for(int i=0; i<1000; i++ ) {
            model.fit(trainingData);
        }
        log.info("训练完成,耗时[{}]ms", System.currentTimeMillis()-startTime);

        // 第三阶段:评估

        // 在测试集上评估模型
        Evaluation eval = new Evaluation(numClasses);
        INDArray output = model.output(testData.getFeatures());
        eval.eval(testData.getLabels(), output);

        log.info("评估结果如下\n" + eval.stats());
    }
}
  • 编码完成后,运行main方法,可见顺利完成训练并输出了评估结果,还有混淆矩阵用于辅助分析:
在这里插入图片描述
  • 至此,咱们的第一个实战就完成了,通过经典实例体验的DL4J训练和评估的常规步骤,对重要API也有了初步认识,接下来会继续实战,接触到更多的经典实例;

你不孤单,欣宸原创一路相伴

  1. Java系列
  2. Spring系列
  3. Docker系列
  4. kubernetes系列
  5. 数据库+中间件系列
  6. DevOps系列

欢迎关注公众号:程序员欣宸

微信搜索「程序员欣宸」,我是欣宸,期待与您一同畅游Java世界...
https://github.com/zq2599/blog_demos

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 204,445评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,889评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 151,047评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,760评论 1 276
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,745评论 5 367
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,638评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,011评论 3 398
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,669评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,923评论 1 299
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,655评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,740评论 1 330
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,406评论 4 320
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,995评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,961评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,197评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,023评论 2 350
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,483评论 2 342

推荐阅读更多精彩内容