卷积神经网络在图像处理领域获得了巨大的成功,其结合特征提取和目标训练为一体的模型,能够很好地利用已有的信息度和结果进行反馈训练。
对于文本识别的卷积神经网络来说,同样也是充分利用特征提取时提取的问本特征来计算文本特征权值大小。,归一化处理需要处理的数据。这样使得原来的文本信息抽象成一个向量化的样本集,之后将样本集和训练好的模板输入卷积神经网络进行处理。
本章将在上一章的基础上使用卷积神经网络实现文本分类问题,这里将采用基于字符和基于词嵌入的两种词卷积神经网络处理方法。实际上无论是基于字符还是基于词嵌入形式的处理方式都是可以互相转换的。本章只介绍基本的模型使用方法,更深入的应用请自行研究学习。
字符(非单词)文本的处理
本节介绍基于字符的CNN处理方法,基于单词的卷积处理将在下一章介绍。由于单词都是由字母组成的,因此可以简单地将单词拆分成字母的表示形式,
hello -> [‘h’, ‘e’, ‘l’, ‘l’, ‘o’]
这样可以看到,一个单词hello被人为拆成了’h’, ‘e’, ‘l’, ‘l’, ‘o’这5个字母。对于hello的处理有两种方法,
- 独热编码。
- 字符嵌入。
处理的结果,单词“hello”将被转成一个[5, n]的矩阵。本例将采用独热编码的方法处理。
使用卷积神经网络计算字符矩阵时,对于每个单词拆分后的数据,根据不同的长度对其进行卷积处理,提取出高层抽象概念。这样做的好处是不需要使用预训练好的词向量和语法句法结构信息。除此之外,还有一个好处是可以很容易推广到所有语言。使用CNN处理字符文本分类的原理如下图所示,
标题文本读取和转化
对于AG News数据集来说,每条新闻都有对应的分类,也有标题和正文。对于正文的抽取在前几章中已经介绍。这里直接对新闻标题进行处理,如下所示,
3 | Wall St. Bears Claw Back Into the Black (Reuters) |
---|---|
3 | Wall St. Bears Claw Back Into the Black (Reuters) |
3 | Carlyle Looks Toward Commercial Aerospace (Reuters) |
3 | Oil and Economy Cloud Stocks' Outlook (Reuters) |
3 | Iraq Halts Oil Exports from Main Southern Pipeline (Reuters) |
3 | Oil prices soar to all-time record, posing new menace to US economy (AFP) |
3 | Stocks End Up, But Near Year Lows (Reuters) |
3 | Money Funds Fell in Latest Week (AP) |
3 | Fed minutes show dissent over inflation (USATODAY.com) |
3 | Safety Net (Forbes.com) |
3 | Wall St. Bears Claw Back Into the Black |
由于只对文本标题进行处理,因此在进行数据清洗时不用处理时不用处理停用词和进行词根还原。对于空格,由于是字符计算,因此不需要保留,直接删除即可。 修改原来代码如下,
def stop_words():
try:
_create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
pass
else:
ssl._create_default_https_context = _create_unverified_https_context
nltk.data.path.append("/tmp/")
nltk.download("stopwords", download_dir = "/tmp/");
stops = nltk.corpus.stopwords.words("English")
print(stops)
return stops
def purify(string: str, pattern: str = r"[^a-z]", replacement: str = " "):
string = string.lower()
string = re.sub(pattern = pattern, repl = replacement, string = string)
# Replace the consucutive spaces with single space
string = re.sub(pattern = r" +", repl = replacement, string = string)
# Trim the string
string = string.strip()
string = string + "eos"
return string
def purify_stops(string: str, pattern: str = r"[^a-z0-9]", replacement: str = " ", stops = stop_words()):
string = string.lower()
string = re.sub(pattern = pattern, repl = replacement, string = string)
# Replace the consucutive spaces with single space
string = re.sub(pattern = r" +", repl = replacement, string = string)
# Trim the string
string = string.strip()
# Seperate the string with space, an array will be yielded
strings = string.split(" ")
strings = [word for word in strings if word not in stops]
strings = [nltk.PorterStemmer().stem(word) for word in strings]
strings.append("eos")
strings = ["bos"] + strings
return strings
def setup():
with open("../../Shares/ag_news_csv/train.csv", "r") as handler:
labels = []
titles = []
descriptions = []
trains = csv.reader(handler)
for line in trains:
labels.append(jax.numpy.int32(line[0]))
titles.append(purify(line[1]))
descriptions.append(purify_stops(line[2]))
return labels, titles, descriptions
文本独热编码处理
下面将生成的字符串进行独热编码处理,处理方式很简单,首先建立一个由26个字母组成的字符表,
def one_hot(strings):
alphabet = "abcdefghijklmnopqrstuvwxyz"
将不同的字符获取字符表对应位置进行提取,根据提取的位置将对应的字符位置设置成1,其它为0。例如字符“c”在字符表中排行第3,那么获取的字符矩阵为,
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
其它字符类似,代码如下,
alphabet = "abcdefghijklmnopqrstuvwxyz"
def one_hot(characters):
array = numpy.array(characters)
length = len(alphabet) + 1
# jax.numpy.eye(N, M = None, K = 0, dtype) to create a 2-dimension array th>
# the elements in diagonal will be filled out with 1s, others are 0s.
eyes = numpy.eye(length)[array]
return eyes
下一步就是将字符串按字母表中的顺序转换成数字序列,代码如下,
def indexes_of(characters):
indexes = []
for character in characters:
index = alphabet.index(character)
indexes.append(index)
return indexes
def train():
string = "hello"
indexes = indexes_of(string)
print("string =", string, ", indexes =", indexes)
if __name__ == "__main__":
train()
这样生成结果如下,
string = hello , indexes = [7, 4, 11, 11, 14]
将代码整合到一起,如下,
import numpy
def one_hot(characters, alphabet = None):
alphabet = ("abcdefghijklmnopqrstuvwxyz" if alphabet == None else alphabet)
array = numpy.array(characters)
length = len(alphabet)
# jax.numpy.eye(N, M = None, K = 0, dtype) to create a 2-dimension array that
# the elements in diagonal will be filled out with 1s, others are 0s.
eyes = numpy.eye(length)[array]
return eyes
def indexes_of(characters, alphabet = None):
alphabet = ("abcdefghijklmnopqrstuvwxyz" if alphabet == None else alphabet)
indexes = []
for character in characters:
index = alphabet.index(character)
indexes.append(index)
return indexes
def indexes_matrix(string):
indexes = indexes_of(string)
matrix = one_hot(indexes)
return matrix
def train():
#labels, titles, descriptions = AgNewsCsvReader.setup()
#print(labels[: 5], titles[: 5], titles[: 5])
string = "hello"
indexes = indexes_matrix(string)
print("string =", string, ", indexes =", indexes)
if __name__ == "__main__":
train()
运行结果打印输出如下,
string = hello , indexes = [[0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0.]
[0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0.]]
可以看到,单词“hello”被转换成一个[5, 26]大小的矩阵,供下一步处理。有了上面定义的方法,下一步就是对新闻标题进行独热编码处理。代码如下,
import numpy
import sys
sys.path.append("../52/")
import AgNewsCsvReader
def one_hot(characters, alphabet = None):
alphabet = ("abcdefghijklmnopqrstuvwxyz" if alphabet == None else alphabet)
array = numpy.array(characters)
length = len(alphabet)
# jax.numpy.eye(N, M = None, K = 0, dtype) to create a 2-dimension array that
# the elements in diagonal will be filled out with 1s, others are 0s.
eyes = numpy.eye(length)[array]
return eyes
def indexes_of(characters, alphabet = None):
alphabet = ("abcdefghijklmnopqrstuvwxyz" if alphabet == None else alphabet)
indexes = []
for character in characters:
index = alphabet.index(character)
indexes.append(index)
return indexes
def indexes_matrix(string):
indexes = indexes_of(string)
matrix = one_hot(indexes)
return matrix
def train():
labels, titles, descriptions = AgNewsCsvReader.setup()
#print(labels[: 5], titles[: 5], titles[: 5])
for title in titles[: 10]:
indexes = indexes_matrix(title)
print("string =", title, ", indexes.shape =", indexes.shape)
if __name__ == "__main__":
train()
运行结果打印输出如下,
string = wallstbearsclawbackintotheblackreuterseos , indexes.shape = (41, 27)
string = carlylelookstowardcommercialaerospacereuterseos , indexes.shape = (47, 27)
string = oilandeconomycloudstocksoutlookreuterseos , indexes.shape = (41, 27)
string = iraqhaltsoilexportsfrommainsouthernpipelinereuterseos , indexes.shape = (53, 27)
string = oilpricessoartoalltimerecordposingnewmenacetouseconomyafpeos , indexes.shape = (60, 27)
string = stocksendupbutnearyearlowsreuterseos , indexes.shape = (36, 27)
string = moneyfundsfellinlatestweekapeos , indexes.shape = (31, 27)
string = fedminutesshowdissentoverinflationusatodaycomeos , indexes.shape = (48, 27)
string = safetynetforbescomeos , indexes.shape = (21, 27)
string = wallstbearsclawbackintotheblackeos , indexes.shape = (34, 27)
不过,这里出现了一个新问题,对云不同长度的单词,矩阵的行长度不同。虽然卷积神经网络可以处理不同长度的字符串,但是在本例中还是希望以相同大小的矩阵作为输入进行计算。
生成文本矩阵时矩阵补全
对于不同长度的矩阵处理,简单的思路就是将其进行规范化处理:长的截断,短的补长。代码如下,
def align_string_matrix(string, maximum_length = 64, alphabet = "abcdefghijklmnopqrstuvwxyz "):
length = len(string)
if length > maximum_length:
string = string[: maximum_length]
matrix = indexes_matrix(string)
return matrix
else:
matrix = indexes_matrix(string)
length = maximum_length - length
matrix_padded = numpy.zeros([length, len(alphabet)])
matrix = numpy.concatenate([matrix, matrix_padded], axis = 0)
return matrix
def train():
'''
string = "hello"
indexes = indexes_matrix(string)
print("string =", string, ", indexes =", indexes)
'''
labels, titles, descriptions = AgNewsCsvReader.setup()
#print(labels[: 5], titles[: 5], titles[: 5])
for title in titles[: 10]:
indexes = align_string_matrix(title)
print("string =", title, ", indexes.shape =", indexes.shape)
if __name__ == "__main__":
train()
代码中,首先对不同长度的字符串进行处理,
- 大于maximum_length(默认64,可根据需要自行设置该值)的字符串,截取前部分进行矩阵转换。
- 长度小于maximum_length的,先生成由0构成的补全矩阵,再与原矩阵进行串接(numpy.concatenate)。
运行结果打印输出如下,
string = wall st bears claw back into the black reuterseos , indexes.shape = (64, 27)
string = carlyle looks toward commercial aerospace reuterseos , indexes.shape = (64, 27)
string = oil and economy cloud stocks outlook reuterseos , indexes.shape = (64, 27)
string = iraq halts oil exports from main southern pipeline reuterseos , indexes.shape = (64, 27)
string = oil prices soar to all time record posing new menace to us economy afpeos , indexes.shape = (64, 27)
string = stocks end up but near year lows reuterseos , indexes.shape = (64, 27)
string = money funds fell in latest week apeos , indexes.shape = (64, 27)
string = fed minutes show dissent over inflation usatoday comeos , indexes.shape = (64, 27)
string = safety net forbes comeos , indexes.shape = (64, 27)
string = wall st bears claw back into the blackeos , indexes.shape = (64, 27)
构建分类标签独热编码矩阵
对于分类的表示,同样可以使用独热编码进行处理。代码如下,
def one_hot_numbers(numbers):
array = numpy.array(numbers)
maximum = numpy.max(array) + 1
eyes = numpy.eye(maximum)[array]
return eyes
def train():
'''
string = "hello"
indexes = indexes_matrix(string)
print("string =", string, ", indexes =", indexes)
'''
labels, titles, descriptions = AgNewsCsvReader.setup()
#print(labels[: 5], titles[: 5], titles[: 5])
for title in titles[: 10]:
indexes = align_string_matrix(title)
print("string =", title, ", indexes.shape =", indexes.shape)
one_hoted_labels = one_hot_numbers(labels)
print("one_hoted_labels.shape = ", one_hoted_labels.shape)
if __name__ == "__main__":
train()
截止到目前,全部代码如下,
../52/AgNewsCsvReader.py
import csv
import re
import jax
import ssl
import nltk
def stop_words():
try:
_create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
pass
else:
ssl._create_default_https_context = _create_unverified_https_context
nltk.data.path.append("/tmp/")
nltk.download("stopwords", download_dir = "/tmp/");
stops = nltk.corpus.stopwords.words("English")
print(stops)
return stops
def purify(string: str, pattern: str = r"[^a-z]", replacement: str = " "):
string = string.lower()
string = re.sub(pattern = pattern, repl = replacement, string = string)
# Replace the consucutive spaces with single space
string = re.sub(pattern = r" +", repl = replacement, string = string)
# string = re.sub(pattern = " ", repl = "", string = string)
# Trim the string
string = string.strip()
string = string + "eos"
return string
def purify_stops(string: str, pattern: str = r"[^a-z0-9]", replacement: str = " ", stops = stop_words()):
string = string.lower()
string = re.sub(pattern = pattern, repl = replacement, string = string)
# Replace the consucutive spaces with single space
string = re.sub(pattern = r" +", repl = replacement, string = string)
# Trim the string
string = string.strip()
# Seperate the string with space, an array will be yielded
strings = string.split(" ")
strings = [word for word in strings if word not in stops]
strings = [nltk.PorterStemmer().stem(word) for word in strings]
strings.append("eos")
strings = ["bos"] + strings
return strings
def setup():
with open("../../Shares/ag_news_csv/train.csv", "r") as handler:
labels = []
titles = []
descriptions = []
trains = csv.reader(handler)
trains = list(trains)
for i in range(len(trains)):
line = trains[I]
labels.append(jax.numpy.int32(line[0]))
titles.append(purify(line[1]))
descriptions.append(purify_stops(line[2]))
return labels, titles, descriptions
CharactersConvolutionalNeuralNetwork.py
import numpy
import sys
sys.path.append("../52/")
import AgNewsCsvReader
def one_hot(characters, alphabet = None):
alphabet = ("abcdefghijklmnopqrstuvwxyz" if alphabet == None else alphabet)
array = numpy.array(characters)
length = len(alphabet)
# jax.numpy.eye(N, M = None, K = 0, dtype) to create a 2-dimension array that
# the elements in diagonal will be filled out with 1s, others are 0s.
eyes = numpy.eye(length)[array]
return eyes
def one_hot_numbers(numbers):
array = numpy.array(numbers)
maximum = numpy.max(array) + 1
eyes = numpy.eye(maximum)[array]
return eyes
def indexes_of(characters, alphabet = None):
alphabet = ("abcdefghijklmnopqrstuvwxyz" if alphabet == None else alphabet)
indexes = []
for character in characters:
index = alphabet.index(character)
indexes.append(index)
return indexes
def indexes_matrix(string, alphabet = "abcdefghijklmnopqrstuvwxyz "):
indexes = indexes_of(string, alphabet)
matrix = one_hot(indexes, alphabet)
return matrix
def align_string_matrix(string, maximum_length = 64, alphabet = "abcdefghijklmnopqrstuvwxyz "):
length = len(string)
if length > maximum_length:
string = string[: maximum_length]
matrix = indexes_matrix(string)
return matrix
else:
matrix = indexes_matrix(string)
length = maximum_length - length
matrix_padded = numpy.zeros([length, len(alphabet)])
matrix = numpy.concatenate([matrix, matrix_padded], axis = 0)
return matrix
def train():
'''
string = "hello"
indexes = indexes_matrix(string)
print("string =", string, ", indexes =", indexes)
'''
labels, titles, descriptions = AgNewsCsvReader.setup()
#print(labels[: 5], titles[: 5], titles[: 5])
trains = []
for title in titles[: 10]:
matrix = align_string_matrix(title)
trains.append(matrix)
trains = numpy.expand_dims(trains, axis = -1)
labels = one_hot_numbers(labels)
print("trains.shape =", trains.shape, ", labels.shape =", labels.shape)
if __name__ == "__main__":
train()
代码中,首先通过csv库获取全文本数据,之后逐行将文本和标签读入,分别将其转化成独热编码矩阵后,再使用NumPy库将其对应的列表转换成NumPy格式。运行结果打印输出如下,
trains.shape = (120000, 64, 27, 1) , labels.shape = (120000, 5)
这里分别生成了训练集和标签数据的独热编码矩阵列表,
- 训练集的维度为[120000, 64, 27, 1],第一个数字代表样本总数,第二个和第三个数字为生成的矩阵维度,最后一个1代表这里只使用1个通道。
- 标签数据为[120000, 5],是一个二维矩阵,120000是样本的总数,5是类别。注意,one-hot是从0开始的,而标签的分类是从1开始的,因此会自动添加一个0的标签。
至此,文本数据处理结束。
一维卷积神经网络conv1d模型实现文本分类
在完成了文本的处理后,下面进入基于卷积神经网络的分类模型设计,如本章开始时提到了卷积处理字符文本分类的架构图所示,模型的设计有多种多样,根据类似的模型设计了一个由5层神经网络构成的文本分类模型,
层级 | 名称 |
---|---|
1 | Conv 3 x 3, 1 x 1 |
2 | Conv 5 x 5, 1 x 1 |
3 | Conv 3 x 3, 1 x 1 |
4 | Fully Connected 256 |
5 | Fully Connected 5 |
前3层是基于一维的卷积神经网络,后2层适用于分类任务的全连接层。代码如下,
def cnn(number_classes):
return jax.example_libraries.stax.serial(
jax.example_libraries.stax.Conv(1, (3, 3)),
jax.example_libraries.stax.Relu,
jax.example_libraries.stax.Conv(1, (5, 5)),
jax.example_libraries.stax.Relu,
jax.example_libraries.stax.Flatten,
jax.example_libraries.stax.Dense(32),
jax.example_libraries.stax.Relu,
jax.example_libraries.stax.Dense(number_classes),
jax.example_libraries.stax.LogSoftmax
)
完整训练代码如下所示,
../52/AgNewsCsvReader.py
import csv
import re
import jax
import ssl
import nltk
def stop_words():
try:
_create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
pass
else:
ssl._create_default_https_context = _create_unverified_https_context
nltk.data.path.append("/tmp/")
nltk.download("stopwords", download_dir = "/tmp/");
stops = nltk.corpus.stopwords.words("English")
print(stops)
return stops
def purify(string: str, pattern: str = r"[^a-z]", replacement: str = " "):
string = string.lower()
string = re.sub(pattern = pattern, repl = replacement, string = string)
# Replace the consucutive spaces with single space
string = re.sub(pattern = r" +", repl = replacement, string = string)
# string = re.sub(pattern = " ", repl = "", string = string)
# Trim the string
string = string.strip()
string = string + " eos"
return string
def purify_stops(string: str, pattern: str = r"[^a-z0-9]", replacement: str = " ", stops = stop_words()):
string = string.lower()
string = re.sub(pattern = pattern, repl = replacement, string = string)
# Replace the consucutive spaces with single space
string = re.sub(pattern = r" +", repl = replacement, string = string)
# Trim the string
string = string.strip()
# Seperate the string with space, an array will be yielded
strings = string.split(" ")
strings = [word for word in strings if word not in stops]
strings = [nltk.PorterStemmer().stem(word) for word in strings]
strings.append("eos")
strings = ["bos"] + strings
return strings
def setup():
with open("../../Shares/ag_news_csv/train.csv", "r") as handler:
train_labels = []
train_titles = []
train_descriptions = []
trains = csv.reader(handler)
trains = list(trains)
for i in range(len(trains)):
line = trains[I]
train_labels.append(jax.numpy.int32(line[0]))
train_titles.append(purify(line[1]))
train_descriptions.append(purify_stops(line[2]))
with open("../../Shares/ag_news_csv/test.csv", "r") as handler:
test_labels = []
test_titles = []
test_descriptions = []
tests = csv.reader(handler)
tests = list(tests)
for i in range(len(tests)):
line = tests[I]
test_labels.append(jax.numpy.int32(line[0]))
test_titles.append(purify(line[1]))
test_descriptions.append(purify_stops(line[2]))
return (train_labels, train_titles, train_descriptions), (test_labels, test_titles, test_descriptions)
def main():
(train_labels, train_titles, train_descriptions), (test_labels, test_titles, test_descriptions) = setup()
print((train_labels.shape, train_titles.shape, train_descriptions.shape), (test_labels.shape, test_titles.shape, test_descriptions.shape))
if __name__ == "__main__":
main()
CharactersConvolutionalNeuralNetwork.py
import numpy
import jax
import jax.example_libraries.stax
import jax.example_libraries.optimizers
import sys
sys.path.append("../52/")
import AgNewsCsvReader
def one_hot(characters, alphabet):
array = numpy.array(characters)
length = len(alphabet)
# jax.numpy.eye(N, M = None, K = 0, dtype) to create a 2-dimension array that
# the elements in diagonal will be filled out with 1s, others are 0s.
eyes = numpy.eye(length)[array]
return eyes
def one_hot_numbers(numbers):
array = numpy.array(numbers)
maximum = numpy.max(array) + 1
eyes = numpy.eye(maximum)[array]
return eyes
def indexes_of(characters, alphabet):
indexes = []
for character in characters:
index = alphabet.index(character)
indexes.append(index)
return indexes
def indexes_matrix(string, alphabet):
indexes = indexes_of(string, alphabet)
matrix = one_hot(indexes, alphabet)
return matrix
def align_string_matrix(string, maximum_length = 64, alphabet = "abcdefghijklmnopqrstuvwxyz "):
length = len(string)
if length > maximum_length:
string = string[: maximum_length]
matrix = indexes_matrix(string, alphabet)
return matrix
else:
matrix = indexes_matrix(string, alphabet)
length = maximum_length - length
matrix_padded = numpy.zeros([length, len(alphabet)])
matrix = numpy.concatenate([matrix, matrix_padded], axis = 0)
return matrix
def cnn(number_classes):
return jax.example_libraries.stax.serial(
jax.example_libraries.stax.Conv(1, (3, 3)),
jax.example_libraries.stax.Relu,
jax.example_libraries.stax.Conv(1, (5, 5)),
jax.example_libraries.stax.Relu,
jax.example_libraries.stax.Flatten,
jax.example_libraries.stax.Dense(32),
jax.example_libraries.stax.Relu,
jax.example_libraries.stax.Dense(number_classes),
jax.example_libraries.stax.LogSoftmax
)
def setup():
prng = jax.random.PRNGKey(15)
(train_labels, train_titles, train_descriptions), (test_labels, test_titles, test_descriptions) = AgNewsCsvReader.setup()
train_texts = []
for title in train_titles:
matrix = align_string_matrix(title)
train_texts.append(matrix)
train_texts = numpy.expand_dims(train_texts, axis = -1)
train_labels = one_hot_numbers(train_labels)
test_texts = []
for title in test_titles:
matrix = align_string_matrix(title)
test_texts.append(matrix)
test_texts = numpy.expand_dims(test_texts, axis = -1)
test_labels = one_hot_numbers(test_labels)
number_classes = 5
input_shape = [-1, 64, 28, 1]
batch_size = 100
epochs = 5
init_random_params, predict = cnn(number_classes)
optimizer_init_function, optimizer_update_function, get_params_function = jax.example_libraries.optimizers.adam(step_size = 2.17e-4)
_, init_params = init_random_params(prng, input_shape = input_shape)
optimizer_state = optimizer_init_function(init_params)
return (prng, number_classes, batch_size, epochs, init_params, optimizer_state), (init_random_params, optimizer_init_function, predict, optimizer_update_function, get_params_function), ((train_texts, train_labels), (test_texts, tes>
def verify_accuracy(params, batch, predict_function):
inputs, targets = batch
predictions = predict_function(params, inputs)
class_ = jax.numpy.argmax(predictions, axis = 1)
targets = jax.numpy.argmax(targets, axis = 1)
return jax.numpy.sum(predictions == targets)
def loss_function(params, batch, predict_function):
inputs, targets = batch
predictions = predict_function(params, inputs)
losses = -targets * predictions
losses = jax.numpy.sum(losses, axis = 1)
losses = jax.numpy.mean(losses)
return losses
def update_function(i, optimizer_state, batch, get_params_function, optimizer_update_function, predict_function):
params = get_params_function(optimizer_state)
loss_function_grad = jax.grad(loss_function)
gradients = loss_function_grad(params, batch, predict_function)
return optimizer_update_function(i, gradients, optimizer_state)
def train():
(prng, number_classes, batch_size, epochs, init_params, optimizer_state), (init_random_params, optimizer_init_function, predict, optimizer_update_function, get_params_function), ((train_texts, train_labels), (test_texts, test_label>
print("train_texts.shape =", train_texts.shape, ", train_labels.shape =", train_labels.shape, ", test_texts.shape =", test_texts.shape, ", test_labels.shape =", test_labels.shape)
train_batch_number = int(len(train_texts) / batch_size)
test_batch_number = int(len(test_texts) / batch_size)
for i in range(epochs):
print(f"Epoch {i} started")
for j in range(train_batch_number):
start = j * batch_size
end = (j + 1) * batch_size
batch = (train_texts[start: end], train_labels[start: end])
optimizer_state = update_function(i, optimizer_state, batch, get_params_function, optimizer_update_function, predict)
if (j + 1) % 10 == 0:
params = get_params_function(optimizer_state)
losses = loss_function(params, batch)
print("Losses now is =", losses)
params = get_params_function(optimizer_state)
print(f"Epoch {i} compeleted")
accuracies = []
predictions = 0.0
for j in range(test_batch_number):
start = j * batch_size
end = (j + 1) * batch_size
batch = (test_texts[start: end], test_labels[start: end])
predictions += verify_accuracy(params, batch)
accuracies.append(predictions / float(len(train_texts)))
print(f"Training accuracies =", accuracies)
if __name__ == "__main__":
train()
首先获取训练集和测试集,接下来定义预损失函数、优化器,与ResNet类似,不再赘述。
结论
本章基于AG News新闻标题和分类标签,使用一层卷积和全连接层建构了一个文本分类模型。注意,这个示例知识为了说明问题,效果并不一定好。