PyTorch leaf node

起因

今天被PyTroch tensor的requires_grad搞了一把。具体情况是创建一个tensor和在后续的使用过程中,对requires_grad的取值会影响到python存储的变量是否为leaf node。说起来很抽象,直接上代码。

(有关leaf node,请参考我的另外一篇博客,https://blog.csdn.net/huyaoyu/article/details/81059315)

测试代码

以下代码测试在PyTorch 1.3.1上。

import torchif __name__ == "__main__":a = torch.tensor([1.0], requires_grad=False)print("a.is_leaf = {}. ".format( a.is_leaf ))b = torch.tensor([1.0], requires_grad=True)print("b.is_leaf = {}. ".format( b.is_leaf ))c = torch.tensor([1.0], requires_grad=False).clone()print("c.is_leaf = {}. ".format( c.is_leaf ))d = torch.tensor([1.0], requires_grad=False).detach()print("d.is_leaf = {}. ".format( d.is_leaf ))e = torch.tensor([1.0], requires_grad=False).cuda()print("e.is_leaf = {}. ".format( e.is_leaf ))f = torch.tensor([1.0], requires_grad=True).clone()print("f.is_leaf = {}. ".format( f.is_leaf ))g = torch.tensor([1.0], requires_grad=True).detach()print("g.is_leaf = {}. ".format( g.is_leaf ))h = torch.tensor([1.0], requires_grad=True).cuda()print("h.is_leaf = {}. ".format( h.is_leaf ))i = torch.tensor([1.0], requires_grad=True).clone().detach()print("i.is_leaf = {}. ".format( i.is_leaf ))j = torch.tensor([1.0], requires_grad=True).detach().clone()print("j.is_leaf = {}. ".format( j.is_leaf ))k = torch.tensor([1.0], requires_grad=True).cuda().detach()print("k.is_leaf = {}. ".format( k.is_leaf ))

各位猜一下输出都是什么?

输出是这样的(PyTorch 1.3.1):

a.is_leaf = True. 
b.is_leaf = True. 
c.is_leaf = True. 
d.is_leaf = True. 
e.is_leaf = True. 
f.is_leaf = False. 
g.is_leaf = True. 
h.is_leaf = False. 
i.is_leaf = True. 
j.is_leaf = True. 
k.is_leaf = True. 

其中fh的输出显示对应的python变量不再是leaf node了。其原因在于torch.tensor([1.0], requires_grad=True)将返回一个设置了requires_grad = True的tensor,这个tensor的所有后续的.clone().cuda()操作都是“可微”的,也就是说.clone().cuda()操作都将返回一个非leaf node。于是如果我们想确保得到的python变量是一个leaf node,最保险的做法是在使用类似于torch的tensor()zeros()函数时,不指定requires_grad,此时可以对得到的tensor随意操作.clone().cuda()并赋值给其他python变量。在得到最终python变量后,通过显式对requires_grad成员变量赋值从而设自动梯度运算请求。

参考文献

https://discuss.pytorch.org/t/how-to-define-a-leaf-tensor-in-pytorch-0-4-1/28461/5


本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

相关文章

立即
投稿

微信公众账号

微信扫一扫加关注

返回
顶部