Pytorch DDP 分布式训练实例

相关注释已经写在代码块中。

代码实例

'''
文件名: DDP.py
脚本启动指令:
if torch version < 1.12.0:python -m torch.distributed.launch --nproc_per_node=2 DDP.py
else:torchrun --nproc_per_node=2 DDP.py
'''import os
import random
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR100
import torch.nn.functional as F
from torch import distributed
from torch.utils.data import DataLoader
from torchvision import models## 初始化DDP进程组
try:rank = int(os.environ["RANK"])local_rank = int(os.environ["LOCAL_RANK"])world_size = int(os.environ["WORLD_SIZE"])distributed.init_process_group("nccl")
except KeyError:rank = 0local_rank = 0world_size = 1distributed.init_process_group(backend="nccl",init_method="tcp://127.0.0.1:12584",rank=rank,world_size=world_size,)def seed_all(seed):if not seed:seed = 42torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)torch.cuda.manual_seed(seed)np.random.seed(seed)random.seed(seed)os.environ["PYTHONHASHSEED"] = str(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = Falsedef build_dataloader():train_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])val_transform = transforms.Compose([transforms.ToTensor()])trainset = CIFAR100(root='your data root', train=True, download=True, transform=train_transform)valset = CIFAR100(root='your data root', train=False, download=True, transform=val_transform)## create samplertrain_sampler = torch.utils.data.distributed.DistributedSampler(trainset)val_sampler = torch.utils.data.distributed.DistributedSampler(valset)## 这里的batch_size指的是每个进程下的batch_size, 总batch_size是这里的batch_size再乘以并行数(world_size)trainloader = DataLoader(trainset, batch_size=16, num_workers=2, sampler=train_sampler, shuffle=False, pin_memory=True, drop_last=True)valloader = DataLoader(valset, batch_size=16, num_workers=2, sampler=val_sampler, shuffle=False, pin_memory=True, drop_last=True)return trainloader, valloaderdef metric(logit, truth):prob = F.softmax(logit, 1)_, top = prob.topk(1, dim=1, largest=True, sorted=True)correct = top.eq(truth.view(-1, 1).expand_as(top))correct = correct.data.cpu().numpy()correct = np.mean(correct)return correctdef main():## 全局层面控制随机数, 基本控制全局层面的随机数seed_all(42)## set devicetorch.cuda.set_device(local_rank)## build dataloadertrainloader, valloader = build_dataloader()## build modelmodel = models.resnet101(pretrained=False, num_classes=100).to(local_rank)## load modelckpt_path = 'your model dir'if rank == 0 and ckpt_path is not None:model.load_state_dict(torch.load(ckpt_path, map_location=torch.device("cuda", local_rank)))## use SyncBatchNormmodel = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(local_rank)## build DDP modelmodel = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True  # 当为True时, 在forward结束后, 会标记出所有没被用到的parameter, 提前把这些设定为ready. 默认为False, 因为其会拖慢速度.)## get optimizer。optimizer = torch.optim.SGD(model.parameters(), lr=0.001)## build loss functionloss_func = nn.CrossEntropyLoss().to(local_rank)## train modelnum_epochs = 100model.train()for epoch in range(num_epochs):## 设置sampler的epoch,让不同的epoch产生shuffle的效果trainloader.sampler.set_epoch(epoch)for data, label in trainloader:data, label = data.to(local_rank), label.to(local_rank)optimizer.zero_grad()prediction = model(data)loss = loss_func(prediction, label)## 同步进程distributed.barrier()'''不需要使用distributed.all_reduce来对loss进行累加求和并取平均, DDP在求梯度时会自动计算不同进程下梯度的均值可参考官方文档: https://pytorch.org/docs/stable/notes/ddp.html'''loss.backward()## 查看模型参数梯度, 通过打印各进程的梯度, 验证各进程的梯度是否相同for name, param in model.named_parameters():print(f'name = {name}, grad_value = {param.grad}')optimizer.step()## 模型保存的是model.moduleif rank == 0:torch.save(model.module.state_dict(), "%d.ckpt" % epoch)## evalif (epoch+1) % 5 == 0:total_acc = 0for data, label in valloader:data, label = data.to(local_rank), label.to(local_rank)prediction = model(data)## 收集不同进程下的预测值_gather_prediction = [torch.zeros_like(prediction).cuda()for _ in range(world_size)]_gather_label = [torch.zeros_like(label).cuda()for _ in range(world_size)]distributed.all_gather(_gather_prediction, prediction)distributed.all_gather(_gather_label, label)prediction = torch.cat(_gather_prediction)label = torch.cat(_gather_label)accuracy = metric(prediction, label)total_acc += accuracyavg_acc = total_acc / len(valloader)print(avg_acc)## destroydistributed.destroy_process_group()if __name__ == "__main__":main()


本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

相关文章

立即
投稿

微信公众账号

微信扫一扫加关注

返回
顶部