第46章 Jaxpr解释器和维度命名

前面几章一直在讲JAX转换函数,如jax.jit、jax.grad、jax.vmap等以及它们的组合使用可以编写简洁、执行高效的代码。本章介绍如何通过自定义Jaxpr解释器来自定义函数转换。

Jaxpr Tracer跟踪器

JAX为数值计算提供了一套类似于NumPy的API,几乎可以按NumPy原样使用jax.numpy,当JAX真正的功能来自于可组合的函数转换。下面以jax.jit函数转换为例,该函数接受一个函数并返回一个语义相同的函数,之后再用XLA加速器编译函数。


import jax

def function(x):
    
    return 2 * x ** 2 + 3 * x

def test():
    
    function_jit = jax.jit(function)
    result = function_jit(10)
    
    print("result = ", result)
    
def main():
    
    test()
    
if __name__ == "__main__":
    
    main()

上面例子里,当调用funciton_jit时,JAX讲跟踪函数并构造XLA计算图,然后对图形进行JIT编译和执行。其他函数转换方式类似,即首先跟踪函数并以某种方式处理输出跟踪。

JAX中一个特别重要的跟踪器就是Jaxpr,它讲OP记录到的Jaxpr(JAX表达式)中。Jaxpr是一种数据结构,可以像函数式编程语言那样进行计算,因此Jaxpr是函数转换中有用的中间表示形式。

可以使用make_jaxpr对函数进行jaxpr转换,它将一个函数转换成给定的示例参数,生成计Jaxpr算表达式。虽然通常不能直接使用它生成的jaxpr语句,但是这对于调试和查看JAX函数很有用。

下面通过几段代码来理解jaxpr的运行机制。


import jax

def function(x):
    
    return 2 * x ** 2 + 3 * x

def test():
    
    expr = jax.make_jaxpr(function)
    result = expr(2.0)
    
    print(result)
    
def main():
    
    test()
    
if __name__ == "__main__":
    
    main()

运行结果打印输出如下,


{ lambda ; a:f32[]. let
    b:f32[] = integer_pow[y=2] a
    c:f32[] = mul 2.0 b
    d:f32[] = mul 3.0 a
    e:f32[] = add c d
  in (e,) }

更详细的函数来来对make_jaxpr进行解析,代码如下,


import jax

def function(x):
    
    return 2 * x ** 2 + 3 * x

def print_jaxpr(closed_expr):
    
    jaxpr = closed_expr.jaxpr
    
    print("invars: ", jaxpr.invars)
    print("outvars: ", jaxpr.outvars)
    print("constvars: ", jaxpr.constvars)
    
    for equation in jaxpr.eqns:
        print("Equation: ", equation.invars, equation.primitive, equation.outvars, equation.params)
        
    print("jaxpr: ", jaxpr)

def test():
    
    expr = jax.make_jaxpr(function)
    result = expr(2.0)
    
    print(result)
    print("--------------------------")
    
    print_jaxpr(result)
    
def main():
    
    test()
    
if __name__ == "__main__":
    
    main()

运行结果打印输出如下,


{ lambda ; a:f32[]. let
    b:f32[] = integer_pow[y=2] a
    c:f32[] = mul 2.0 b
    d:f32[] = mul 3.0 a
    e:f32[] = add c d
  in (e,) }
--------------------------
invars:  [a]
outvars:  [e]
constvars:  []
Equation:  [a] integer_pow [b] {'y': 2}
Equation:  [2.0, b] mul [c] {}
Equation:  [3.0, a] mul [d] {}
Equation:  [c, d] add [e] {}
jaxpr:  { lambda ; a:f32[]. let
    b:f32[] = integer_pow[y=2] a
    c:f32[] = mul 2.0 b
    d:f32[] = mul 3.0 a
    e:f32[] = add c d
  in (e,) }

详细解析前,先了解一下相关参数的意义,

  • jaxpr.invars,输入变量列表,类似于函数的形参。
  • jaxpr.outvars,输出(返回)变量列表。
  • Jaxpr.constvars,变量列表,也是jaxpr的输入变量,但对应追踪中的常量。
  • Jaxpr.eqns,一系列内部计算的等式(或函数)列表,这个列表中的每一个等式(或函数)都有一个输入和输出,用于计算这个函数产生的输出结果。

根据参数说明,可以尝试去解析一下上面运行结果,


