TextCNN--文本多分类实践

数据集和代码都在微信公众号里面:一路向AI,回复文本分类即可获取,后续会不定期更新文本数据和其它文本分类模型~

在上一篇文章中,描述了TextCNN用于文本分类内在逻辑。今天应用这个模型来实践一个文本多分类Demo。

一、数据集

先介绍下数据集,数据集是从网上找到,具体来源找不到了。数据集有女性、体育、文学、校园4个文件夹组成,每个文件下有几百个txt文件,每个txt文件包含一行文本。

数据.png

首先读取每个文件夹的所有数据作为我们的训练数据,而数据标签则为每个txt文件所对应的文件夹名称,即:女性、体育、文学、校园4个类别,这边便于演示Demo,使用的数据量较小:其中体育下299条数据、女性下992条数据,文学下797条数据、校园下265条数据,总共2353条数据。这显示出文本数据类别不均衡,后续会对其进行一定的处理。

数据获取完之后,对数据进行按以下步骤进行处理:
1. 数据分词:使用jieba对文本进行分词。
2. 文本过滤:首先过滤掉非中文字符,例如19 30 或者www url等。其次使用停用词过滤掉一些无意义的中文字或词。
3. 数据填充:由于每行文本序列不一致,为了便于建模,需要把所有序列填充到相同的长度,这里初略选取序列最大长度为25,对长度小于25的序列后端补齐'0',对长度大于25的序列进行截断处理。

 def text_process(self, stopwords):
        for i in range(len(self.text)):
            # 使用正则表达式过滤非中文字符或数字
            pattern = re.compile(r'[^\u4e00-\u9fa5]')
            self.text[i] = re.sub(pattern, '', self.text[i])
            # jieba 分词
            cut_result = list(jieba.cut(self.text[i]))
            # 过滤停用词
            for j in range(len(cut_result)):
                if cut_result[j] in stopwords:
                    cut_result[j] = ''
                else:
                    # 把所有单词存到集合里
                    if cut_result[j] not in self.words:
                        self.words.append(cut_result[j])

            # 数据填充
            tmp = self.data_padding([x.strip() for x in list(cut_result) if x != '' and x != ' '])
            self.text[i] = ' '.join(tmp)

    def data_padding(self, sequence):
        # 序列小于最大长度填充'0'
        if len(sequence) <= self.max_len:
            sequence.extend(['0'] * (self.max_len - len(sequence)))
        else:
            # 序列大于最大长度进行截断
            sequence = sequence[:self.max_len]
        return sequence

4. 数据编码:对文本编码:可以在上述过程中,统计出分词后所有单词的个数,并把其映射为单词所对引的索引,然后把文本中的单词转换为其对应的索引;对于标签编码,可以把标签映射为{'体育':0,'女性':1,'文学':2,'校园':3}处理,也可以直接进行onehot编码: {'体育' : [1 0 0 0], '女性' :[0 1 0 0], '文学':[0 0 1 0], '校园' :[0 0 0 1]}。

def data_encoding(self, texts, labels):
        with open('../data/word2index.txt') as fp:
            word2index = json.load(fp)

        # 文本编码 -- 找到每个词对应的索引
        data = []
        for text in texts:
            text = text.split(' ')
            tmp = []
            for i in range(len(text)):
                text[i] = word2index.get(text[i], 0)
                tmp.append(text[i])
            data.extend(tmp)

        # 标签编码 
        label2ind = {}
        unique_label = list(set(labels))
        for index, label in enumerate(unique_label):
            label2ind[label] = index
        for i in range(len(labels)):
            labels[i] = label2ind[labels[i]]
        
        # one hot 编码
        # labels = to_categorical(labels, len(set(labels)), dtype=int)
        return np.array(data).reshape(-1, self.max_len), np.array(labels), word2index

5. 划分数据集:把文本转换成向量后,把数据集充分打乱之后,可以分为训练集和测试集。其中参数stratify = label 可以使划分的训练集和测试集各类比例与原始数据集分布一致,等同于各类等比例抽样。

 def split_data(self, data, label):
        # shuffle data
        data, label = shuffle(data, label, random_state=2020)
        X_train, X_text, y_train, y_test = train_test_split(data, label, test_size=0.1, random_state=2020,
                                                            stratify=label)
        return X_train, X_text, y_train, y_test

