Deeplearning4j文本分类——今日头条【原创】

本节课我们讲解一下文本分类,在文本分类中我们需要将文本预处理。

分词

public static void data(String source,String save) throws IOException {
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(  new FileInputStream(new File(source)), "UTF-8"));
        File saveFile = new File(save);
        if(!saveFile.exists()){
            saveFile.createNewFile();
        }
        OutputStreamWriter writerStream = new OutputStreamWriter( new FileOutputStream(saveFile), "UTF-8");
        BufferedWriter writer = new BufferedWriter(writerStream);
        String line = null;
        long startTime = System.currentTimeMillis();
        while ((line = bufferedReader.readLine()) != null) {
            String[] array = line.split("_!_");
            StringBuilder stringBuilder = new StringBuilder();
            for (Term term : HanLP.segment(array[3])) {
                if (stringBuilder.length() > 0) {
                    stringBuilder.append(" ");
                }
                stringBuilder.append(term.word.trim());
            }
            writer.write(Integer.parseInt(array[1].trim()) + "_!_" + stringBuilder.toString() + "\n");
        }
        writer.flush();
        writer.close();
        System.out.println(System.currentTimeMillis() - startTime);
        bufferedReader.close();
    }

使用分词工具将文本进行分词处理

分本向量处理工具

public static void dataSet(String filePath,String savePath) throws FileNotFoundException {
        SentenceIterator iter = new BasicLineIterator(filePath);
        TokenizerFactory t = new DefaultTokenizerFactory();
        t.setTokenPreProcessor(new CommonPreprocessor());
        VocabCache<VocabWord> cache = new AbstractCache<>();
        WeightLookupTable<VocabWord> table = new InMemoryLookupTable.Builder<VocabWord>().vectorLength(100)
                .useAdaGrad(false).cache(cache).build();

        Word2Vec vec = new Word2Vec.Builder()
                .elementsLearningAlgorithm("org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram")
                .minWordFrequency(0).iterations(1).epochs(20).layerSize(100).seed(42).windowSize(8).iterate(iter)
                .tokenizerFactory(t).lookupTable(table).vocabCache(cache).build();

        vec.fit();
        WordVectorSerializer.writeWord2VecModel(vec, savePath);
    }

构建训练数据

private static HashMap<String,DataSetIterator> dataSet() throws IOException {
        List<String> trainLabelList = new ArrayList<>();// 训练集label
        List<String> trainSentences = new ArrayList<>();// 训练集文本集合
        List<String> testLabelList = new ArrayList<>();// 测试集label
        List<String> testSentences = new ArrayList<>();//// 测试集文本集合
        Map<String, List<String>> map = new HashMap<>();


        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(
                new FileInputStream(new File(basePath+"toutiao_data_type_word.txt")), "UTF-8"));
        String line = null;
        int truncateReviewsToLength = 0;
        Random random = new Random(123);
        while ((line = bufferedReader.readLine()) != null) {
            String[] array = line.split("_!_");
            if (map.get(array[0]) == null) {
                map.put(array[0], new ArrayList<String>());
            }
            map.get(array[0]).add(array[1]);// 将样本中所有数据,按照类别归类
            int length = array[1].split(" ").length;
            if (length > truncateReviewsToLength) {
                truncateReviewsToLength = length;// 求样本中,句子的最大长度
            }
        }
        bufferedReader.close();
        for (Map.Entry<String, List<String>> entry : map.entrySet()) {
            for (String sentence : entry.getValue()) {
                if (random.nextInt() % 5 == 0) {// 每个类别抽取20%作为test集
                    testLabelList.add(entry.getKey());
                    testSentences.add(sentence);
                } else {
                    trainLabelList.add(entry.getKey());
                    trainSentences.add(sentence);
                }
            }

        }
        int batchSize = 64;
        Random rng = new Random(12345);
        Word2Vec word2Vec = WordVectorSerializer.readWord2VecModel(basePath+"toutiao_cat_data_dataset.txt");
        System.out.println("Loading word vectors and creating DataSetIterators");
        DataSetIterator trainIter = getDataSetIterator(word2Vec, batchSize, truncateReviewsToLength, trainLabelList, trainSentences, rng);
        DataSetIterator testIter = getDataSetIterator(word2Vec, batchSize, truncateReviewsToLength, testLabelList, testSentences, rng);
        HashMap<String,DataSetIterator> data = new HashMap<>();
        data.put("trainIter",trainIter);
        data.put("testIter",testIter);
        return data;

    }
    private static DataSetIterator getDataSetIterator(WordVectors wordVectors, int minibatchSize, int maxSentenceLength,
                                                      List<String> lableList, List<String> sentences, Random rng) {

        LabeledSentenceProvider sentenceProvider = new CollectionLabeledSentenceProvider(sentences, lableList, rng);

        return new CnnSentenceDataSetIterator.Builder().sentenceProvider(sentenceProvider).wordVectors(wordVectors)
                .minibatchSize(minibatchSize).maxSentenceLength(maxSentenceLength).useNormalizedWordVectors(false)
                .build();
    }