invars:  [a]
outvars:  [e]
constvars:  []
Equation:  [a] integer_pow [b] {'y': 2}
Equation:  [2.0, b] mul [c] {}
Equation:  [3.0, a] mul [d] {}
Equation:  [c, d] add [e] {}

  • invars: [a],输入参数变量为唯一元素a组成的数组或列表。
  • outvars: [e],输出或返回值变量为唯一元素e组成的数组或列表。
  • constvars: [],输入可追踪常数“变量”无。
  • Equation: [a] integer_pow [b] {'y': 2},等式(或函数),计算输入参数变量a([a表示参数输入值,[b]表示等式或函数的输出或返回值])的2次幂的等式。
  • Equation: [2.0, b] mul [c] {},等式(或函数),计算输入参数变量b(上面等式的输出值)和常数2.0相乘、返回值或者输出值为c的等式。
  • Equation: [3.0, a] mul [d] {},等式(或函数),计算输入参数变量a和常数3.0相乘、返回值或者输出值为d的等式。
  • Equation: [c, d] add [e] {},等式(或函数),计算输入参数变量c和d相加、返回值或者输出值为e的等式。

{ lambda ; a:f32[]. let
    b:f32[] = integer_pow[y=2] a
    c:f32[] = mul 2.0 b
    d:f32[] = mul 3.0 a
    e:f32[] = add c d
  in (e,) }

  • { lambda ; a:f32[]. let,定义lambda表达式,float32类型的数组输入参数a,let函数体开始。
  • b:f32[] = integer_pow[y=2] a,定义float32类型的数组变量b,用于接受由2次指数函数integer_pow和输入参数a计算后的结果。
  • c:f32[] = mul 2.0 b,定义float32类型的数组变量c,用于接受由常数2.0与上面结果b相乘后的结果。
  • d:f32[] = mul 3.0 a,定义float32类型的数组变量d,用于接受由常数3.0与输入参数a计算后的结果。
  • e:f32[] = add c d,定义float32类型的数组变量e,用于接受由上面结果c和结果d相加的后的结果。
  • in (e,) },定义返回值为由e组成的元组。

由上面解析过程来看,Jaxpr表达式是易于转换的简单程序表示形式,类似于某些语言的中间语言。由于JAX允许从Python函数中直接转译Jaxpr,所以,它提供了一套为Python数值计算函数进行转换的方法。

对于函数的追踪则有些复杂,不能直接使用make_jaxpr,因为需要提取在追踪过程中创建的常量以传递到jaxpr。但是,可以编写一个类似于make_jaxpr的函数,代码如下,


def print_literals():
    
    function_jaxpr = jax.make_jaxpr(function)
    closed_jaxpr = function_jaxpr(2.0)
    
    print(closed_jaxpr)
    print("-----------------------------------------")
    
    print(closed_jaxpr.literals)

运行结果打印输出如下,


{ lambda ; a:f32[]. let
    b:f32[] = integer_pow[y=2] a
    c:f32[] = mul 2.0 b
    d:f32[] = mul 3.0 a
    e:f32[] = add c d
  in (e,) }
-----------------------------------------
[]

此时输出结果就是以序列的方式对函数内部参数进行追踪的Jaxpr代码。

定义可被Jaxpr追踪函数

对于解释器的使用,需要先将其注册之后再遵循JAX原语的规则来使用。下面例子演示使用Jaxpr进行包装的函数。代码如下所示,


import jax

def inverse_iterate_jaxpr(inverse_registry, jaxpr, consts, *args):
    
    configurations = {}
    
    def read(var):
        
        if type(var) is jax.core.Literal:
            
            return var.val
        
        return configurations[var]
    
    def write(var, value):
        
        configurations[var] = value
    
    jax.util.safe_map(write, jaxpr.outvars, args)
    jax.util.safe_map(write, jaxpr.constvars, consts)
    
    # Backwards iteration
    for equation in jaxpr.eqns[:: -1]:
        
        in_values = jax.util.safe_map(read, equation.outvars)
        
        if equation.primitive not in inverse_registry:
            
            raise NotImplementedError("{} does not registered inverse.".format(equation.primitive))
        
        out_values = inverse_registry[equation.primitive](*in_values)
        
        jax.util.safe_map(write, equation.invars, [out_values])
        
    return jax.util.safe_map(read, jaxpr.invars)

def inverse(functionPointer, inverse_registry):
    
    @jax.util.wraps(functionPointer)
    def wrapped_function(*args, **kwargs):
        
        function_jaxpr = jax.make_jaxpr(functionPointer)
        closed_jaxpr = function_jaxpr(*args, **kwargs)
        
        output = inverse_iterate_jaxpr(inverse_registry, closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)
        
        return output[0]
    
    return wrapped_function

def function(x):
    
    tan = jax.numpy.tanh(x)
    exp = jax.numpy.exp(tan)
    
    return exp
          
def test():
    
    function_jaxpr = jax.make_jaxpr(function)
    jaxpr = function_jaxpr(2.)

