使用Convnext升级SRGAN
最近再向着超分辨率重建的方向学习,并通过Convnext对SRGAN网络进行了一些结构上的升级,效果还不错。
代码
带训练权重带少量数据集(414mb):
链接:https://pan.baidu.com/s/1KYyHyE5BpCTjKNuwBh52Wg?pwd=dnmd
提取码:dnmd
纯代码(5.8mb):
链接:https://pan.baidu.com/s/18kqXsh6NnKKMNg3WHdhRUA?pwd=89ay
提取码:89ay
视频
SRGAN完整代码,并用Convnext进行优化_哔哩哔哩_bilibili
如下是我的使用Convnext的SRGANext与SRGAN的比较。

如下是SRGAN在第100轮训练的低分辨率图像,由低分辨率图像生成的高分辨率图像,和高分辨率图像



如下是SRGANext在第100轮训练的低分辨率图像,由低分辨率图像生成的高分辨率图像,和高分辨率图像



这是SRGAN的结构
这是我的SRGANext的生成器结构

这是我的SRGANext的辨别器结构

生成器和辨别器都采用的block
附上SRGAN代码
import torch
import torch.nn as nnclass ResidualBlock(nn.Module):'''两个卷积,不改变大小;也不该变维度,因此残差也不需要1*1卷积核调整维度'''def __init__(self, in_features):super(ResidualBlock, self).__init__()self.conv_block = nn.Sequential(nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(in_features, 0.8),nn.PReLU(),nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(in_features, 0.8),)def forward(self, x):return x + self.conv_block(x)class Generator(nn.Module):def __init__(self, scale_factor=2, num_residual_blocks=16):super(Generator, self).__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=9, stride=1, padding=4,padding_mode='reflect', bias=True)self.prelu = nn.PReLU()self.residual_blocks = nn.Sequential(*[ResidualBlock(64) for _ in range(num_residual_blocks)])self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)self.bn2 = nn.BatchNorm2d(64)self.upsample = nn.Sequential(nn.Conv2d(64, 256, kernel_size=3, stride=1, padding=1,padding_mode='reflect', bias=True),nn.PixelShuffle(scale_factor),nn.PReLU(),nn.Conv2d(64, 256, kernel_size=3, stride=1, padding=1,padding_mode='reflect', bias=True),nn.PixelShuffle(scale_factor),nn.PReLU(),nn.Conv2d(64, 3, kernel_size=9, stride=1, padding=4, bias=True))def forward(self, x):out = self.conv1(x)out = self.prelu(out)residual = outout = self.residual_blocks(out)out = self.conv2(out)out = self.bn2(out)out += residualout = self.upsample(out)return outclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=True),nn.LeakyReLU(0.2),nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=True),nn.BatchNorm2d(64),nn.LeakyReLU(0.2),nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),nn.BatchNorm2d(128),nn.LeakyReLU(0.2),nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=True),nn.BatchNorm2d(128),nn.LeakyReLU(0.2),nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),nn.BatchNorm2d(256),nn.LeakyReLU(0.2),nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=True),nn.BatchNorm2d(256),nn.LeakyReLU(0.2),nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),nn.BatchNorm2d(512),nn.LeakyReLU(0.2),nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1, bias=True),nn.BatchNorm2d(512),nn.LeakyReLU(0.2),nn.AdaptiveAvgPool2d(1),nn.Flatten(),nn.Linear(512, 1024),nn.Dropout(),nn.LeakyReLU(0.2),nn.Linear(1024, 1),nn.Sigmoid())def forward(self, x):out = self.layer(x)return out
附上SRGANext代码
注意其中layernorm的通道维度在最后。
import torch
import torch.nn as nnclass LayerNorm(nn.Module):'''[b,c,h,w] --> permute [b,h,w,c] --> LN --> permute [b,c,h,w]'''def __init__(self,dim):super(LayerNorm, self).__init__()self.norm = nn.LayerNorm(dim)def forward(self,x):x = x.permute(0,2,3,1)x = self.norm(x)x = x.permute(0,3,1,2)return xclass SRGANextBlock(nn.Module):def __init__(self,dim, layer_scale_init_value=1e-6):super(SRGANextBlock, self).__init__()self.dconv1 = nn.Conv2d(in_channels=dim,out_channels=dim,kernel_size=7,stride=1,padding=3,groups=dim)self.norm1 = LayerNorm(dim)self.pconv1 = nn.Conv2d(dim,4*dim,kernel_size=1,stride=1)self.act = nn.GELU()self.pconv2 = nn.Conv2d(4*dim,dim,kernel_size=1,stride=1)self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim,)),requires_grad=True) if layer_scale_init_value > 0 else Nonedef forward(self,x):shortcut = xx = self.dconv1(x)x = self.norm1(x)x = self.pconv1(x)x = self.act(x)x = self.pconv2(x)x = x.permute(0, 2, 3, 1) # [N, C, H, W] -> [N, H, W, C]if self.gamma is not None:x = self.gamma * xx = x.permute(0, 3, 1, 2) # [N, H, W, C] -> [N, C, H, W]x = shortcut + xreturn xclass Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.stem = nn.Conv2d(3,64,kernel_size=7,stride=1,padding=3)self.norm1 = LayerNorm(64)self.block1 = nn.Sequential(*[SRGANextBlock(64) for _ in range(3)])self.norm2 = LayerNorm(64)self.pconv1 = nn.Conv2d(64,128,kernel_size=1,stride=1)self.block2 = nn.Sequential(*[SRGANextBlock(128) for _ in range(3)])self.norm3 = LayerNorm(128)self.pconv2 = nn.Conv2d(128,256,kernel_size=1,stride=1)self.block3 = nn.Sequential(*[SRGANextBlock(256) for _ in range(9)])self.norm4 = LayerNorm(256)self.pconv3 = nn.Conv2d(256,512,kernel_size=1,stride=1)self.block4 = nn.Sequential(*[SRGANextBlock(512) for _ in range(3)])self.upsample = nn.Sequential(nn.Conv2d(512,1024,kernel_size=3,stride=1,padding=1),nn.PixelShuffle(2),nn.PReLU(),nn.Conv2d(256,512,kernel_size=3,stride=1,padding=1),nn.PixelShuffle(2),nn.PReLU())self.conv1 = nn.Conv2d(128,3,kernel_size=7,stride=1,padding=3)def forward(self,x):x = self.stem(x)x = self.norm1(x)x = self.block1(x)x = self.norm2(x)x = self.pconv1(x)x = self.block2(x)x = self.norm3(x)x = self.pconv2(x)x = self.block3(x)x = self.norm4(x)x = self.pconv3(x)x = self.block4(x)x = self.upsample(x)x = self.conv1(x)return xclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.stem = nn.Sequential(nn.Conv2d(3,96,kernel_size=7,stride=2,padding=3),LayerNorm(96))self.block1 = nn.Sequential(*[SRGANextBlock(96) for _ in range(3)])self.downsample1 = nn.Sequential(LayerNorm(96),nn.Conv2d(96,192,kernel_size=2,stride=2))self.block2 = nn.Sequential(*[SRGANextBlock(192) for _ in range(3)])self.downsample2 = nn.Sequential(LayerNorm(192),nn.Conv2d(192, 384, kernel_size=2, stride=2))self.block3 = nn.Sequential(*[SRGANextBlock(384) for _ in range(9)])self.downsample3 = nn.Sequential(LayerNorm(384),nn.Conv2d(384, 768, kernel_size=2, stride=2))self.block4 = nn.Sequential(*[SRGANextBlock(768) for _ in range(3)])self.endLayer = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Flatten(),nn.Linear(768,1024),nn.LeakyReLU(0.2),nn.Linear(1024, 1),nn.Sigmoid())def forward(self,x):x = self.stem(x)shortcut = xx = shortcut+self.block1(x)x = self.downsample1(x)shortcut = xx = shortcut+self.block2(x)x = self.downsample2(x)shortcut = xx = shortcut+self.block3(x)x = self.downsample3(x)shortcut = xx = shortcut+self.block4(x)x = self.endLayer(x)return x
本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!