模型搭建

public static ComputationGraph model(int truncateReviewsToLength){
        int vectorSize = 100;
        int cnnLayerFeatureMaps = 50;
        PoolingType globalPoolingType = PoolingType.MAX;
        ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder().weightInit(WeightInit.RELU)
                .activation(Activation.LEAKYRELU).updater(new Nesterovs(0.01, 0.9))
                .convolutionMode(ConvolutionMode.Same).l2(0.0001).graphBuilder().addInputs("input")
                .addLayer("cnn3",
                        new ConvolutionLayer.Builder().kernelSize(3, vectorSize).stride(1, vectorSize)
                                .nOut(cnnLayerFeatureMaps).build(),
                        "input")
                .addLayer("cnn4",
                        new ConvolutionLayer.Builder().kernelSize(4, vectorSize).stride(1, vectorSize)
                                .nOut(cnnLayerFeatureMaps).build(),
                        "input")
                .addLayer("cnn5",
                        new ConvolutionLayer.Builder().kernelSize(5, vectorSize).stride(1, vectorSize)
                                .nOut(cnnLayerFeatureMaps).build(),
                        "input")
                .addLayer("cnn6",
                        new ConvolutionLayer.Builder().kernelSize(6, vectorSize).stride(1, vectorSize)
                                .nOut(cnnLayerFeatureMaps).build(),
                        "input")
                .addLayer("cnn3-stride2",
                        new ConvolutionLayer.Builder().kernelSize(3, vectorSize).stride(2, vectorSize)
                                .nOut(cnnLayerFeatureMaps).build(),
                        "input")
                .addLayer("cnn4-stride2",
                        new ConvolutionLayer.Builder().kernelSize(4, vectorSize).stride(2, vectorSize)
                                .nOut(cnnLayerFeatureMaps).build(),
                        "input")
                .addLayer("cnn5-stride2",
                        new ConvolutionLayer.Builder().kernelSize(5, vectorSize).stride(2, vectorSize)
                                .nOut(cnnLayerFeatureMaps).build(),
                        "input")
                .addLayer("cnn6-stride2",
                        new ConvolutionLayer.Builder().kernelSize(6, vectorSize).stride(2, vectorSize)
                                .nOut(cnnLayerFeatureMaps).build(),
                        "input")
                .addVertex("merge1", new MergeVertex(), "cnn3", "cnn4", "cnn5", "cnn6")
                .addLayer("globalPool1", new GlobalPoolingLayer.Builder().poolingType(globalPoolingType).build(),
                        "merge1")
                .addVertex("merge2", new MergeVertex(), "cnn3-stride2", "cnn4-stride2", "cnn5-stride2", "cnn6-stride2")
                .addLayer("globalPool2", new GlobalPoolingLayer.Builder().poolingType(globalPoolingType).build(),
                        "merge2")
                .addLayer("fc",
                        new DenseLayer.Builder().nOut(200).dropOut(0.5).activation(Activation.LEAKYRELU).build(),
                        "globalPool1", "globalPool2")
                .addLayer("out",
                        new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
                                .activation(Activation.SOFTMAX).nOut(15).build(),
                        "fc")
                .setOutputs("out").setInputTypes(InputType.convolutional(truncateReviewsToLength, vectorSize, 1))
                .build();

        ComputationGraph net = new ComputationGraph(config);
        return net;
    }

训练模型

 private static void train(ComputationGraph model,DataSetIterator trainIter,DataSetIterator testIter) throws IOException {
        UIServer uiServer = UIServer.getInstance();
        StatsStorage statsStorage = new InMemoryStatsStorage();
        uiServer.attach(statsStorage);
        model.setListeners(new ScoreIterationListener(100), new StatsListener(statsStorage, 20),
                new EvaluativeListener(testIter, 1, InvocationType.EPOCH_END));
        model.fit(trainIter, 10);
        ModelSerializer.writeModel(model, new File(basePath + "mlp.mod"), true);
    }
public static void main(String arg[]) throws IOException {
        //data(basePath+"toutiao_cat_data.txt",basePath+"toutiao_data_type_word.txt");
        //dataSet(basePath+"toutiao_cat_data.txt",basePath+"toutiao_cat_data_dataset.txt");

        ComputationGraph model = model(100);
        model.init();
        HashMap<String,DataSetIterator> data = dataSet();
        train(model,data.get("trainIter"),data.get("testIter"));
    }

下一节课讲解图片目标检测

本人诚接各类商业AI模型训练工作,如果您是一家公司,想借助AI解决当前服务问题,可以联系我。微信号:CompanyAiHelper

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 199,830评论 5 468
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 83,992评论 2 376
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 146,875评论 0 331
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 53,837评论 1 271
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 62,734评论 5 360
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,091评论 1 277
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,550评论 3 390
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,217评论 0 254
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,368评论 1 294
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,298评论 2 317
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,350评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,027评论 3 315
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,623评论 3 303
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,706评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 30,940评论 1 255
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 42,349评论 2 346
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 41,936评论 2 341

推荐阅读更多精彩内容