Python从零实现计算图和自动求导

计算图是现代深度学习框架如 Tensorflow、PyTorch 等的核心概念,其中涉及的所有计算几乎都依赖于计算图提供的自动求导功能,因此研究计算图对深入理解反向传播等深度学习的底层算法大有帮助。

手工求导

求导在数学上非常容易实现,例如以下函数:

f(x) = \sin(e^{x^2})

我们能够轻易地求得其导函数为:

f'(x) = \cos(e^{x^2}) \cdot e^{x^2} \cdot (2x)

那么能否通过编程语言实现该函数及其导函数?答案是可以,而且非常容易,只需要把式子逐项翻译即可:

import numpy as np

def f(x):
    return np.sin(np.exp(np.power(x, 2)))

def f_prime(x):
    t = np.exp(np.power(x, 2))
    return np.cos(t) * t * (2*x)

到目前为止,求导对于编程语言来说似乎没什么难的,但是不要忘记,我们这里研究的函数是一个具体的实例。在实际应用中,我们用到的函数将会非常丰富,它们的组合方式更是千变万化,光是上面这个简单的例子就会有无数种变形!例如:

\begin{gathered} f_2(x) = \cos(e^{x^2}) \\\\ f_3(x) = \sin(2^{x^2}) \\\\ f_4(x) = \sin(e^{x^3}) \\\\ \cdots \end{gathered}

更不要提各式各样的其他复杂函数:

\begin{gathered} g(x) = -y \ln(\frac{1}{1 + e^{-wx}}) - (1-y)\ln(1 - \frac{1}{1 + e^{-wx}}) \\\\ h(x) = (w_n\cdot(\cdots(\mathrm{relu}(w_2\cdot(\mathrm{relu}(w_1\cdot x))))) - y) ^ 2 \\\\ \cdots \end{gathered}

如果坚持手工求导的话,我们不仅需要无数次地推导公式,而且对于某些复杂的函数,求导公式并不简单,显然不可能完成。

链式法则

因此,我们需要一套抽象的求导规则,使得无论函数的具体形式如何,都能自动对其求导。也就是实现如下抽象函数的求导法则:

