Pytorch梯度检查 torch.autograd.gradcheck
在编写好自己的 autograd function 后,可以利用
gradcheck中提供的gradcheck和gradgradcheck接口,对数值算得的梯度和求导算得的梯度进行比较,以检查backward是否编写正确。 即用它可以check自己写的反向传播函数是否正确这个函数可以自己计算通过数值法求得的梯度,然后和我们写的backward的结果比较
在下面的例子中,我们自己实现了
Sigmoid函数,并利用gradcheck来检查backward的编写是否正确。import torch from torch.autograd import Function import torchclass Sigmoid(Function):@staticmethoddef forward(ctx, x): output = 1 / (1 + torch.exp(-x))ctx.save_for_backward(output)return output@staticmethoddef backward(ctx, grad_output): output, = ctx.saved_tensorsgrad_x = output * (1 - output) * grad_outputreturn grad_xtest_input = torch.randn(4, requires_grad=True) # tensor([-0.4646, -0.4403, 1.2525, -0.5953], requires_grad=True) print(torch.autograd.gradcheck(Sigmoid.apply, (test_input,), eps=1e-3)) # pass print(torch.autograd.gradcheck(torch.sigmoid, (test_input,), eps=1e-3)) # pass返回True就代表是正确的
参考
PyTorch 源码解读之 torch.autograd:梯度计算详解 - 知乎
本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

