DDP额外进程显存占用
DDP额外进程显存占用
在我们使用DDP做并行训练时,时常会碰到0号卡有额外的进程显存占用,常规的问题是在读取预训练模型时在进程0反复读取,这种问题的解决方案可以通过将预训练权重读取至CPU或者在读取权重时设置map_location,例如:
torch.jit.load('xxx.pt', map_location=torch.device(f'cuda:{rank}'))
这里的rank就是你的GPU号。
但是有时候这种方式可能并不能解决问题,此时可以尝试将find_unused_parameters设置为False,即
model_train = torch.nn.parallel.DistributedDataParallel(model_train, device_ids=[local_rank], find_unused_parameters=False)
本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!
