优化器pytoch代码分析

通过torch.optim导入优化器。优化器继承class Optimizer,从Optimizer开始分析,再到SGD和Adam。

Class Optimizer:

def __init__(self, params, defaults):torch._C._log_api_usage_once("python.optimizer")self.defaults = defaultsif isinstance(params, torch.Tensor):raise TypeError("params argument given to the optimizer should be ""an iterable of Tensors or dicts, but got " +torch.typename(params))self.state = defaultdict(dict)self.param_groups = []param_groups = list(params)if len(param_groups) == 0:raise ValueError("optimizer got an empty parameter list")if not isinstance(param_groups[0], dict):param_groups = [{'params': param_groups}]for param_group in param_groups:self.add_param_group(param_group)

params为传入优化器的模型参数即model.parameters(),defaults为其他参数,如在SGD中:

defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, nesterov=nesterov)super(SGD, self).__init__(params, defaults)

然后通过super调用父类Optimizer的初始化方法。在Optimizer中将模型参数params以及优化器参数放在self.add_param_group()中。

opt.zero_grad:设置参数p.grad.detach_()和p.grad.zero_()。

def zero_grad(self):r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""for group in self.param_groups:for p in group['params']:if p.grad is not None:p.grad.detach_()p.grad.zero_()

Optimizer的存储与恢复:

def __getstate__(self):return {'defaults': self.defaults,'state': self.state,'param_groups': self.param_groups,}def __setstate__(self, state):self.__dict__.update(state)def state_dict(self):def pack_group(group):packed = {k: v for k, v in group.items() if k != 'params'}packed['params'] = [id(p) for p in group['params']]return packedparam_groups = [pack_group(g) for g in self.param_groups]# Remap state to use ids as keyspacked_state = {(id(k) if isinstance(k, torch.Tensor) else k): vfor k, v in self.state.items()}return {'state': packed_state,'param_groups': param_groups,}def load_state_dict(self, state_dict):state_dict = d


本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

相关文章

立即
投稿

微信公众账号

微信扫一扫加关注

返回
顶部