运行结果打印输出如下,


jaxpr =  { lambda ; a:f32[]. let b:f32[] = tanh a; c:f32[] = exp b in (c,) }
---------------------------
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
result =  { lambda ; a:f32[]. let b:f32[] = log a; c:f32[] = atanh b in (c,) }
---------------------------

可以看到,自定义函数被前向和后向转换后的结果,

  • XLA是JAX使用的编译器,它使得JAX可以用于TPU,并迅速应用于所有设备的编译器,因此值得研究。但是,直接使用原始C++接口处理 XLA计算并不容易。JAX通过Python包装器公开底层的XLA计算生成器API,并使得与XLA计算模型的交互访问,以便进行融合。
  • XLA计算在被编译以计算图的形式生成,然后降低到特定设备中,比如CPU、GPU和TPU。

维度命名

之前在进行矩阵计算时,特别是在VGG训练时,没有使用维度名称,而是根据位置约定来匹配batch_size、channels、height、width等维度。而JAX一个特性是可以给维度进行命名。对维度命名很有用,能够帮助编程者如何使用命名轴来编写文档化函数,以更加直观的方式来操控矩阵运算。

以前面章节实现的全连接层完成MNIST数据集分类任务为例,说明维度命名。代码如下所示,


import jax

def forward(weight1, weight2, images):
    
    dot = jax.numpy.dot(images, weight1)
    
    hidden1 = jax.nn.relu(dot)
    hidden2 = jax.numpy.dot(hidden1, weight2)
    
    logtis = jax.nn.softmax(hidden2)
    
    return logtis

def loss_function(weight1, weight2, images, labels):
    
    predictions = forward(weight1 = weight1, weight2 = weight2, images = images)
    targets = jax.nn.one_hot(labels, predictions.shape[-1])
    losses = jax.numpy.sum(targets * predictions, axis = 1)
    
    return -jax.numpy.mean(losses, axis = 0)

def train():
    
    weight1 = jax.numpy.zeros((784, 512))
    weight2 = jax.numpy.zeros((512, 10))
    
    images = jax.numpy.zeros((128, 784))
    labels = jax.numpy.zeros(128, dtype = jax.numpy.int32)
    
    losses = loss_function(weight1, weight2, images, labels)
    
    print("losses = ", losses)
    
def main():
    
    train()
    
if __name__ == "__main__":
    
    main()

上述代码仅仅是简单地实现了前向预测部分与损失函数的计算。下面通过使用命名空间对这部分代码进行改写,代码如下,


axes = [
        ["inputs", "hidden"],
        ["hidden", "classes"],
        ["batch", "inputs"],
        ["batch", ...]
        ]

这里根据输入的数据建立了对应的维度名称,其中每个维度都被人为设定了特定的名称。通过以下方式使用,


import jax
import numpy
from jax.experimental import maps

def predict(weight1, weight2, images):
    
    dots = jax.numpy.dot(images, weight1)
    hiddens = jax.nn.relu(dots)
    logtis = jax.numpy.dot(hiddens, weight2)
    
    return logtis - jax.nn.logsumexp(logtis, axis = 1, keepdims = True)

def loss_function(weight1, weight2, images, labels):
    
    predictions = predict(weight1 = weight1, weight2 = weight2, images = images)
    targets = jax.nn.one_hot(labels, predictions.shape[-1])
    losses = jax.numpy.sum(targets * predictions, axis = 1)
    
    return -jax.numpy.mean(losses, axis = 0)

# Named dimensions will be used to compute the data
def named_predict(weight1, weight2, images):
    
    pdot = jax.lax.pdot(images, weight1, "inputs")
    hidden = jax.nn.relu(pdot)
    logtis = jax.lax.pdot(hidden, weight2, "hidden")
    
    return logtis - jax.nn.logsumexp(logtis, "classes")

def named_loss_function(weight1, weight2, images, labels):
    
    predictions = named_predict(weight1, weight2, images)
    
    # jax.lax.psum(): Compute an all-reduce sum on x over the pmapped axis axis_name
    number_classes = jax.lax.psum(1, "classes")
    targets = jax.nn.one_hot(labels, number_classes, axis = "classes")
    losses = jax.lax.psum(targets * predictions, "classes")
    
    return -jax.lax.pmean(losses, "batch")
    
