关于在循环中backward()的注意事项
许多人会在训练时遇到这样一个报错:
trying to backward through the graph a second time (or directly access saved tensors after they have already been freed)
一般原因是在每个epoch中调用loss.backward()时 触发这个错误。此时一般不要按照报错代码中官方给出的推荐一样去做retain_graph=True,因为这样每次循环都会增大memory,导致out of menmory. 而是应该去寻找自己代码的问题,做detach()。参考python - RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed) - Stack Overflow
我的问题和他不太一样,更为简单易懂,但是原理差不多。触发这个的场景是在执行cw attack的时候。 报错代码大致如下
target_feature = model(target_image)def forward(self,images,labels):images = images.clone().detach().to(self.device)w = images.clone().detach()w.requires_grad = TrueMSELoss = nn.MSELoss(reduction='none')Flatten = nn.Flatten()optimizer = optim.Adam([w], lr=self.lr)for step in range(self.steps):adv_img = self.f1(w)adv_img_feature = self.model(F.interpolate(adv_img*2-1,(112,112),mode='bilinear')) current_L2_2 = MSELoss(Flatten(adv_feature),Flatten(target_feature)).sum(dim=1)L2_loss_2 = current_L2_2.sum()optimizer.zero_grad()loss.backward()optimizer.step()
因为target_feature 是从第一行代码生成过来的,所以target_feature 是在上一个计算图中的,带有grad_fn, 也是叶子节点。所以应该像torchattacks 官方代码一样先对参与loss计算的变量在循环外都先做detach()
target_feature = target_feature.clone().detach().to(self.device)
或者直接在生成的时候操作,这样target_feature就不会在上一个计算图中。
with torch.no_grad():target_feature = model(target_image)
本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!
