Pytorch中optimizer类初始化传入参数分析(分析源码)

今天在跟随沐神的课看见了以前没见过SGD参数传入方式(才学没多久,见识浅陋):

trainer = torch.optim.SGD([{'params': params_1x}, {'params': net.fc.parameters(), 'lr': learning_rate * 10}],lr=learning_rate, weight_decay=0.001) 

传入了一个列表中包含了两个字典,作用是为了做到在最后一层中的学习率与其他层不一样。这是怎么做到的呢?

于是我看了看SGD类__init__函数源码:

    def __init__(self, params, lr=required, momentum=0, dampening=0,weight_decay=0, nesterov=False, *, maximize=False, foreach: Optional[bool] = None):if lr is not required and lr < 0.0:raise ValueError("Invalid learning rate: {}".format(lr))if momentum < 0.0:raise ValueError("Invalid momentum value: {}".format(momentum))if weight_decay < 0.0:raise ValueError("Invalid weight_decay value: {}".format(weight_decay))defaults = dict(lr=lr, momentum=momentum, dampening=dampening,weight_decay=weight_decay, nesterov=nesterov,maximize=maximize, foreach=foreach)if nesterov and (momentum <= 0 or dampening != 0):raise ValueError("Nesterov momentum requires a momentum and zero dampening")super(SGD, self).__init__(params, defaults)

我们可以看到在这个函数params只在最后的父类初始化函数中用到了,我继续查看SGD父类Optimizer类的__init__函数源码:

    def __init__(self, params, defaults):torch._C._log_api_usage_once("python.optimizer")self.defaults = defaultsself._hook_for_profile()if isinstance(params, torch.Tensor):raise TypeError("params argument given to the optimizer should be ""an iterable of Tensors or dicts, but got " +torch.typename(params))self.state = defaultdict(dict)self.param_groups = []param_groups = list(params)if len(param_groups) == 0:raise ValueError("optimizer got an empty parameter list")if not isinstance(param_groups[0], dict):param_groups = [{'params': param_groups}]for param_group in param_groups:self.add_param_group(param_group)self._warned_capturable_if_run_uncaptured = True

在第14行的if语句中它判断了传入的param_groups是否为一个字典,如果不为字典就将它变成键为"params",值为param_groups的字典。

并且在接下来的for loop中对param_groups进行迭代,然后调用函数add_param_group将每一个param_groups中的pram_group传入其中。所以先要弄清楚这个parameter怎么用还得查看add_param_group函数。这个函数有点长,涉及到本问题的核心代码就是:

        for name, default in self.defaults.items():if default is required and name not in param_group:raise ValueError("parameter group didn't specify a value of required optimization parameter " +name)else:param_group.setdefault(name, default)

这是一个遍历self.defaults.items()的一个for loop,从之前的代码可知defaults里面放的也是一些字典,存了一些关于优化器的超参数。for loop中的if语句是用于判断default中是否有与param_group相同的键,如果没有的话就将此时这个default中的item添加到param_group中去。

回看我们最初的问题,SGD初始化时传入的那个含有两个字典的列表,只要列表里的字典中含有与超参数名字相同的键,不就能够做到改变这层的超参数与其他的曾不一样了吗?这里第二个字典中有{'lr':learning_rate * 10}这个键值对就是这个作用!!

总结:其实这是一个比较简单的问题,但是在这个问题的解决过程当中,通过不断追踪参数在源码中的用法,我很好地锻炼了我的解决问题的能力,也增加了我对pytorch库的一些了解。

注:我是刚学深度学习没多久的菜鸟,若有大佬发现不足之处,还请多多指正!


本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

相关文章

立即
投稿

微信公众账号

微信扫一扫加关注

返回
顶部