第56章 字符卷积

卷积神经网络在图像处理领域获得了巨大的成功,其结合特征提取和目标训练为一体的模型,能够很好地利用已有的信息度和结果进行反馈训练。

对于文本识别的卷积神经网络来说,同样也是充分利用特征提取时提取的问本特征来计算文本特征权值大小。,归一化处理需要处理的数据。这样使得原来的文本信息抽象成一个向量化的样本集,之后将样本集和训练好的模板输入卷积神经网络进行处理。

本章将在上一章的基础上使用卷积神经网络实现文本分类问题,这里将采用基于字符和基于词嵌入的两种词卷积神经网络处理方法。实际上无论是基于字符还是基于词嵌入形式的处理方式都是可以互相转换的。本章只介绍基本的模型使用方法,更深入的应用请自行研究学习。

字符(非单词)文本的处理

本节介绍基于字符的CNN处理方法,基于单词的卷积处理将在下一章介绍。由于单词都是由字母组成的,因此可以简单地将单词拆分成字母的表示形式,


hello -> [‘h’, ‘e’, ‘l’, ‘l’, ‘o’]

这样可以看到,一个单词hello被人为拆成了’h’, ‘e’, ‘l’, ‘l’, ‘o’这5个字母。对于hello的处理有两种方法,

  • 独热编码。
  • 字符嵌入。

处理的结果,单词“hello”将被转成一个[5, n]的矩阵。本例将采用独热编码的方法处理。

使用卷积神经网络计算字符矩阵时,对于每个单词拆分后的数据,根据不同的长度对其进行卷积处理,提取出高层抽象概念。这样做的好处是不需要使用预训练好的词向量和语法句法结构信息。除此之外,还有一个好处是可以很容易推广到所有语言。使用CNN处理字符文本分类的原理如下图所示,


图1 使用CNN处理字符文本分类的原理
标题文本读取和转化

对于AG News数据集来说,每条新闻都有对应的分类,也有标题和正文。对于正文的抽取在前几章中已经介绍。这里直接对新闻标题进行处理,如下所示,

3 Wall St. Bears Claw Back Into the Black (Reuters)
3 Wall St. Bears Claw Back Into the Black (Reuters)
3 Carlyle Looks Toward Commercial Aerospace (Reuters)
3 Oil and Economy Cloud Stocks' Outlook (Reuters)
3 Iraq Halts Oil Exports from Main Southern Pipeline (Reuters)
3 Oil prices soar to all-time record, posing new menace to US economy (AFP)
3 Stocks End Up, But Near Year Lows (Reuters)
3 Money Funds Fell in Latest Week (AP)
3 Fed minutes show dissent over inflation (USATODAY.com)
3 Safety Net (Forbes.com)
3 Wall St. Bears Claw Back Into the Black

由于只对文本标题进行处理,因此在进行数据清洗时不用处理时不用处理停用词和进行词根还原。对于空格,由于是字符计算,因此不需要保留,直接删除即可。 修改原来代码如下,


def stop_words():
    
    try:
        _create_unverified_https_context = ssl._create_unverified_context
    except AttributeError:
        pass
    else:
        ssl._create_default_https_context = _create_unverified_https_context
    
    nltk.data.path.append("/tmp/")
    
    nltk.download("stopwords", download_dir = "/tmp/");
    
    stops = nltk.corpus.stopwords.words("English")
    
    print(stops)
    
    return stops

def purify(string: str, pattern: str = r"[^a-z]", replacement: str = " "):
    
    string = string.lower()
    
    string = re.sub(pattern = pattern, repl = replacement, string = string)
    # Replace the consucutive spaces with single space
    string = re.sub(pattern = r" +",  repl = replacement, string = string)
    
    # Trim the string
    string = string.strip()
    string = string + "eos"
    
    return string

def purify_stops(string: str, pattern: str = r"[^a-z0-9]", replacement: str = " ", stops = stop_words()):
    
    string = string.lower()
    
    string = re.sub(pattern = pattern, repl = replacement, string = string)
    # Replace the consucutive spaces with single space
    string = re.sub(pattern = r" +",  repl = replacement, string = string)
    
    # Trim the string
    string = string.strip()
    
    # Seperate the string with space, an array will be yielded
    strings = string.split(" ")
    
    strings = [word for word in strings if word not in stops]
    strings = [nltk.PorterStemmer().stem(word) for word in strings]
    
    strings.append("eos")
    strings = ["bos"] + strings
    
    return strings

