PyTorch leaf node
起因
今天被PyTroch tensor的requires_grad搞了一把。具体情况是创建一个tensor和在后续的使用过程中,对requires_grad的取值会影响到python存储的变量是否为leaf node。说起来很抽象,直接上代码。
(有关leaf node,请参考我的另外一篇博客,https://blog.csdn.net/huyaoyu/article/details/81059315)
测试代码
以下代码测试在PyTorch 1.3.1上。
import torchif __name__ == "__main__":a = torch.tensor([1.0], requires_grad=False)print("a.is_leaf = {}. ".format( a.is_leaf ))b = torch.tensor([1.0], requires_grad=True)print("b.is_leaf = {}. ".format( b.is_leaf ))c = torch.tensor([1.0], requires_grad=False).clone()print("c.is_leaf = {}. ".format( c.is_leaf ))d = torch.tensor([1.0], requires_grad=False).detach()print("d.is_leaf = {}. ".format( d.is_leaf ))e = torch.tensor([1.0], requires_grad=False).cuda()print("e.is_leaf = {}. ".format( e.is_leaf ))f = torch.tensor([1.0], requires_grad=True).clone()print("f.is_leaf = {}. ".format( f.is_leaf ))g = torch.tensor([1.0], requires_grad=True).detach()print("g.is_leaf = {}. ".format( g.is_leaf ))h = torch.tensor([1.0], requires_grad=True).cuda()print("h.is_leaf = {}. ".format( h.is_leaf ))i = torch.tensor([1.0], requires_grad=True).clone().detach()print("i.is_leaf = {}. ".format( i.is_leaf ))j = torch.tensor([1.0], requires_grad=True).detach().clone()print("j.is_leaf = {}. ".format( j.is_leaf ))k = torch.tensor([1.0], requires_grad=True).cuda().detach()print("k.is_leaf = {}. ".format( k.is_leaf ))
各位猜一下输出都是什么?
输出是这样的(PyTorch 1.3.1):
a.is_leaf = True.
b.is_leaf = True.
c.is_leaf = True.
d.is_leaf = True.
e.is_leaf = True.
f.is_leaf = False.
g.is_leaf = True.
h.is_leaf = False.
i.is_leaf = True.
j.is_leaf = True.
k.is_leaf = True.
其中f和h的输出显示对应的python变量不再是leaf node了。其原因在于torch.tensor([1.0], requires_grad=True)将返回一个设置了requires_grad = True的tensor,这个tensor的所有后续的.clone()和.cuda()操作都是“可微”的,也就是说.clone()和.cuda()操作都将返回一个非leaf node。于是如果我们想确保得到的python变量是一个leaf node,最保险的做法是在使用类似于torch的tensor()或zeros()函数时,不指定requires_grad,此时可以对得到的tensor随意操作.clone()和.cuda()并赋值给其他python变量。在得到最终python变量后,通过显式对requires_grad成员变量赋值从而设自动梯度运算请求。
参考文献
https://discuss.pytorch.org/t/how-to-define-a-leaf-tensor-in-pytorch-0-4-1/28461/5
本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!