二、TextCNN模型

TextCNN的核心思想是抓取文本的局部特征:通过不同的卷积核尺寸(确切的说是卷积核高度)来提取文本的N-gram信息,然后通过最大池化操作来突出各个卷积操作提取的最关键信息(颇有一番Attention的味道),拼接后通过全连接层对特征进行组合,最后通过多分类损失函数来训练模型。

textCNN.jpg

在本模型中TextCNN代码如下:

def textcnn(wordsize, label, embedding_matrix=None):
    input = Input(shape=(data_process.max_len,))
    if embedding_matrix is None:
        embedding = Embedding(input_dim=wordsize,
                              output_dim=32,
                              input_length=data_process.max_len,
                              trainable=True)(input)
    else:  # 使用预训练矩阵初始化Embedding
        embedding = Embedding(input_dim=wordsize,
                              output_dim=32,
                              weights=[embedding_matrix],
                              input_length=data_process.max_len,
                              trainable=False)(input)

    convs = []
    for kernel_size in [2, 3, 4]:
        conv = Conv1D(64, kernel_size, activation='relu')(embedding)
        pool = MaxPooling1D(pool_size=data_process.max_len - kernel_size + 1)(conv)
        convs.append(pool)
        print(pool)
    concat = Concatenate()(convs)
    flattern = Flatten()(concat)
    dropout = Dropout(0.3)(flattern)
    output = Dense(len(set(label)), activation='softmax')(dropout)
    model = Model(input, output)
    model.compile(loss='sparse_categorical_crossentropy',
                  optimizer='adam',
                  metrics=['accuracy'])
    print(model.summary())
    return model

模型结构如下:


model.png

开头提到训练数据不均衡,对待数据不均衡通常采用的方式为过采样、降采样、数据加权等。前两种方式比较简单,不作过多介绍,这里介绍下数据加权,类别数量分布为{0: 893, 1: 238, 2: 269, 3: 717},通过样本总数除以每个类别总数来得到每个类别的样本权重,经过处理后得到:{0: 2.37, 1: 8.89, 2: 7.87, 3: 2.95},可以看到样本数目越多,样本权重就越小。

   def class_weight(self, y_train):
        count_res = dict(Counter(y_train))
        print(count_res)
        for key in count_res.keys():
            count_res[key] = round(len(y_train) / count_res[key], 2)
        return count_res

样本得到权重后,怎么使用呢?可以在模型训练的时候通过class_weight参数赋予给损失函数。

history = model.fit(X_train, y_train, validation_split=0.05, batch_size=32, epochs=20, class_weight=class_weight,
                        verbose=2)

三、评估结果

模型训练基本没有调参,在测试集上的准确率达到93%左右,其它一些评估指标结果如下:混淆矩阵结果行代表真实标签,列代表预测标签,可以看出把模型的第3类样本预测为第2类样本的数目最多为3个,可以挑选出这些Badcase分析下是什么原因造成的。


混淆矩阵结果: 
[[29  1  0  0]
 [ 1 75  2  2]
 [ 2  3 92  2]
 [ 0  1  1 25]]
 
 分类报告结果:
    precision    recall  f1-score   support
0       0.91      0.97      0.94        30
1       0.94      0.94      0.94        80
2       0.97      0.93      0.95        99
3       0.86      0.93      0.89        27

模型后续可改进的空间还有很多,比如说网格搜索+交叉验证,模型不均衡数据集的处理,预训练Embedding等等,后续有时间会逐渐完善。

由于时间比较仓促,文章写的有点乱,数据集和代码在公众号回复文本分类即可获取,后续会不断更新该系列文章,有兴趣的可以关注一波。

微信公众号.jpg
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 206,214评论 6 481
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 88,307评论 2 382
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 152,543评论 0 341
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 55,221评论 1 279
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 64,224评论 5 371
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,007评论 1 284
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,313评论 3 399
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,956评论 0 259
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 43,441评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,925评论 2 323
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,018评论 1 333
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,685评论 4 322
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,234评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,240评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,464评论 1 261
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,467评论 2 352
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,762评论 2 345

推荐阅读更多精彩内容