def setup():
    
    with open("../../Shares/ag_news_csv/train.csv", "r") as handler:
        
        labels = []
        titles = []
        descriptions = []
        
        trains = csv.reader(handler)
        
        for line in trains:
            
            labels.append(jax.numpy.int32(line[0]))
            titles.append(purify(line[1]))
            descriptions.append(purify_stops(line[2]))
            
        return labels, titles, descriptions
文本独热编码处理

下面将生成的字符串进行独热编码处理,处理方式很简单,首先建立一个由26个字母组成的字符表,


def one_hot(strings):
    
    alphabet = "abcdefghijklmnopqrstuvwxyz"

将不同的字符获取字符表对应位置进行提取,根据提取的位置将对应的字符位置设置成1,其它为0。例如字符“c”在字符表中排行第3,那么获取的字符矩阵为,


[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

其它字符类似,代码如下,


alphabet = "abcdefghijklmnopqrstuvwxyz"

def one_hot(characters):
    
    array = numpy.array(characters)
    length = len(alphabet) + 1
    # jax.numpy.eye(N, M = None, K = 0, dtype) to create a 2-dimension array th>
    # the elements in diagonal will be filled out with 1s, others are 0s.
    eyes = numpy.eye(length)[array]
    
    return eyes

下一步就是将字符串按字母表中的顺序转换成数字序列,代码如下,


def indexes_of(characters):
    
    indexes = []
    
    for character in characters:
        
        index = alphabet.index(character)
        
        indexes.append(index)
        
    return indexes

def train():
    
    string = "hello"
    indexes = indexes_of(string)
    
    print("string =", string, ", indexes =", indexes)

if __name__ == "__main__":
    
    train()

这样生成结果如下,


string = hello , indexes = [7, 4, 11, 11, 14]

将代码整合到一起,如下,


import numpy

def one_hot(characters, alphabet = None):
    
    alphabet = ("abcdefghijklmnopqrstuvwxyz" if alphabet == None else alphabet)
    
    array = numpy.array(characters)
    length = len(alphabet)
    # jax.numpy.eye(N, M = None, K = 0, dtype) to create a 2-dimension array that
    # the elements in diagonal will be filled out with 1s, others are 0s.
    eyes = numpy.eye(length)[array]
    
    return eyes

def indexes_of(characters, alphabet = None):
    
    alphabet = ("abcdefghijklmnopqrstuvwxyz" if alphabet == None else alphabet)
    
    indexes = []
    
    for character in characters:
        
        index = alphabet.index(character)
        
        indexes.append(index)
        
    return indexes

def indexes_matrix(string):
    
    indexes = indexes_of(string)
    matrix = one_hot(indexes)
    
    return matrix

def train():
    
    #labels, titles, descriptions = AgNewsCsvReader.setup()
    #print(labels[: 5], titles[: 5], titles[: 5])
    
    string = "hello"
    indexes = indexes_matrix(string)
    
    print("string =", string, ", indexes =", indexes)

if __name__ == "__main__":
    
    train()

运行结果打印输出如下,


string = hello , indexes = [[0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0.]]

可以看到,单词“hello”被转换成一个[5, 26]大小的矩阵,供下一步处理。有了上面定义的方法,下一步就是对新闻标题进行独热编码处理。代码如下,


import numpy
import sys
sys.path.append("../52/")
import AgNewsCsvReader

def one_hot(characters, alphabet = None):
    
    alphabet = ("abcdefghijklmnopqrstuvwxyz" if alphabet == None else alphabet)
    
    array = numpy.array(characters)
    length = len(alphabet)
    # jax.numpy.eye(N, M = None, K = 0, dtype) to create a 2-dimension array that
    # the elements in diagonal will be filled out with 1s, others are 0s.
    eyes = numpy.eye(length)[array]
    
    return eyes

def indexes_of(characters, alphabet = None):
    
    alphabet = ("abcdefghijklmnopqrstuvwxyz" if alphabet == None else alphabet)
    
    indexes = []
    
    for character in characters:
        
        index = alphabet.index(character)
        
        indexes.append(index)
        
    return indexes

def indexes_matrix(string):
    
    indexes = indexes_of(string)
    matrix = one_hot(indexes)
    
    return matrix

def train():
    
    labels, titles, descriptions = AgNewsCsvReader.setup()
    #print(labels[: 5], titles[: 5], titles[: 5])
    
    for title in titles[: 10]:
    
        indexes = indexes_matrix(title)
        
        print("string =", title, ", indexes.shape =", indexes.shape)

if __name__ == "__main__":
    
    train()

运行结果打印输出如下,


string = wallstbearsclawbackintotheblackreuterseos , indexes.shape = (41, 27)
string = carlylelookstowardcommercialaerospacereuterseos , indexes.shape = (47, 27)
string = oilandeconomycloudstocksoutlookreuterseos , indexes.shape = (41, 27)
string = iraqhaltsoilexportsfrommainsouthernpipelinereuterseos , indexes.shape = (53, 27)
string = oilpricessoartoalltimerecordposingnewmenacetouseconomyafpeos , indexes.shape = (60, 27)
string = stocksendupbutnearyearlowsreuterseos , indexes.shape = (36, 27)
string = moneyfundsfellinlatestweekapeos , indexes.shape = (31, 27)
string = fedminutesshowdissentoverinflationusatodaycomeos , indexes.shape = (48, 27)
string = safetynetforbescomeos , indexes.shape = (21, 27)
string = wallstbearsclawbackintotheblackeos , indexes.shape = (34, 27)

不过,这里出现了一个新问题,对云不同长度的单词,矩阵的行长度不同。虽然卷积神经网络可以处理不同长度的字符串,但是在本例中还是希望以相同大小的矩阵作为输入进行计算。

生成文本矩阵时矩阵补全

对于不同长度的矩阵处理,简单的思路就是将其进行规范化处理:长的截断,短的补长。代码如下,


def align_string_matrix(string, maximum_length = 64, alphabet = "abcdefghijklmnopqrstuvwxyz "):
    
    length = len(string)
    
    if length > maximum_length:
        
        string = string[: maximum_length]
        matrix = indexes_matrix(string)
        
        return matrix
    
    else:
        
        matrix = indexes_matrix(string)
        length = maximum_length - length
        matrix_padded = numpy.zeros([length, len(alphabet)])
        
        matrix = numpy.concatenate([matrix, matrix_padded], axis = 0)
        
        return matrix

def train():
    
    '''

    string = "hello"
    indexes = indexes_matrix(string)
    
    print("string =", string, ", indexes =", indexes)
        
    '''
    
    labels, titles, descriptions = AgNewsCsvReader.setup()
    #print(labels[: 5], titles[: 5], titles[: 5])
    
    for title in titles[: 10]:
    
        indexes = align_string_matrix(title)
        
        print("string =", title, ", indexes.shape =", indexes.shape)
        
if __name__ == "__main__":
    
    train()

代码中,首先对不同长度的字符串进行处理,

  • 大于maximum_length(默认64,可根据需要自行设置该值)的字符串,截取前部分进行矩阵转换。
  • 长度小于maximum_length的,先生成由0构成的补全矩阵,再与原矩阵进行串接(numpy.concatenate)。

运行结果打印输出如下,


string = wall st bears claw back into the black reuterseos , indexes.shape = (64, 27)
string = carlyle looks toward commercial aerospace reuterseos , indexes.shape = (64, 27)
string = oil and economy cloud stocks outlook reuterseos , indexes.shape = (64, 27)
string = iraq halts oil exports from main southern pipeline reuterseos , indexes.shape = (64, 27)
string = oil prices soar to all time record posing new menace to us economy afpeos , indexes.shape = (64, 27)
string = stocks end up but near year lows reuterseos , indexes.shape = (64, 27)
string = money funds fell in latest week apeos , indexes.shape = (64, 27)
string = fed minutes show dissent over inflation usatoday comeos , indexes.shape = (64, 27)
string = safety net forbes comeos , indexes.shape = (64, 27)
string = wall st bears claw back into the blackeos , indexes.shape = (64, 27)

构建分类标签独热编码矩阵

对于分类的表示,同样可以使用独热编码进行处理。代码如下,


def one_hot_numbers(numbers):
    
    array = numpy.array(numbers)
    maximum = numpy.max(array) + 1
    
    eyes = numpy.eye(maximum)[array]
    
    return eyes

def train():
    
    '''

    string = "hello"
    indexes = indexes_matrix(string)
    
    print("string =", string, ", indexes =", indexes)
        
    '''
    
    labels, titles, descriptions = AgNewsCsvReader.setup()
    #print(labels[: 5], titles[: 5], titles[: 5])
    
    for title in titles[: 10]:
    
        indexes = align_string_matrix(title)
        
        print("string =", title, ", indexes.shape =", indexes.shape)
        
    one_hoted_labels = one_hot_numbers(labels)
    
    print("one_hoted_labels.shape = ", one_hoted_labels.shape)
        
if __name__ == "__main__":
    
    train()

截止到目前,全部代码如下,

../52/AgNewsCsvReader.py


import csv
import re
import jax
import ssl
import nltk
    
def stop_words():
    
    try:
        _create_unverified_https_context = ssl._create_unverified_context
    except AttributeError:
        pass
    else:
        ssl._create_default_https_context = _create_unverified_https_context
    
    nltk.data.path.append("/tmp/")
    
    nltk.download("stopwords", download_dir = "/tmp/");
    
    stops = nltk.corpus.stopwords.words("English")
    
    print(stops)
    
    return stops

def purify(string: str, pattern: str = r"[^a-z]", replacement: str = " "):
    
    string = string.lower()
    
    string = re.sub(pattern = pattern, repl = replacement, string = string)
    # Replace the consucutive spaces with single space
    string = re.sub(pattern = r" +",  repl = replacement, string = string)
    # string = re.sub(pattern = " ", repl = "", string = string)
    
    # Trim the string
    string = string.strip()
    string = string + "eos"
    
    return string

def purify_stops(string: str, pattern: str = r"[^a-z0-9]", replacement: str = " ", stops = stop_words()):
    
    string = string.lower()
    
    string = re.sub(pattern = pattern, repl = replacement, string = string)
    # Replace the consucutive spaces with single space
    string = re.sub(pattern = r" +",  repl = replacement, string = string)
    
    # Trim the string
    string = string.strip()
    
    # Seperate the string with space, an array will be yielded
    strings = string.split(" ")
    
    strings = [word for word in strings if word not in stops]
    strings = [nltk.PorterStemmer().stem(word) for word in strings]
    
    strings.append("eos")
    strings = ["bos"] + strings
    
    return strings

def setup():
    
    with open("../../Shares/ag_news_csv/train.csv", "r") as handler:
        
        labels = []
        titles = []
        descriptions = []
        
        trains = csv.reader(handler)
        trains = list(trains)
        
        for i in range(len(trains)):
            
            line = trains[I]
            
            labels.append(jax.numpy.int32(line[0]))
            titles.append(purify(line[1]))
            descriptions.append(purify_stops(line[2]))
            
        return labels, titles, descriptions

CharactersConvolutionalNeuralNetwork.py

import numpy
import sys
sys.path.append("../52/")
import AgNewsCsvReader

def one_hot(characters, alphabet = None):
    
    alphabet = ("abcdefghijklmnopqrstuvwxyz" if alphabet == None else alphabet)
    
    array = numpy.array(characters)
    length = len(alphabet)
    # jax.numpy.eye(N, M = None, K = 0, dtype) to create a 2-dimension array that
    # the elements in diagonal will be filled out with 1s, others are 0s.
    eyes = numpy.eye(length)[array]
    
    return eyes

def one_hot_numbers(numbers):
    
    array = numpy.array(numbers)
    maximum = numpy.max(array) + 1
    
    eyes = numpy.eye(maximum)[array]
    
    return eyes

def indexes_of(characters, alphabet = None):
    
    alphabet = ("abcdefghijklmnopqrstuvwxyz" if alphabet == None else alphabet)
    
    indexes = []
    
    for character in characters:
        
        index = alphabet.index(character)
        
        indexes.append(index)
        
    return indexes

def indexes_matrix(string, alphabet = "abcdefghijklmnopqrstuvwxyz "):
    
    indexes = indexes_of(string, alphabet)
    matrix = one_hot(indexes, alphabet)
    
    return matrix
def align_string_matrix(string, maximum_length = 64, alphabet = "abcdefghijklmnopqrstuvwxyz "):
    
    length = len(string)
    
    if length > maximum_length:
        
        string = string[: maximum_length]
        matrix = indexes_matrix(string)
        
        return matrix
    
    else:
        
        matrix = indexes_matrix(string)
        length = maximum_length - length
        matrix_padded = numpy.zeros([length, len(alphabet)])
        
        matrix = numpy.concatenate([matrix, matrix_padded], axis = 0)
        
        return matrix

def train():
    
    '''

    string = "hello"
    indexes = indexes_matrix(string)
    
    print("string =", string, ", indexes =", indexes)
        
    '''
    
    labels, titles, descriptions = AgNewsCsvReader.setup()
    #print(labels[: 5], titles[: 5], titles[: 5])
    trains = []
    
    for title in titles[: 10]:
    
        matrix = align_string_matrix(title)
        
        trains.append(matrix)
        
    trains = numpy.expand_dims(trains, axis = -1)
    labels = one_hot_numbers(labels)
    
    print("trains.shape =", trains.shape, ", labels.shape =", labels.shape)
        
if __name__ == "__main__":
    
    train()

代码中,首先通过csv库获取全文本数据,之后逐行将文本和标签读入,分别将其转化成独热编码矩阵后,再使用NumPy库将其对应的列表转换成NumPy格式。运行结果打印输出如下,


trains.shape = (120000, 64, 27, 1) , labels.shape = (120000, 5)

这里分别生成了训练集和标签数据的独热编码矩阵列表,

  • 训练集的维度为[120000, 64, 27, 1],第一个数字代表样本总数,第二个和第三个数字为生成的矩阵维度,最后一个1代表这里只使用1个通道。
  • 标签数据为[120000, 5],是一个二维矩阵,120000是样本的总数,5是类别。注意,one-hot是从0开始的,而标签的分类是从1开始的,因此会自动添加一个0的标签。

至此,文本数据处理结束。

一维卷积神经网络conv1d模型实现文本分类

在完成了文本的处理后,下面进入基于卷积神经网络的分类模型设计,如本章开始时提到了卷积处理字符文本分类的架构图所示,模型的设计有多种多样,根据类似的模型设计了一个由5层神经网络构成的文本分类模型,

层级 名称
1 Conv 3 x 3, 1 x 1
2 Conv 5 x 5, 1 x 1
3 Conv 3 x 3, 1 x 1
4 Fully Connected 256
5 Fully Connected 5

前3层是基于一维的卷积神经网络,后2层适用于分类任务的全连接层。代码如下,


def cnn(number_classes):
    
    return jax.example_libraries.stax.serial(
        
        jax.example_libraries.stax.Conv(1, (3, 3)),
        jax.example_libraries.stax.Relu,
        
        jax.example_libraries.stax.Conv(1, (5, 5)),
        jax.example_libraries.stax.Relu,
        
        jax.example_libraries.stax.Flatten,
        
        jax.example_libraries.stax.Dense(32),
        jax.example_libraries.stax.Relu,
        
        jax.example_libraries.stax.Dense(number_classes),
        
        jax.example_libraries.stax.LogSoftmax
        
        )

完整训练代码如下所示,

../52/AgNewsCsvReader.py


import csv
import re
import jax
import ssl
import nltk
    
def stop_words():
    
    try:
        _create_unverified_https_context = ssl._create_unverified_context
    except AttributeError:
        pass
    else:
        ssl._create_default_https_context = _create_unverified_https_context
    
    nltk.data.path.append("/tmp/")
    
    nltk.download("stopwords", download_dir = "/tmp/");
    
    stops = nltk.corpus.stopwords.words("English")
    
    print(stops)
    
    return stops

def purify(string: str, pattern: str = r"[^a-z]", replacement: str = " "):
    
    string = string.lower()
    
    string = re.sub(pattern = pattern, repl = replacement, string = string)
    # Replace the consucutive spaces with single space
    string = re.sub(pattern = r" +",  repl = replacement, string = string)
    # string = re.sub(pattern = " ", repl = "", string = string)
    
    # Trim the string
    string = string.strip()
    string = string + " eos"
    
    return string

def purify_stops(string: str, pattern: str = r"[^a-z0-9]", replacement: str = " ", stops = stop_words()):
    
    string = string.lower()
    
    string = re.sub(pattern = pattern, repl = replacement, string = string)
    # Replace the consucutive spaces with single space
    string = re.sub(pattern = r" +",  repl = replacement, string = string)
    
    # Trim the string
    string = string.strip()
    
    # Seperate the string with space, an array will be yielded
    strings = string.split(" ")
    
    strings = [word for word in strings if word not in stops]
    strings = [nltk.PorterStemmer().stem(word) for word in strings]
    
    strings.append("eos")
    strings = ["bos"] + strings
    
    return strings

def setup():
    
    with open("../../Shares/ag_news_csv/train.csv", "r") as handler:
        
        train_labels = []
        train_titles = []
        train_descriptions = []
        
        trains = csv.reader(handler)
        trains = list(trains)
        
        for i in range(len(trains)):
            
            line = trains[I]
            
            train_labels.append(jax.numpy.int32(line[0]))
            train_titles.append(purify(line[1]))
            train_descriptions.append(purify_stops(line[2]))
            
    with open("../../Shares/ag_news_csv/test.csv", "r") as handler:
        
        test_labels = []
        test_titles = []
        test_descriptions = []
        
        tests = csv.reader(handler)
        tests = list(tests)
        
        for i in range(len(tests)):
            
            line = tests[I]
            
            test_labels.append(jax.numpy.int32(line[0]))
            test_titles.append(purify(line[1]))
            test_descriptions.append(purify_stops(line[2]))
            
        return (train_labels, train_titles, train_descriptions), (test_labels, test_titles, test_descriptions)

    
def main():
    
    (train_labels, train_titles, train_descriptions), (test_labels, test_titles, test_descriptions) = setup()
    
    print((train_labels.shape, train_titles.shape, train_descriptions.shape), (test_labels.shape, test_titles.shape, test_descriptions.shape))
        
if __name__ == "__main__":
    
    main()

CharactersConvolutionalNeuralNetwork.py

import numpy
import jax
import jax.example_libraries.stax
import jax.example_libraries.optimizers
import sys

sys.path.append("../52/")

import AgNewsCsvReader


def one_hot(characters, alphabet):
        
    array = numpy.array(characters)
    length = len(alphabet)
    # jax.numpy.eye(N, M = None, K = 0, dtype) to create a 2-dimension array that
    # the elements in diagonal will be filled out with 1s, others are 0s.
    eyes = numpy.eye(length)[array]
    
    return eyes

def one_hot_numbers(numbers):
    
    array = numpy.array(numbers)
    maximum = numpy.max(array) + 1
    
    eyes = numpy.eye(maximum)[array]
    
    return eyes

def indexes_of(characters, alphabet):
        
    indexes = []
    
    for character in characters:
        
        index = alphabet.index(character)
        
        indexes.append(index)
        
    return indexes

def indexes_matrix(string, alphabet):
    
    indexes = indexes_of(string, alphabet)
    matrix = one_hot(indexes, alphabet)
    
    return matrix

def align_string_matrix(string, maximum_length = 64, alphabet = "abcdefghijklmnopqrstuvwxyz "):
    
    length = len(string)
    
    if length > maximum_length:
        
        string = string[: maximum_length]
        matrix = indexes_matrix(string, alphabet)
        
        return matrix
    
    else:
        
        matrix = indexes_matrix(string, alphabet)
        length = maximum_length - length
        matrix_padded = numpy.zeros([length, len(alphabet)])
        
        matrix = numpy.concatenate([matrix, matrix_padded], axis = 0)
        
        return matrix
    
def cnn(number_classes):
    
    return jax.example_libraries.stax.serial(
        
        jax.example_libraries.stax.Conv(1, (3, 3)),
        jax.example_libraries.stax.Relu,
        
        jax.example_libraries.stax.Conv(1, (5, 5)),
        jax.example_libraries.stax.Relu,
        
        jax.example_libraries.stax.Flatten,
        
        jax.example_libraries.stax.Dense(32),
        jax.example_libraries.stax.Relu,
        
        jax.example_libraries.stax.Dense(number_classes),
        
        jax.example_libraries.stax.LogSoftmax
        
        )

def setup():
    
    prng = jax.random.PRNGKey(15)
    
    (train_labels, train_titles, train_descriptions), (test_labels, test_titles, test_descriptions) = AgNewsCsvReader.setup()
    
    train_texts = []
    
    for title in train_titles:
    
        matrix = align_string_matrix(title)
        
        train_texts.append(matrix)
        
    train_texts = numpy.expand_dims(train_texts, axis = -1)
    train_labels = one_hot_numbers(train_labels)
    
    test_texts = []
    
    for title in test_titles:
    
        matrix = align_string_matrix(title)
        
        test_texts.append(matrix)
        
    test_texts = numpy.expand_dims(test_texts, axis = -1)
    test_labels = one_hot_numbers(test_labels)
    
    number_classes = 5
    input_shape = [-1, 64, 28, 1]
    batch_size = 100
    epochs = 5
    
    init_random_params, predict = cnn(number_classes)
    
    optimizer_init_function, optimizer_update_function, get_params_function = jax.example_libraries.optimizers.adam(step_size = 2.17e-4)
    _, init_params = init_random_params(prng, input_shape = input_shape)
    optimizer_state = optimizer_init_function(init_params)
    
    return (prng, number_classes, batch_size, epochs, init_params, optimizer_state), (init_random_params, optimizer_init_function, predict, optimizer_update_function, get_params_function), ((train_texts, train_labels), (test_texts, tes>
    
def verify_accuracy(params, batch, predict_function):
    
    inputs, targets = batch
    predictions = predict_function(params, inputs)
    class_ = jax.numpy.argmax(predictions, axis = 1)
    targets = jax.numpy.argmax(targets, axis = 1)
    
    return jax.numpy.sum(predictions == targets)

def loss_function(params, batch, predict_function):
    
    inputs, targets = batch
    
    predictions = predict_function(params, inputs)
    
    losses = -targets * predictions
    losses = jax.numpy.sum(losses, axis = 1)
    losses = jax.numpy.mean(losses)
    
    return losses

def update_function(i, optimizer_state, batch, get_params_function, optimizer_update_function, predict_function):
    
    params = get_params_function(optimizer_state)
    
    loss_function_grad = jax.grad(loss_function)
    gradients = loss_function_grad(params, batch, predict_function)
    
    return optimizer_update_function(i, gradients, optimizer_state)
    
def train():
    
    (prng, number_classes, batch_size, epochs, init_params, optimizer_state), (init_random_params, optimizer_init_function, predict, optimizer_update_function, get_params_function), ((train_texts, train_labels), (test_texts, test_label>
    
    print("train_texts.shape =", train_texts.shape, ", train_labels.shape =", train_labels.shape, ", test_texts.shape =", test_texts.shape, ", test_labels.shape =", test_labels.shape)
    
    train_batch_number = int(len(train_texts) / batch_size)
    test_batch_number = int(len(test_texts) / batch_size)
    
    for i in range(epochs):
        
        print(f"Epoch {i} started")
    
        for j in range(train_batch_number):
            
            start = j * batch_size
            end = (j + 1) * batch_size
            
            batch = (train_texts[start: end], train_labels[start: end])
            
            optimizer_state = update_function(i, optimizer_state, batch, get_params_function, optimizer_update_function, predict)
            
            if (j + 1) % 10 == 0:
                
                params = get_params_function(optimizer_state)
                losses = loss_function(params, batch)
                
                print("Losses now is =", losses)
                
        params = get_params_function(optimizer_state)
        
        print(f"Epoch {i} compeleted")
        
        accuracies = []
        predictions = 0.0
        
        for j in range(test_batch_number):
            
            start = j * batch_size
            end = (j + 1) * batch_size
            
            batch = (test_texts[start: end], test_labels[start: end])
            
            predictions += verify_accuracy(params, batch)
            
        accuracies.append(predictions / float(len(train_texts)))
        
        print(f"Training accuracies =", accuracies)
                 
if __name__ == "__main__":
    
    train()

首先获取训练集和测试集,接下来定义预损失函数、优化器,与ResNet类似,不再赘述。

结论

本章基于AG News新闻标题和分类标签,使用一层卷积和全连接层建构了一个文本分类模型。注意,这个示例知识为了说明问题,效果并不一定好。

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 203,098评论 5 476
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,213评论 2 380
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 149,960评论 0 336
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,519评论 1 273
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,512评论 5 364
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,533评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,914评论 3 395
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,574评论 0 256
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,804评论 1 296
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,563评论 2 319
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,644评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,350评论 4 318
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,933评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,908评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,146评论 1 259
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 42,847评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,361评论 2 342

推荐阅读更多精彩内容