f(x) = g(h(k(\cdots(x)))

尽管这个问题听上去要比具体函数的求导困难得多,但它依然有章可循。回想我们求导的一般过程,不过是运用了以下两点技术而已:

  1. 基本函数的求导法则。包括三角函数、指数函数、幂函数等。
  2. 链式法则

链式法则使得我们可以对复合函数进行求导。针对上面的例子,为了显式地调用链式法则,我们可以引入如下中间变量:

\begin{aligned} u &= x^2 \\\\ v &= e^u \\\\ w &= sin(v) \end{aligned}

使用链式法则描述的求导过程如下:

\frac{\mathrm{d}y}{\mathrm{d}x} = \frac{\mathrm{d}y}{\mathrm{d}w} \cdot \frac{\mathrm{d}w}{\mathrm{d}v} \cdot \frac{\mathrm{d}v}{\mathrm{d}u} \cdot \frac{\mathrm{d}u}{\mathrm{d}x}

有了链式法则,我们就能够“机械”地搬运任意基本函数的导函数,从而对非常复杂的复合函数求导。

计算图

由上述分析可知,一旦我们实现了(1)基本函数的求导法则以及(2)链式法则,就能够让程序模仿我们手工求导的过程,从而做到“以不变应万变”。计算图非常适合用来描述这两个法则。

计算图在数据结构上属于有向图(Directed Graph),图的每个节点对应一个“基本函数”,而节点之间的有向边则可用于描述链式法则。

上面的例子使用计算图描述如下:

x \to (\cdot)^2 \to e^{(\cdot)} \to \sin(\cdot) \to y

计算图能够非常清晰地展现数据的流动过程。从输入 x 开始,中间依次经过平方、自然指数、正弦函数三个基本运算依次作用,最终得到输出 y

注意:这个例子并非典型的计算图,因为其中所涉及的运算都是一元运算,导致图结构是线性的,没有分支,更像是链表

这种线性结构的计算图无法描述加法、乘法等多元运算,例如 x + sin(x)x\sin(x)。但它的好处是非常简单,便于理解和实现,因此我们将继续使用这种线性结构完成演示。

计算图的每一个节点都包含一个基本函数,并且其导函数是已知的。节点在进行一次“前向计算”时,除了要根据输入值计算输出值之外,还要调用导函数计算梯度值,并缓存在节点中。最终,我们将所有节点的梯度值相乘(链式法则)即可得到整个计算流程的总梯度。

代码实现

在实现代码之前,我们首先要明确接口的设计,即假想用户将会如何调用计算图,这是一个非常重要的工程原则。

我们期望用户以如下方式调用计算图:

>>> import compute_graph as cg
>>> inp = cg.Input()
>>> out = cg.power(inp, 2)
>>> out = cg.exp(out)
>>> out = cg.sin(out)
>>> graph = ComputeGraph(inp, out)
>>>
>>> import numpy as np
>>> x = np.linspace(0, 1, 5)
>>> graph.forward(x)
array([0.84147098, 0.87454388, 0.95916224, 0.98307241, 0.41078129])
>>> graph.grad
array([ 0.        ,  0.25811137,  0.36319491, -0.48233501, -4.95669947])

这种 API 风格与 Keras 比较接近,符合一般用户的使用习惯。

下面我们开始着手实施我们的想法。我们计划为计算图、计算图节点分别设计一个类。

图节点类

首先定义所有图节点的基类,代表节点的通用结构。

class Node(object):
    """Node of compute graph"""
    def __init__(self, x, *args, **kw):
        # 我们已经假定计算图为线性结构,因此只需要连接图中的前一个节点 x。
        # 如果考虑一般的图结构,则需要更为复杂的设计,此处不予讨论。
        if not isinstance(x, Node):
            raise ValueError('the input should be a compute graph Node object')
        x.next = self
        self.next = None

        # 使用一个变量缓存计算的梯度值
        self.grad = None

        # 其余非通用参数单独初始化
        self.init(*args, **kw)

    def init(self, *args, **kw):
        pass

    def fun(self, x):
        """节点中保存的基本函数"""
        raise NotImplementedError()

    def fun_grad(self, x, out):
        """基本函数的导函数,用于计算梯度。
        x, out 分别是 self.fun 的输入和输出。
        理论上只需要 x 即可计算出梯度,但很多函数的导函数会引用自身,例如指数函数。
        引入 out 作为参数可避免计算梯度时重复计算自身。
        """
        raise NotImplementedError()

    def forward(self, x):
        """计算输出,同时缓存梯度"""
        out = self.fun(x)
        self.grad = self.fun_grad(x, out)
        return out

    def __str__(self):
        return self.__class__.__name__

    def __repr__(self):
        return '<"{}" node of compute graph>'.format(str(self))

一般的计算节点只需要继承节点基类,并实现 funfun_grad 两个方法即可。

正弦函数节点

class sin(Node):
    """Node of sin function"""
    def fun(self, x):
        return np.sin(x)

    def fun_grad(self, x, out):
        return np.cos(x)

指数函数节点

class exp(Node):
    """Node of exp function"""
    def fun(self, x):
        return np.exp(x)

    def fun_grad(self, x, out):
        return out

幂函数节点

注意,幂函数需要在初始化时传入额外的参数,即幂指数。

class power(Node):
    """Node of power function"""
    def init(self, p):
        # 从参数中接收幂指数
        self.p = p

    def fun(self, x):
        return np.power(x, self.p)

    def fun_grad(self, x, out):
        return self.p * np.power(x, self.p - 1)

    def __str__(self):
        return '{}(., {})'.format(self.__class__.__name__, self.p)

输入节点

与普通节点不同,输入节点没有前驱节点,也不需要对数据进行加工和求导,因此需要单独进行定义。

class Input(Node):
    """Input Node"""
    def __init__(self):
        self.next = None

    def fun(self, x):
        return x

    def fun_grad(self, x, out):
        return 1

计算图类

我们已经把主要的计算过程定义在了图节点类中,因此计算图类的任务就非常轻松了,只需要整合图节点的计算结果即可。

class ComputeGraph(object):
    """Compute Graph"""
    def __init__(self, inp, out):
        self.head = inp
        self.tail = out
        self.grad = None

    def forward(self, x):
        if self.head is None:
            raise ValueError('the graph is empty')
        out = x
        grad = 1.0
        node = self.head
        while node:
            out = node.forward(out)
            grad *= node.grad
            node = node.next
        self.grad = grad
        return out

    def __str__(self):
        node = self.head
        desc = []
        while node:
            desc.append(str(node))
            node = node.next
        return ' --> '.join(desc)

到此为止,我们的代码已经全部完成了,是不是简单地出乎意料?

验证代码

在进行接口设计时,我们给出了一段样板代码,现在我们可以用它来验证我们的程序。

>>> import compute_graph as cg
>>> inp = cg.Input()
>>> out = cg.power(inp, 2)
>>> out
<"power(., 2)" node of compute graph>
>>> out = cg.exp(out)
>>> out = cg.sin(out)
>>> graph = ComputeGraph(inp, out)
>>> print(graph)
Input --> power(., 2) --> exp --> sin
>>>
>>> import numpy as np
>>> x = np.linspace(0, 1, 5)
>>> graph.forward(x)
array([0.84147098, 0.87454388, 0.95916224, 0.98307241, 0.41078129])
>>> graph.grad
array([ 0.        ,  0.25811137,  0.36319491, -0.48233501, -4.95669947])

代码无误,且输出完全符合预期。

但我们还未考察计算结果是否正确无误,毕竟这才是最重要的。我们可以通过之前手动推导的公式对计算结果加以验证,函数的定义如下:

def f(x):
    return np.sin(np.exp(np.power(x, 2)))

def f_prime(x):
    t = np.exp(np.power(x, 2))
    return 2 * x * np.cos(t) * t

我们进行如下验证:

>>> f(x)
array([0.84147098, 0.87454388, 0.95916224, 0.98307241, 0.41078129])
>>> f_prime(x)
array([ 0.        ,  0.25811137,  0.36319491, -0.48233501, -4.95669947])
>>> np.all(f(x) == graph.forward(x))
True
>>> np.all(f_prime(x) == graph.grad)
True

说明计算图的计算结果和梯度值均准确无误。

上述代码在一些细节问题上可能有所欠缺,但足以从宏观上理解计算图的实现原理。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 206,839评论 6 482
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 88,543评论 2 382
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 153,116评论 0 344
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 55,371评论 1 279
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 64,384评论 5 374
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,111评论 1 285
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,416评论 3 400
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,053评论 0 259
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 43,558评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,007评论 2 325
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,117评论 1 334
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,756评论 4 324
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,324评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,315评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,539评论 1 262
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,578评论 2 355
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,877评论 2 345