【架构分析】Tensorflow 自动微分源码分析

目录

概要

源码分析

Python 梯度API gradients

核心函数_GradientsHelper

反向传播梯度计算过程示意


概要

深度学习框架比如Tensorflow能够对深度学习领域起到巨大推动作用的一个重要Feature就是自动反向传播更新梯度,文章 链接 对几种不同类型的微分做了很详细的图文说明,值得一看。本文基于Google 官Tensorflow_r1.15 版本对求解梯度(微分)的核心函数代码做如下分析

源码分析

Python 梯度API gradients

@tf_export(v1=["gradients"])
def gradients(ys,xs,grad_ys=None,name="gradients",colocate_gradients_with_ops=False,gate_gradients=False,aggregation_method=None,stop_gradients=None,unconnected_gradients=UnconnectedGradients.NONE):
'''Args:ys: A `Tensor` or list of tensors to be differentiated.xs: A `Tensor` or list of tensors to be used for differentiation.grad_ys: Optional. A `Tensor` or list of tensors the same size as`ys` and holding the gradients computed for each y in `ys`....Returns:A list of `sum(dy/dx)` for each x in `xs`.
'''...with ops.get_default_graph()._mutation_lock():return gradients_util._GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,gate_gradients, aggregation_method, stop_gradients,unconnected_gradients)

tf.gradients 计算 tensor list ys 对 tensor list xs 的梯度,计算梯度的len和xs一样,计算结果是对于每个 x_{i} \epsilon x_{s}  求出 \sum \frac{\partial y{_{j}}}{\partial x_{i}}  (其中 y_{j} \epsilon y_{s}) ,换而言之即遍历xs,每取出1个xi,把ys中每个tensor对xi求梯度然后相加作为ys对xi的梯度结果,最后计算出ys对所有的xi的梯度,组成ys对xs的梯度结果

核心函数_GradientsHelper

计算梯度的核心函数_GradientsHelper 分析如下

def _GradientsHelper(ys,xs,grad_ys=None,name="gradients",colocate_gradients_with_ops=False,gate_gradients=False,aggregation_method=None,stop_gradients=None,unconnected_gradients=UnconnectedGradients.NONE,src_graph=None):"""Implementation of gradients()."""
...# The approach we take here is as follows: Create a list of all ops in the# subgraph between the ys and xs.  Visit these ops in reverse order of ids# to ensure that when we visit an op the gradients w.r.t its outputs have# been collected.  Then aggregate these gradients if needed, call the op's# gradient function, and add the generated gradients to the gradients for# its input.# Initialize the pending count for ops in the connected subgraph from ys# to the xs.to_ops = [t.op for t in ys]from_ops = [t.op for t in xs]stop_gradient_ops = [t.op for t in stop_gradients]reachable_to_ops, pending_count, loop_state = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs_set)# Iterate over the collected ops.## grads: op => list of gradients received on each output endpoint of the# op.  The gradients for each endpoint are initially collected as a list.# When it is time to call the op's gradient function, for each endpoint we# aggregate the list of received gradients into a Add() Operation if there# is more than one.grads = {}# Add the initial gradients for the ys.for y, grad_y in zip(ys, grad_ys):_SetGrad(grads, y, grad_y)# Initialize queue with to_ops.queue = collections.deque()# Add the ops in 'to_ops' into the queue.to_ops_set = set()for op in to_ops:# 'ready' handles the case where one output gradient relies on# another output's gradient.ready = (pending_count[op] == 0)if ready and op not in to_ops_set and op in reachable_to_ops:to_ops_set.add(op)queue.append(op)
...while queue:# generate gradient subgraph for op.op = queue.popleft()
...if has_out_grads and (op not in stop_ops):try:grad_fn = ops.get_gradient_function(op)
...with ops.name_scope(op.name + "_grad"):# pylint: disable=protected-accesswith src_graph._original_op(op):# pylint: enable=protected-accessif grad_fn:# If grad_fn was found, do not use SymbolicGradient even for# functions.in_grads = _MaybeCompile(grad_scope, op, func_call,lambda: grad_fn(op, *out_grads))else:# For function call ops, we add a 'SymbolicGradient'# node to the graph to compute gradients.in_grads = _MaybeCompile(grad_scope, op, func_call,lambda: _SymGrad(op, out_grads))
..._LogOpGradients(op, out_grads, in_grads)
...# Update pending count for the inputs of op and enqueue ready ops._UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,xs_set)...return [_GetGrad(grads, x, unconnected_gradients) for x in xs]

