前面几章一直在讲JAX转换函数,如jax.jit、jax.grad、jax.vmap等以及它们的组合使用可以编写简洁、执行高效的代码。本章介绍如何通过自定义Jaxpr解释器来自定义函数转换。
Jaxpr Tracer跟踪器
JAX为数值计算提供了一套类似于NumPy的API,几乎可以按NumPy原样使用jax.numpy,当JAX真正的功能来自于可组合的函数转换。下面以jax.jit函数转换为例,该函数接受一个函数并返回一个语义相同的函数,之后再用XLA加速器编译函数。
import jax
def function(x):
return 2 * x ** 2 + 3 * x
def test():
function_jit = jax.jit(function)
result = function_jit(10)
print("result = ", result)
def main():
test()
if __name__ == "__main__":
main()
上面例子里,当调用funciton_jit时,JAX讲跟踪函数并构造XLA计算图,然后对图形进行JIT编译和执行。其他函数转换方式类似,即首先跟踪函数并以某种方式处理输出跟踪。
JAX中一个特别重要的跟踪器就是Jaxpr,它讲OP记录到的Jaxpr(JAX表达式)中。Jaxpr是一种数据结构,可以像函数式编程语言那样进行计算,因此Jaxpr是函数转换中有用的中间表示形式。
可以使用make_jaxpr对函数进行jaxpr转换,它将一个函数转换成给定的示例参数,生成计Jaxpr算表达式。虽然通常不能直接使用它生成的jaxpr语句,但是这对于调试和查看JAX函数很有用。
下面通过几段代码来理解jaxpr的运行机制。
import jax
def function(x):
return 2 * x ** 2 + 3 * x
def test():
expr = jax.make_jaxpr(function)
result = expr(2.0)
print(result)
def main():
test()
if __name__ == "__main__":
main()
运行结果打印输出如下,
{ lambda ; a:f32[]. let
b:f32[] = integer_pow[y=2] a
c:f32[] = mul 2.0 b
d:f32[] = mul 3.0 a
e:f32[] = add c d
in (e,) }
更详细的函数来来对make_jaxpr进行解析,代码如下,
import jax
def function(x):
return 2 * x ** 2 + 3 * x
def print_jaxpr(closed_expr):
jaxpr = closed_expr.jaxpr
print("invars: ", jaxpr.invars)
print("outvars: ", jaxpr.outvars)
print("constvars: ", jaxpr.constvars)
for equation in jaxpr.eqns:
print("Equation: ", equation.invars, equation.primitive, equation.outvars, equation.params)
print("jaxpr: ", jaxpr)
def test():
expr = jax.make_jaxpr(function)
result = expr(2.0)
print(result)
print("--------------------------")
print_jaxpr(result)
def main():
test()
if __name__ == "__main__":
main()
运行结果打印输出如下,
{ lambda ; a:f32[]. let
b:f32[] = integer_pow[y=2] a
c:f32[] = mul 2.0 b
d:f32[] = mul 3.0 a
e:f32[] = add c d
in (e,) }
--------------------------
invars: [a]
outvars: [e]
constvars: []
Equation: [a] integer_pow [b] {'y': 2}
Equation: [2.0, b] mul [c] {}
Equation: [3.0, a] mul [d] {}
Equation: [c, d] add [e] {}
jaxpr: { lambda ; a:f32[]. let
b:f32[] = integer_pow[y=2] a
c:f32[] = mul 2.0 b
d:f32[] = mul 3.0 a
e:f32[] = add c d
in (e,) }
详细解析前,先了解一下相关参数的意义,
- jaxpr.invars,输入变量列表,类似于函数的形参。
- jaxpr.outvars,输出(返回)变量列表。
- Jaxpr.constvars,变量列表,也是jaxpr的输入变量,但对应追踪中的常量。
- Jaxpr.eqns,一系列内部计算的等式(或函数)列表,这个列表中的每一个等式(或函数)都有一个输入和输出,用于计算这个函数产生的输出结果。
根据参数说明,可以尝试去解析一下上面运行结果,
invars: [a]
outvars: [e]
constvars: []
Equation: [a] integer_pow [b] {'y': 2}
Equation: [2.0, b] mul [c] {}
Equation: [3.0, a] mul [d] {}
Equation: [c, d] add [e] {}
- invars: [a],输入参数变量为唯一元素a组成的数组或列表。
- outvars: [e],输出或返回值变量为唯一元素e组成的数组或列表。
- constvars: [],输入可追踪常数“变量”无。
- Equation: [a] integer_pow [b] {'y': 2},等式(或函数),计算输入参数变量a([a表示参数输入值,[b]表示等式或函数的输出或返回值])的2次幂的等式。
- Equation: [2.0, b] mul [c] {},等式(或函数),计算输入参数变量b(上面等式的输出值)和常数2.0相乘、返回值或者输出值为c的等式。
- Equation: [3.0, a] mul [d] {},等式(或函数),计算输入参数变量a和常数3.0相乘、返回值或者输出值为d的等式。
- Equation: [c, d] add [e] {},等式(或函数),计算输入参数变量c和d相加、返回值或者输出值为e的等式。
{ lambda ; a:f32[]. let
b:f32[] = integer_pow[y=2] a
c:f32[] = mul 2.0 b
d:f32[] = mul 3.0 a
e:f32[] = add c d
in (e,) }
- { lambda ; a:f32[]. let,定义lambda表达式,float32类型的数组输入参数a,let函数体开始。
- b:f32[] = integer_pow[y=2] a,定义float32类型的数组变量b,用于接受由2次指数函数integer_pow和输入参数a计算后的结果。
- c:f32[] = mul 2.0 b,定义float32类型的数组变量c,用于接受由常数2.0与上面结果b相乘后的结果。
- d:f32[] = mul 3.0 a,定义float32类型的数组变量d,用于接受由常数3.0与输入参数a计算后的结果。
- e:f32[] = add c d,定义float32类型的数组变量e,用于接受由上面结果c和结果d相加的后的结果。
- in (e,) },定义返回值为由e组成的元组。
由上面解析过程来看,Jaxpr表达式是易于转换的简单程序表示形式,类似于某些语言的中间语言。由于JAX允许从Python函数中直接转译Jaxpr,所以,它提供了一套为Python数值计算函数进行转换的方法。
对于函数的追踪则有些复杂,不能直接使用make_jaxpr,因为需要提取在追踪过程中创建的常量以传递到jaxpr。但是,可以编写一个类似于make_jaxpr的函数,代码如下,
def print_literals():
function_jaxpr = jax.make_jaxpr(function)
closed_jaxpr = function_jaxpr(2.0)
print(closed_jaxpr)
print("-----------------------------------------")
print(closed_jaxpr.literals)
运行结果打印输出如下,
{ lambda ; a:f32[]. let
b:f32[] = integer_pow[y=2] a
c:f32[] = mul 2.0 b
d:f32[] = mul 3.0 a
e:f32[] = add c d
in (e,) }
-----------------------------------------
[]
此时输出结果就是以序列的方式对函数内部参数进行追踪的Jaxpr代码。
定义可被Jaxpr追踪函数
对于解释器的使用,需要先将其注册之后再遵循JAX原语的规则来使用。下面例子演示使用Jaxpr进行包装的函数。代码如下所示,
import jax
def inverse_iterate_jaxpr(inverse_registry, jaxpr, consts, *args):
configurations = {}
def read(var):
if type(var) is jax.core.Literal:
return var.val
return configurations[var]
def write(var, value):
configurations[var] = value
jax.util.safe_map(write, jaxpr.outvars, args)
jax.util.safe_map(write, jaxpr.constvars, consts)
# Backwards iteration
for equation in jaxpr.eqns[:: -1]:
in_values = jax.util.safe_map(read, equation.outvars)
if equation.primitive not in inverse_registry:
raise NotImplementedError("{} does not registered inverse.".format(equation.primitive))
out_values = inverse_registry[equation.primitive](*in_values)
jax.util.safe_map(write, equation.invars, [out_values])
return jax.util.safe_map(read, jaxpr.invars)
def inverse(functionPointer, inverse_registry):
@jax.util.wraps(functionPointer)
def wrapped_function(*args, **kwargs):
function_jaxpr = jax.make_jaxpr(functionPointer)
closed_jaxpr = function_jaxpr(*args, **kwargs)
output = inverse_iterate_jaxpr(inverse_registry, closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)
return output[0]
return wrapped_function
def function(x):
tan = jax.numpy.tanh(x)
exp = jax.numpy.exp(tan)
return exp
def test():
function_jaxpr = jax.make_jaxpr(function)
jaxpr = function_jaxpr(2.)
运行结果打印输出如下,
jaxpr = { lambda ; a:f32[]. let b:f32[] = tanh a; c:f32[] = exp b in (c,) }
---------------------------
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
result = { lambda ; a:f32[]. let b:f32[] = log a; c:f32[] = atanh b in (c,) }
---------------------------
可以看到,自定义函数被前向和后向转换后的结果,
- XLA是JAX使用的编译器,它使得JAX可以用于TPU,并迅速应用于所有设备的编译器,因此值得研究。但是,直接使用原始C++接口处理 XLA计算并不容易。JAX通过Python包装器公开底层的XLA计算生成器API,并使得与XLA计算模型的交互访问,以便进行融合。
- XLA计算在被编译以计算图的形式生成,然后降低到特定设备中,比如CPU、GPU和TPU。
维度命名
之前在进行矩阵计算时,特别是在VGG训练时,没有使用维度名称,而是根据位置约定来匹配batch_size、channels、height、width等维度。而JAX一个特性是可以给维度进行命名。对维度命名很有用,能够帮助编程者如何使用命名轴来编写文档化函数,以更加直观的方式来操控矩阵运算。
以前面章节实现的全连接层完成MNIST数据集分类任务为例,说明维度命名。代码如下所示,
import jax
def forward(weight1, weight2, images):
dot = jax.numpy.dot(images, weight1)
hidden1 = jax.nn.relu(dot)
hidden2 = jax.numpy.dot(hidden1, weight2)
logtis = jax.nn.softmax(hidden2)
return logtis
def loss_function(weight1, weight2, images, labels):
predictions = forward(weight1 = weight1, weight2 = weight2, images = images)
targets = jax.nn.one_hot(labels, predictions.shape[-1])
losses = jax.numpy.sum(targets * predictions, axis = 1)
return -jax.numpy.mean(losses, axis = 0)
def train():
weight1 = jax.numpy.zeros((784, 512))
weight2 = jax.numpy.zeros((512, 10))
images = jax.numpy.zeros((128, 784))
labels = jax.numpy.zeros(128, dtype = jax.numpy.int32)
losses = loss_function(weight1, weight2, images, labels)
print("losses = ", losses)
def main():
train()
if __name__ == "__main__":
main()
上述代码仅仅是简单地实现了前向预测部分与损失函数的计算。下面通过使用命名空间对这部分代码进行改写,代码如下,
axes = [
["inputs", "hidden"],
["hidden", "classes"],
["batch", "inputs"],
["batch", ...]
]
这里根据输入的数据建立了对应的维度名称,其中每个维度都被人为设定了特定的名称。通过以下方式使用,
import jax
import numpy
from jax.experimental import maps
def predict(weight1, weight2, images):
dots = jax.numpy.dot(images, weight1)
hiddens = jax.nn.relu(dots)
logtis = jax.numpy.dot(hiddens, weight2)
return logtis - jax.nn.logsumexp(logtis, axis = 1, keepdims = True)
def loss_function(weight1, weight2, images, labels):
predictions = predict(weight1 = weight1, weight2 = weight2, images = images)
targets = jax.nn.one_hot(labels, predictions.shape[-1])
losses = jax.numpy.sum(targets * predictions, axis = 1)
return -jax.numpy.mean(losses, axis = 0)
# Named dimensions will be used to compute the data
def named_predict(weight1, weight2, images):
pdot = jax.lax.pdot(images, weight1, "inputs")
hidden = jax.nn.relu(pdot)
logtis = jax.lax.pdot(hidden, weight2, "hidden")
return logtis - jax.nn.logsumexp(logtis, "classes")
def named_loss_function(weight1, weight2, images, labels):
predictions = named_predict(weight1, weight2, images)
# jax.lax.psum(): Compute an all-reduce sum on x over the pmapped axis axis_name
number_classes = jax.lax.psum(1, "classes")
targets = jax.nn.one_hot(labels, number_classes, axis = "classes")
losses = jax.lax.psum(targets * predictions, "classes")
return -jax.lax.pmean(losses, "batch")
def train():
weight1 = jax.numpy.zeros((784, 512))
weight2 = jax.numpy.zeros((512, 10))
images = jax.numpy.zeros((128, 784))
labels = jax.numpy.zeros(128, dtype = jax.numpy.int32)
losses = loss_function(weight1, weight2, images, labels)
print("losses = ", losses)
in_axes = [
["inputs", "hidden"],
["hidden", "classes"],
["batch", "inputs"],
["batch", ...]
]
# Register the names for the dimensions
loss_function_xmap = maps.xmap(named_loss_function, in_axes = in_axes, out_axes = [...], axis_resources = {"batch": "x"})
devices = numpy.array(jax.local_devices())
with jax.sharding.Mesh(devices, ("x",)):
losses = loss_function_xmap(weight1, weight2, images, labels)
print("losses = ", losses)
def main():
train()
if __name__ == "__main__":
main()
运行结果打印输出如下,
losses = 2.3025854
losses = 2.3025854
通过给维度命名,可以很好地对神经网络的维度进行设定,而不至于在训练时因弄错维度而造成计算错误。毕竟一个有意义的名称,让让人望文生义,明显好于单纯以数字标识的维度位置。
自定义JAX中的向量Tensor
Python本身的NumPy(不是jax.numpy)中的编程模型是基于N维数组,而不是每一个N维数组数值包含2个部分,
- 数组中的数据类型。
- 数组的维度。
在JAX中,这两个维度被同一成一个类型——dtype[shape_tuple]。举例来说,一个float32的维度大小为[3, 17, 21]的数据被定义成f32[(3, 17, 21)]。下面通过一个小示例来掩饰形状如何通过简单的NumPy程序进行传播。
import numpy
import etils
class ArrayType:
def __getitem__(self, idx):
return Any
f32 = ArrayType()
def test():
array = numpy.ones(shape = (3, 17, 21))
print(array.shape)
array = numpy.arange(1071).reshape(3, 17, 21)
print(array.shape)
x: etils.array_types.f32[(2, 3)] = numpy.ones(shape = (2, 3), dtype = numpy.float32)
y: etils.array_types.f32[(3, 5)] = numpy.ones(shape = (3, 5), dtype = numpy.float32)
z: etils.array_types.f32[(2, 5)] = x.dot(y)
w: etils.array_types.f32[(7, 1, 5)] = numpy.ones((7, 1, 5), dtype = numpy.float32)
q: etils.array_types.f32[(7, 2, 5)] = z + w
print(f"x.shape = {x.shape}, y.shape = {y.shape}, z.shape = {z.shape}, w.shape = {w.shape}, q.shape = {q.shape}")
x: f32[(2, 3)] = numpy.ones(shape = (2, 3), dtype = numpy.float32)
y: f32[(3, 5)] = numpy.ones(shape = (3, 5), dtype = numpy.float32)
z: f32[(2, 5)] = x.dot(y)
w: f32[(7, 1, 5)] = numpy.ones((7, 1, 5), dtype = numpy.float32)
q: f32[(7, 2, 5)] = z + w
print(f"x.shape = {x.shape}, y.shape = {y.shape}, z.shape = {z.shape}, w.shape = {w.shape}, q.shape = {q.shape}")
def main():
test()
if __name__ == "__main__":
main()
运行结果打印输出如下,
(3, 17, 21)
(3, 17, 21)
x.shape = (2, 3), y.shape = (3, 5), z.shape = (2, 5), w.shape = (7, 1, 5), q.shape = (7, 2, 5)
x.shape = (2, 3), y.shape = (3, 5), z.shape = (2, 5), w.shape = (7, 1, 5), q.shape = (7, 2, 5)
关于f32,从过代码可知,有两种来源etils.array_types.f32和自定义的类,
class ArrayType:
def __getitem__(self, idx):
return Any
f32 = ArrayType()
实际上,在自定义类里,f32是定义的能够接受和返回任何数据类型的自定义类。此时这样被自定义的类可以和正常的数组一样被打印,并提供了一个对应的shape大小。
结论
本章探讨了jaxpr解释器,从组合函数的转换、追踪器,以及自定义可被jaxpr追踪的函数,较为底层。同时,也从工程实践角度通过命名维度来改善深度学习里对矩阵的管理。
内容较多,量力而行。