def train():
    
    weight1 = jax.numpy.zeros((784, 512))
    weight2 = jax.numpy.zeros((512, 10))
    
    images = jax.numpy.zeros((128, 784))
    labels = jax.numpy.zeros(128, dtype = jax.numpy.int32)
    
    losses = loss_function(weight1, weight2, images, labels)
    
    print("losses = ", losses)
    
    in_axes = [
        ["inputs", "hidden"],
        ["hidden", "classes"],
        ["batch", "inputs"],
        ["batch", ...]
        ]
    
    # Register the names for the dimensions
    loss_function_xmap = maps.xmap(named_loss_function, in_axes = in_axes, out_axes = [...], axis_resources = {"batch": "x"})
    
    devices = numpy.array(jax.local_devices())
    
    with jax.sharding.Mesh(devices, ("x",)):
        
        losses = loss_function_xmap(weight1, weight2, images, labels)
        
        print("losses = ", losses)
    
def main():
    
    train()
    
if __name__ == "__main__":
    
    main()

运行结果打印输出如下,


losses =  2.3025854
losses =  2.3025854

通过给维度命名,可以很好地对神经网络的维度进行设定,而不至于在训练时因弄错维度而造成计算错误。毕竟一个有意义的名称,让让人望文生义,明显好于单纯以数字标识的维度位置。

自定义JAX中的向量Tensor

Python本身的NumPy(不是jax.numpy)中的编程模型是基于N维数组,而不是每一个N维数组数值包含2个部分,

  • 数组中的数据类型。
  • 数组的维度。

在JAX中,这两个维度被同一成一个类型——dtype[shape_tuple]。举例来说,一个float32的维度大小为[3, 17, 21]的数据被定义成f32[(3, 17, 21)]。下面通过一个小示例来掩饰形状如何通过简单的NumPy程序进行传播。


import numpy
import etils

class ArrayType:
    
    def __getitem__(self, idx):
        
        return Any
    
f32 = ArrayType()

def test():
    
    array = numpy.ones(shape = (3, 17, 21))
    
    print(array.shape)
    
    array = numpy.arange(1071).reshape(3, 17, 21)
    
    print(array.shape)
    
    x: etils.array_types.f32[(2, 3)] = numpy.ones(shape = (2, 3), dtype = numpy.float32)
    y: etils.array_types.f32[(3, 5)] = numpy.ones(shape = (3, 5), dtype = numpy.float32)
    
    z: etils.array_types.f32[(2, 5)] = x.dot(y)
    
    w: etils.array_types.f32[(7, 1, 5)] = numpy.ones((7, 1, 5), dtype = numpy.float32)
    
    q: etils.array_types.f32[(7, 2, 5)] = z + w
    
    print(f"x.shape = {x.shape}, y.shape = {y.shape}, z.shape = {z.shape}, w.shape = {w.shape}, q.shape = {q.shape}")
    
    x: f32[(2, 3)] = numpy.ones(shape = (2, 3), dtype = numpy.float32)
    y: f32[(3, 5)] = numpy.ones(shape = (3, 5), dtype = numpy.float32)
    
    z: f32[(2, 5)] = x.dot(y)
    
    w: f32[(7, 1, 5)] = numpy.ones((7, 1, 5), dtype = numpy.float32)
    
    q: f32[(7, 2, 5)] = z + w
    
    print(f"x.shape = {x.shape}, y.shape = {y.shape}, z.shape = {z.shape}, w.shape = {w.shape}, q.shape = {q.shape}")
    
def main():
    
    test()
    
if __name__ == "__main__":
    
    main()

运行结果打印输出如下,


(3, 17, 21)
(3, 17, 21)
x.shape = (2, 3), y.shape = (3, 5), z.shape = (2, 5), w.shape = (7, 1, 5), q.shape = (7, 2, 5)
x.shape = (2, 3), y.shape = (3, 5), z.shape = (2, 5), w.shape = (7, 1, 5), q.shape = (7, 2, 5)

关于f32,从过代码可知,有两种来源etils.array_types.f32和自定义的类,


class ArrayType:
    
    def __getitem__(self, idx):
        
        return Any
    
f32 = ArrayType()

实际上,在自定义类里,f32是定义的能够接受和返回任何数据类型的自定义类。此时这样被自定义的类可以和正常的数组一样被打印,并提供了一个对应的shape大小。

结论

本章探讨了jaxpr解释器,从组合函数的转换、追踪器,以及自定义可被jaxpr追踪的函数,较为底层。同时,也从工程实践角度通过命名维度来改善深度学习里对矩阵的管理。

内容较多,量力而行。

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 203,098评论 5 476
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,213评论 2 380
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 149,960评论 0 336
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,519评论 1 273
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,512评论 5 364
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,533评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,914评论 3 395
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,574评论 0 256
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,804评论 1 296
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,563评论 2 319
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,644评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,350评论 4 318
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,933评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,908评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,146评论 1 259
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 42,847评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,361评论 2 342

推荐阅读更多精彩内容