要理解上面的代码,需要先理解其中的几个核心数据和函数的概念:

  • 队列queue、 字段grads、pending_count计数、OP在正向和反向中如何计算的概念说明如下

  • _PendingCount函数对ys和xs之间子图的每个OP的input tensor创建该input 对应OP的pending_count计算,换而言之即统计了ys和xs之间子图的每个OP在反向传播过程计算梯度的过程中输入的个数有几个
def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,xs_set):
...  
# Initialize pending count for between ops.pending_count = collections.defaultdict(int)for op in between_op_list:for x in _NonEagerInputs(op, xs_set):if x.op in between_ops:pending_count[x.op] += 1
  • pending_count的字面解释比较难以理解,用下图说明,即xs和ys之间的子图有4个OP,那么从最后一个OP4开始遍历,它的input tensor对应的OP 即OP2和OP3的pending_count[OP2]为1,pending_count[OP3]为1,而OP2和OP3的input tensor都对应OP1,所以pending_count[OP1]为2; 这样就可以知道,在反向创建计算梯度的过程中,OP2 OP3的梯度输入个数为1,OP1的梯度输入个数为2 (OP在正向传播的输出,即反向传播的输入)

  • 两个重要变量开始在代码逻辑中起作用
    • 一个队列queue,队列里存放计算图里所有出度为0的OP,其实就是该OP的所有输出相对ys的梯度已经都计算出来的OP,循环遍历该queue,取出OP来计算该OP的输入相对于ys的梯度
    • 一个字典grads,字典的键是操作符本身,值是该操作符每个输出端收到的梯度列表
    • queue首先放入的是ys的OP,比如上图的OP4,它一定是ready的(它的梯度输入比如是 d(ys)/d(ys) = 1 一定是已经计算出来了)
  • 通过找出OP对应的梯度计算函数 grad_fn,计算OP的输入相对ys的梯度,保存到grads
  • ​注意每个函数都使用了装饰器RegisterGradient包装,对有m个输入,n个输出的操作符,相应的梯度函数需要传入两个参数
    • 操作符本身
    • n个张量对象,代表OP每个输出相对ys的梯度
    • 返回m个张量对象,代表OP每个输入相对ys的梯度
    • 大部分操作符的梯度计算方式已经由框架给出,但是也可以自定义操作和对应的梯度计算函数
@ops.RegisterGradient("Log")def _LogGrad(op, grad):"""Returns grad * (1/x)."""x = op.inputs[0]with ops.control_dependencies([grad]):x = math_ops.conj(x)return grad * math_ops.reciprocal(x)

  • to_ops、from_ops、函数_PendingCount 构建从输入Xs到输出Ys构建反向传播计算子图的过程说明如下

 _PendingCount 将计算图中ys和xs之间的子图构建一个“可以到达”的op list reachable_to_ops,反向遍历这个list(反向传播),确保每个被遍历的OP,它的所有输出梯度已经计算出来了;注意:这里的梯度其实是指该OP的outputs 相对ys的梯度,如果ys是整个模型的Loss,那么即是d(Loss)/d(OP'outputs) 

  • _LogOpGradients(op, out_grads, in_grads) 可以通过python log 打印OP的输出相对ys的梯度以及经过grad_fn计算出来的OP的输入相对ys的梯度,使用方法可以参考我的另外一篇文章 Tensorflow Debug Tips 链接
  • _UpdatePendingAndEnqueueReady 函数更新OP的pending_count,如果pending_count[op]==0,即表示一个OP的所有输出相对ys的梯度已经都计算出来了,那么这个OP就可以加入queue中,去计算该OP的输入相对ys的梯度,即反向传播过程的计算

反向传播梯度计算过程示意


本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

相关文章

立即
投稿

微信公众账号

微信扫一扫加关注

返回
顶部