IA-SEG项目中DIAL-Filters(IAPM模块+LGF模块)使用解读

IA-SEG项目源自论文Improving Nighttime Driving-Scene Segmentation via Dual Image-adaptive Learnable Filters,其核心就是在原有的语义分割模型上添加了DIAL-Filters。而,DIAL-Filters由两部分组成,包括一个图像自适应处理模块(IAPM,即IA-YOLO中的CNN-PP+DIF模块)和一个可学习的引导滤波器(LGF)。其项目代码使用pytorch实现,为能在pytorch下实现域自适应的检测算法,故对该项目进行分析。IA-SEG项目为针对夜间环境下的语义分割项目,其包含监督学习和非监督学习部分,这里只讨论其核心部分IAPM模块(CNN-PP与DIF)+LGF模块的使用。在本文的第三章和第四章有相关的代码使用案例。

除DIAL-Filters外,IA-SEG论文还提出了一种非监督学习框架,在博文最后面描述,感兴趣的朋友可以去查阅论文原文,或者看我的IA-SEG论文翻译讲解。

IA-SEG项目地址:https://github.com/wenyyu/IA-Seg#arxiv
在这里插入图片描述

1、CNN-PP模块

1.1 基本介绍

CNN-PP模块为DIP模块优化图像提供filter参数,其本质是一个简洁的卷积神经网络,其输入部分为低分辨的原始图,其输出为DIP模块的优化参数。
在这里插入图片描述
在IA-SEG中,CNN-PP模块的参数(预测4个filter参数,278K)比在IA-YOLO(预测15个filter参数,165k)要多。注:IA-SEG与IA-YOLO均为同一作者实现

1.2 实现代码

代码地址:https://github.com/wenyyu/IA-Seg/blob/main/network/dip.py
代码全文如下,其中涉及到一个外部对象cfg,该对象为配置文件,包含num_filter_parameters和cfg.filters在CNN_PP中被用到。

#! /usr/bin/env python
# coding=utf-8
import torch
import torch.nn as nnimport numpy as npfrom configs.train_config import cfgimport timedef conv_downsample(in_filters, out_filters, normalization=False):layers = [nn.Conv2d(in_filters, out_filters, 3, stride=2, padding=1)]layers.append(nn.LeakyReLU(0.2))if normalization:layers.append(nn.InstanceNorm2d(out_filters, affine=True))return layersclass CNN_PP(nn.Module):def __init__(self, in_channels=3):super(CNN_PP, self).__init__()self.model = nn.Sequential(nn.Upsample(size=(256,256),mode='bilinear'),nn.Conv2d(3, 16, 3, stride=2, padding=1),nn.LeakyReLU(0.2),nn.InstanceNorm2d(16, affine=True),*conv_downsample(16, 32, normalization=True),*conv_downsample(32, 64, normalization=True),*conv_downsample(64, 128, normalization=True),*conv_downsample(128, 128),#*discriminator_block(128, 128, normalization=True),nn.Dropout(p=0.5),nn.Conv2d(128, cfg.num_filter_parameters, 8, padding=0),)def forward(self, img_input):self.Pr = self.model(img_input)self.filtered_image_batch = img_inputfilters = cfg.filtersfilters = [x(img_input, cfg) for x in filters]self.filter_parameters = []self.filtered_images = []for j, filter in enumerate(filters):# with tf.variable_scope('filter_%d' % j):# print('    creating filter:', j, 'name:', str(filter.__class__), 'abbr.',#       filter.get_short_name())# print('      filter_features:', self.Pr.shape)self.filtered_image_batch, filter_parameter = filter.apply(self.filtered_image_batch, self.Pr)self.filter_parameters.append(filter_parameter)self.filtered_images.append(self.filtered_image_batch)# print('      output:', self.filtered_image_batch.shape)return self.filtered_image_batch, self.filtered_images, self.Pr, self.filter_parametersdef DIP():model = CNN_PP()return model

1.3 其他关联代码

DIP模块设计到了cfg对象,其代码地址为:
https://github.com/wenyyu/IA-Seg/blob/main/configs/train_config.py

这里与CNN-PP及CNN-DIP相关的代码如下:


import argparse
from network.filters import *cfg.filters = [ExposureFilter, GammaFilter, ContrastFilter, UsmFilter]
# cfg.filters = []cfg.num_filter_parameters = 4
#这里的配置均被用于DIF模块的滤波操作
cfg.exposure_begin_param = 0
cfg.gamma_begin_param = 1
cfg.contrast_begin_param = 2
cfg.usm_begin_param = 3
# Gamma = 1/x ~ x
cfg.curve_steps = 8
cfg.gamma_range = 3
cfg.exposure_range = 3.5
cfg.wb_range = 1.1
cfg.color_curve_range = (0.90, 1.10)
cfg.lab_curve_range = (0.90, 1.10)
cfg.tone_curve_range = (0.5, 2)
cfg.defog_range = (0.1, 1.0)
cfg.usm_range = (0.0, 5)
cfg.cont_range = (0.0, 1.0)

此外,其还关联到DIF的实现代码,后续会描述.
CNN-PP模块作为一个即插即用的头部模块,可以不用添加到模型结构中,在train函数补齐其流程即可。IA-SEG对CNN-PP的使用如下,从中可以看出输入CNNPP的是归一化的图片(但并未进行标准化), 同时CNN-PP的输出也并未与其他图像计算loss,CNN-PP的优化全靠forword流程结束后的loss,这与IA-YOLO中的设计不同

更多使用细节可以查看原作者代码

CNNPP = dip.DIP().to(device)
optimizer.zero_grad()
CNNPP.train()
model= PSPNet(num_classes=args.num_classes, dgf=args.DGF_FLAG).to(device)
model.train()
optimizer = optim.SGD(list(model.parameters())+list(CNNPP.parameters()),lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)for i_iter in range(args.num_steps):for sub_i in range(args.iter_size):_, batch = trainloader_iter.__next__()images, labels, _, _ = batchimages = images.to(device)labels = labels.long().to(device)enhanced_images_pre, ci_map, Pr, filter_parameters = CNNPP(images)enhanced_images = enhanced_images_preenhanced_images[i_pre,...] = standard_transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(enhanced_images_pre[i_pre,...])pred_c = model(enhanced_images)

2、DIF模块

2.1 基本介绍

DIF模块全程Differentiable Image Filters,其由几个具有可调超参数的可微滤波器组成,包括曝光度、伽玛度、对比度和锐度。在IA-SEG中的DIF代码其实是根据IA-YOLO中的DIP代码修改,将原先的TensorFlow实现修改为PyTorch语法,并注释了一些在IA-SEG中不需要用到的Filter模块(Tone Filter 和 Defog Filter)。

2.2 实现代码

代码地址:https://github.com/wenyyu/IA-Seg/blob/main/network/filters.py
其实现代码如下,这里滤除了一下被注释的代码(即原来用tensorflow实现的Tone Filter 和 Defog Filter等).

这里需要注意的是,所有的可微滤波器均继承自Filter,在构建Filter时的参数net,cfg仅有cfg起到作用见1.3章中的代码注释。rgb2lum, tanh_range, lerp函数被引入,为FIlter对象提供数据操作能力。

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from network.util_filters import rgb2lum, tanh_range, lerp
from network.util_filters import *import cv2
import math# device = torch.device("cuda")class Filter(nn.Module):def __init__(self, net, cfg):super(Filter, self).__init__()self.cfg = cfg# self.height, self.width, self.channels = list(map(int, net.get_shape()[1:]))# Specified in child classesself.num_filter_parameters = Noneself.short_name = Noneself.filter_parameters = Nonedef get_short_name(self):assert self.short_namereturn self.short_namedef get_num_filter_parameters(self):assert self.num_filter_parametersreturn self.num_filter_parametersdef get_begin_filter_parameter(self):return self.begin_filter_parameterdef extract_parameters(self, features):# output_dim = self.get_num_filter_parameters(# ) + self.get_num_mask_parameters()# features = ly.fully_connected(#     features,#     self.cfg.fc1_size,#     scope='fc1',#     activation_fn=lrelu,#     weights_initializer=tf.contrib.layers.xavier_initializer())# features = ly.fully_connected(#     features,#     output_dim,#     scope='fc2',#     activation_fn=None,#     weights_initializer=tf.contrib.layers.xavier_initializer())return features[:, self.get_begin_filter_parameter():(self.get_begin_filter_parameter() + self.get_num_filter_parameters())], \features[:, self.get_begin_filter_parameter():(self.get_begin_filter_parameter() + self.get_num_filter_parameters())]# Should be implemented in child classesdef filter_param_regressor(self, features):assert False# Process the whole image, without masking# Should be implemented in child classesdef process(self, img, param, defog, IcA):assert Falsedef debug_info_batched(self):return Falsedef no_high_res(self):return False# Apply the whole filter with maskingdef apply(self,img,img_features=None,defog_A=None,IcA=None,specified_parameter=None,high_res=None):assert (img_features is None) ^ (specified_parameter is None)if img_features is not None:filter_features, mask_parameters = self.extract_parameters(img_features)filter_parameters = self.filter_param_regressor(filter_features)else:assert not self.use_masking()filter_parameters = specified_parameterif high_res is not None:# working on high res...passdebug_info = {}# We only debug the first image of this batchif self.debug_info_batched():debug_info['filter_parameters'] = filter_parameterselse:debug_info['filter_parameters'] = filter_parameters[0]# self.mask_parameters = mask_parameters# self.mask = self.get_mask(img, mask_parameters)# debug_info['mask'] = self.mask[0]#low_res_output = lerp(img, self.process(img, filter_parameters), self.mask)low_res_output = self.process(img, filter_parameters, defog_A, IcA)if high_res is not None:if self.no_high_res():high_res_output = high_reselse:self.high_res_mask = self.get_mask(high_res, mask_parameters)# high_res_output = lerp(high_res,#                        self.process(high_res, filter_parameters, defog, IcA),#                        self.high_res_mask)else:high_res_output = None#return low_res_output, high_res_output, debug_inforeturn low_res_output, filter_parametersdef use_masking(self):return self.cfg.maskingdef get_num_mask_parameters(self):return 6# Input: no need for tanh or sigmoid# Closer to 1 values are applied by filter more strongly# no additional TF variables insidedef get_mask(self, img, mask_parameters):if not self.use_masking():print('* Masking Disabled')return tf.ones(shape=(1, 1, 1, 1), dtype=tf.float32)else:print('* Masking Enabled')with tf.name_scope(name='mask'):# Six parameters for one filterfilter_input_range = 5assert mask_parameters.shape[1] == self.get_num_mask_parameters()mask_parameters = tanh_range(l=-filter_input_range, r=filter_input_range,initial=0)(mask_parameters)size = list(map(int, img.shape[1:3]))grid = np.zeros(shape=[1] + size + [2], dtype=np.float32)shorter_edge = min(size[0], size[1])for i in range(size[0]):for j in range(size[1]):grid[0, i, j,0] = (i + (shorter_edge - size[0]) / 2.0) / shorter_edge - 0.5grid[0, i, j,1] = (j + (shorter_edge - size[1]) / 2.0) / shorter_edge - 0.5grid = tf.constant(grid)# Ax + By + C * L + Dinp = grid[:, :, :, 0, None] * mask_parameters[:, None, None, 0, None] + \grid[:, :, :, 1, None] * mask_parameters[:, None, None, 1, None] + \mask_parameters[:, None, None, 2, None] * (rgb2lum(img) - 0.5) + \mask_parameters[:, None, None, 3, None] * 2# Sharpness and inversioninp *= self.cfg.maximum_sharpness * mask_parameters[:, None, None, 4,None] / filter_input_rangemask = tf.sigmoid(inp)# Strengthmask = mask * (mask_parameters[:, None, None, 5, None] / filter_input_range * 0.5 +0.5) * (1 - self.cfg.minimum_strength) + self.cfg.minimum_strengthprint('mask', mask.shape)return mask# def visualize_filter(self, debug_info, canvas):#   # Visualize only the filter information#   assert Falsedef visualize_mask(self, debug_info, res):return cv2.resize(debug_info['mask'] * np.ones((1, 1, 3), dtype=np.float32),dsize=res,interpolation=cv2.cv2.INTER_NEAREST)def draw_high_res_text(self, text, canvas):cv2.putText(canvas,text, (30, 128),cv2.FONT_HERSHEY_SIMPLEX,0.8, (0, 0, 0),thickness=5)return canvasclass ExposureFilter(Filter):def __init__(self, net, cfg):Filter.__init__(self, net, cfg)self.short_name = 'E'self.begin_filter_parameter = cfg.exposure_begin_paramself.num_filter_parameters = 1def filter_param_regressor(self, features):#param is in (-self.cfg.exposure_range, self.cfg.exposure_range)return tanh_range(-self.cfg.exposure_range, self.cfg.exposure_range, initial=0)(features)def process(self, img, param, defog, IcA):# print('      param:', param)# print('      param:', torch.exp(param * np.log(2)))# return img * torch.exp(torch.tensor(3.31).cuda() * np.log(2))return img * torch.exp(param * np.log(2))class UsmFilter(Filter):#Usm_param is in [Defog_range]def __init__(self, net, cfg):Filter.__init__(self, net, cfg)self.short_name = 'UF'self.begin_filter_parameter = cfg.usm_begin_paramself.num_filter_parameters = 1def filter_param_regressor(self, features):return tanh_range(*self.cfg.usm_range)(features)def process(self, img, param, defog_A, IcA):self.channels = 3kernel = [[0.00078633, 0.00655965, 0.01330373, 0.00655965, 0.00078633],[0.00655965, 0.05472157, 0.11098164, 0.05472157, 0.00655965],[0.01330373, 0.11098164, 0.22508352, 0.11098164, 0.01330373],[0.00655965, 0.05472157, 0.11098164, 0.05472157, 0.00655965],[0.00078633, 0.00655965, 0.01330373, 0.00655965, 0.00078633]]kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)kernel = np.repeat(kernel, self.channels, axis=0)# print('      param:', param)kernel = kernel.to(img.device)# self.weight = nn.Parameter(data=kernel, requires_grad=False)# self.weight.to(device)output = F.conv2d(img, kernel, padding=2, groups=self.channels)img_out = (img - output) * param + img# img_out = (img - output) * torch.tensor(0.043).cuda() + imgreturn img_outclass ContrastFilter(Filter):def __init__(self, net, cfg):Filter.__init__(self, net, cfg)self.short_name = 'Ct'self.begin_filter_parameter = cfg.contrast_begin_paramself.num_filter_parameters = 1def filter_param_regressor(self, features):# return tf.sigmoid(features)# return torch.tanh(features)return tanh_range(*self.cfg.cont_range)(features)def process(self, img, param, defog, IcA):# print('      param.shape:', param.shape)# luminance = torch.minimum(torch.maximum(rgb2lum(img), 0.0), 1.0)luminance = rgb2lum(img)zero = torch.zeros_like(luminance)one = torch.ones_like(luminance)luminance = torch.where(luminance < 0, zero, luminance)luminance = torch.where(luminance > 1, one, luminance)contrast_lum = -torch.cos(math.pi * luminance) * 0.5 + 0.5contrast_image = img / (luminance + 1e-6) * contrast_lumreturn lerp(img, contrast_image, param)# return lerp(img, contrast_image, torch.tensor(0.015).cuda())class ToneFilter(Filter):def __init__(self, net, cfg):Filter.__init__(self, net, cfg)self.curve_steps = cfg.curve_stepsself.short_name = 'T'self.begin_filter_parameter = cfg.tone_begin_paramself.num_filter_parameters = cfg.curve_stepsdef filter_param_regressor(self, features):# tone_curve = tf.reshape(#     features, shape=(-1, 1, self.cfg.curve_steps))[:, None, None, :]tone_curve = tanh_range(*self.cfg.tone_curve_range)(features)return tone_curvedef process(self, img, param, defog, IcA):# img = tf.minimum(img, 1.0)# param = tf.constant([[0.52, 0.53, 0.55, 1.9, 1.8, 1.7, 0.7, 0.6], [0.52, 0.53, 0.55, 1.9, 1.8, 1.7, 0.7, 0.6],#                       [0.52, 0.53, 0.55, 1.9, 1.8, 1.7, 0.7, 0.6], [0.52, 0.53, 0.55, 1.9, 1.8, 1.7, 0.7, 0.6],#                       [0.52, 0.53, 0.55, 1.9, 1.8, 1.7, 0.7, 0.6], [0.52, 0.53, 0.55, 1.9, 1.8, 1.7, 0.7, 0.6]])# param = tf.constant([[0.52, 0.53, 0.55, 1.9, 1.8, 1.7, 0.7, 0.6]])# param = tf.reshape(#     param, shape=(-1, 1, self.cfg.curve_steps))[:, None, None, :]param = torch.unsqueeze(param, 3)# print('      param.shape:', param.shape)tone_curve = paramtone_curve_sum = torch.sum(tone_curve, axis=1) + 1e-30# print('      tone_curve_sum.shape:', tone_curve_sum.shape)total_image = img * 0for i in range(self.cfg.curve_steps):total_image += torch.clamp(img - 1.0 * i / self.cfg.curve_steps, 0, 1.0 / self.cfg.curve_steps) \* param[:, i, :, :]# p_cons = [0.52, 0.53, 0.55, 1.9, 1.8, 1.7, 0.7, 0.6]# for i in range(self.cfg.curve_steps):#   total_image += tf.clip_by_value(img - 1.0 * i / self.cfg.curve_steps, 0, 1.0 / self.cfg.curve_steps) \#                  * p_cons[i]total_image *= self.cfg.curve_steps / tone_curve_sumimg = total_imagereturn img# def visualize_filter(self, debug_info, canvas):#   curve = debug_info['filter_parameters']#   height, width = canvas.shape[:2]#   values = np.array([0] + list(curve[0][0][0]))#   values /= sum(values) + 1e-30#   for j in range(0, self.curve_steps):#     values[j + 1] += values[j]#   for j in range(self.curve_steps):#     p1 = tuple(#         map(int, (width / self.curve_steps * j, height - 1 -#                   values[j] * height)))#     p2 = tuple(#         map(int, (width / self.curve_steps * (j + 1), height - 1 -#                   values[j + 1] * height)))#     cv2.line(canvas, p1, p2, (0, 0, 0), thickness=1)class GammaFilter(Filter):  #gamma_param is in [1/gamma_range, gamma_range]def __init__(self, net, cfg):Filter.__init__(self, net, cfg)self.short_name = 'G'self.begin_filter_parameter = cfg.gamma_begin_paramself.num_filter_parameters = 1def filter_param_regressor(self, features):log_gamma_range = np.log(self.cfg.gamma_range)# return tf.exp(tanh_range(-log_gamma_range, log_gamma_range)(features))return torch.exp(tanh_range(-log_gamma_range, log_gamma_range)(features))def process(self, img, param, defog_A, IcA):# print('      param:', param)# param_1 = param.repeat(1, 3)zero = torch.zeros_like(img) + 0.00001img = torch.where(img <= 0, zero, img)# print("GAMMMA", param)return torch.pow(img, param)# return torch.pow(img, torch.tensor(0.51).cuda())# param_1 = tf.tile(param, [1, 3])# return tf.pow(tf.maximum(img, 0.0001), param_1[:, None, None, :])# return img

2.3 其他关联代码

util_filters为filter提供了一些基础功能函数,如rgb2lum, tanh_range, lerp.
完整代码为: https://github.com/wenyyu/IA-Seg/blob/main/network/util_filters.py
主要代码如下:

import math
import cv2
import torch
import torch.nn as nndef rgb2lum(image):image = 0.27 * image[:, :, :, 0] + 0.67 * image[:, :, :,1] + 0.06 * image[:, :, :, 2]return image[:, :, :, None]def tanh01(x):# return tf.tanh(x) * 0.5 + 0.5return torch.tanh(x) * 0.5 + 0.5def tanh_range(l, r, initial=None):def get_activation(left, right, initial):def activation(x):if initial is not None:bias = math.atanh(2 * (initial - left) / (right - left) - 1)else:bias = 0return tanh01(x + bias) * (right - left) + leftreturn activationreturn get_activation(l, r, initial)def lerp(a, b, l):return (1 - l) * a + l * b

3、IPAM模块使用

IPAM模块实则为上文中CNN-PP与IDF模块的组合,这里在拎出来将使用,实则是为了将代码冲IA-SEG项目中剥离出来,单独使用。
在IA-SEG中,实质上已经将DIF模块嵌入到了CNN-PP的模型中,构成了IPAM模块。但是相关函数代码分离在多个py文件中,不便于使用,故此进行整合

3.1 整合代码

安装依赖项:pip install easydict
整合后的代码如下所示,仅需要修改最底部的cfg即可。这里构建了IPAM类,可以通过IPAM类直接进行图像域适应。

import torch
import torch.nn as nn
import torch.nn.functional as Fimport numpy as np
import time
import math#-----Filter相关的基础函数------
def rgb2lum(image):image = 0.27 * image[:, :, :, 0] + 0.67 * image[:, :, :,1] + 0.06 * image[:, :, :, 2]return image[:, :, :, None]
def tanh01(x):# return tf.tanh(x) * 0.5 + 0.5return torch.tanh(x) * 0.5 + 0.5
def tanh_range(l, r, initial=None):def get_activation(left, right, initial):def activation(x):if initial is not None:bias = math.atanh(2 * (initial - left) / (right - left) - 1)else:bias = 0return tanh01(x + bias) * (right - left) + leftreturn activationreturn get_activation(l, r, initial)
def lerp(a, b, l):return (1 - l) * a + l * b#-----Filter的相关实现------
class Filter(nn.Module):def __init__(self, net, cfg):super(Filter, self).__init__()self.cfg = cfgself.num_filter_parameters = Noneself.short_name = Noneself.filter_parameters = Nonedef get_short_name(self):assert self.short_namereturn self.short_namedef get_num_filter_parameters(self):assert self.num_filter_parametersreturn self.num_filter_parametersdef get_begin_filter_parameter(self):return self.begin_filter_parameterdef extract_parameters(self, features):return features[:, self.get_begin_filter_parameter():(self.get_begin_filter_parameter() + self.get_num_filter_parameters())], \features[:, self.get_begin_filter_parameter():(self.get_begin_filter_parameter() + self.get_num_filter_parameters())]# Should be implemented in child classesdef filter_param_regressor(self, features):assert False# Process the whole image, without masking# Should be implemented in child classesdef process(self, img, param, defog, IcA):assert Falsedef debug_info_batched(self):return Falsedef no_high_res(self):return False# Apply the whole filter with maskingdef apply(self,img,img_features=None,defog_A=None,IcA=None,specified_parameter=None,high_res=None):assert (img_features is None) ^ (specified_parameter is None)if img_features is not None:filter_features, mask_parameters = self.extract_parameters(img_features)filter_parameters = self.filter_param_regressor(filter_features)else:assert not self.use_masking()filter_parameters = specified_parameterif high_res is not None:# working on high res...passdebug_info = {}# We only debug the first image of this batchif self.debug_info_batched():debug_info['filter_parameters'] = filter_parameterselse:debug_info['filter_parameters'] = filter_parameters[0]# self.mask_parameters = mask_parameters# self.mask = self.get_mask(img, mask_parameters)# debug_info['mask'] = self.mask[0]#low_res_output = lerp(img, self.process(img, filter_parameters), self.mask)low_res_output = self.process(img, filter_parameters, defog_A, IcA)if high_res is not None:if self.no_high_res():high_res_output = high_reselse:self.high_res_mask = self.get_mask(high_res, mask_parameters)# high_res_output = lerp(high_res,#                        self.process(high_res, filter_parameters, defog, IcA),#                        self.high_res_mask)else:high_res_output = None#return low_res_output, high_res_output, debug_inforeturn low_res_output, filter_parametersdef use_masking(self):return self.cfg.maskingdef get_num_mask_parameters(self):return 6# Input: no need for tanh or sigmoid# Closer to 1 values are applied by filter more strongly# no additional TF variables insidedef get_mask(self, img, mask_parameters):if not self.use_masking():print('* Masking Disabled')return tf.ones(shape=(1, 1, 1, 1), dtype=tf.float32)else:print('* Masking Enabled')with tf.name_scope(name='mask'):# Six parameters for one filterfilter_input_range = 5assert mask_parameters.shape[1] == self.get_num_mask_parameters()mask_parameters = tanh_range(l=-filter_input_range, r=filter_input_range,initial=0)(mask_parameters)size = list(map(int, img.shape[1:3]))grid = np.zeros(shape=[1] + size + [2], dtype=np.float32)shorter_edge = min(size[0], size[1])for i in range(size[0]):for j in range(size[1]):grid[0, i, j,0] = (i + (shorter_edge - size[0]) / 2.0) / shorter_edge - 0.5grid[0, i, j,1] = (j + (shorter_edge - size[1]) / 2.0) / shorter_edge - 0.5grid = tf.constant(grid)# Ax + By + C * L + Dinp = grid[:, :, :, 0, None] * mask_parameters[:, None, None, 0, None] + \grid[:, :, :, 1, None] * mask_parameters[:, None, None, 1, None] + \mask_parameters[:, None, None, 2, None] * (rgb2lum(img) - 0.5) + \mask_parameters[:, None, None, 3, None] * 2# Sharpness and inversioninp *= self.cfg.maximum_sharpness * mask_parameters[:, None, None, 4,None] / filter_input_rangemask = tf.sigmoid(inp)# Strengthmask = mask * (mask_parameters[:, None, None, 5, None] / filter_input_range * 0.5 +0.5) * (1 - self.cfg.minimum_strength) + self.cfg.minimum_strengthprint('mask', mask.shape)return mask# def visualize_filter(self, debug_info, canvas):#   # Visualize only the filter information#   assert Falsedef visualize_mask(self, debug_info, res):return cv2.resize(debug_info['mask'] * np.ones((1, 1, 3), dtype=np.float32),dsize=res,interpolation=cv2.cv2.INTER_NEAREST)def draw_high_res_text(self, text, canvas):cv2.putText(canvas,text, (30, 128),cv2.FONT_HERSHEY_SIMPLEX,0.8, (0, 0, 0),thickness=5)return canvasclass ExposureFilter(Filter):def __init__(self, net, cfg):Filter.__init__(self, net, cfg)self.short_name = 'E'self.begin_filter_parameter = cfg.exposure_begin_paramself.num_filter_parameters = 1def filter_param_regressor(self, features):#param is in (-self.cfg.exposure_range, self.cfg.exposure_range)return tanh_range(-self.cfg.exposure_range, self.cfg.exposure_range, initial=0)(features)def process(self, img, param, defog, IcA):return img * torch.exp(param * np.log(2))class UsmFilter(Filter):#Usm_param is in [Defog_range]def __init__(self, net, cfg):Filter.__init__(self, net, cfg)self.short_name = 'UF'self.begin_filter_parameter = cfg.usm_begin_paramself.num_filter_parameters = 1def filter_param_regressor(self, features):return tanh_range(*self.cfg.usm_range)(features)def process(self, img, param, defog_A, IcA):self.channels = 3kernel = [[0.00078633, 0.00655965, 0.01330373, 0.00655965, 0.00078633],[0.00655965, 0.05472157, 0.11098164, 0.05472157, 0.00655965],[0.01330373, 0.11098164, 0.22508352, 0.11098164, 0.01330373],[0.00655965, 0.05472157, 0.11098164, 0.05472157, 0.00655965],[0.00078633, 0.00655965, 0.01330373, 0.00655965, 0.00078633]]kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)kernel = np.repeat(kernel, self.channels, axis=0)# print('      param:', param)kernel = kernel.to(img.device)# self.weight = nn.Parameter(data=kernel, requires_grad=False)# self.weight.to(device)output = F.conv2d(img, kernel, padding=2, groups=self.channels)img_out = (img - output) * param + img# img_out = (img - output) * torch.tensor(0.043).cuda() + imgreturn img_outclass ContrastFilter(Filter):def __init__(self, net, cfg):Filter.__init__(self, net, cfg)self.short_name = 'Ct'self.begin_filter_parameter = cfg.contrast_begin_paramself.num_filter_parameters = 1def filter_param_regressor(self, features):return tanh_range(*self.cfg.cont_range)(features)def process(self, img, param, defog, IcA):# print('      param.shape:', param.shape)# luminance = torch.minimum(torch.maximum(rgb2lum(img), 0.0), 1.0)luminance = rgb2lum(img)zero = torch.zeros_like(luminance)one = torch.ones_like(luminance)luminance = torch.where(luminance < 0, zero, luminance)luminance = torch.where(luminance > 1, one, luminance)contrast_lum = -torch.cos(math.pi * luminance) * 0.5 + 0.5contrast_image = img / (luminance + 1e-6) * contrast_lumreturn lerp(img, contrast_image, param)# return lerp(img, contrast_image, torch.tensor(0.015).cuda())class ToneFilter(Filter):def __init__(self, net, cfg):Filter.__init__(self, net, cfg)self.curve_steps = cfg.curve_stepsself.short_name = 'T'self.begin_filter_parameter = cfg.tone_begin_paramself.num_filter_parameters = cfg.curve_stepsdef filter_param_regressor(self, features):tone_curve = tanh_range(*self.cfg.tone_curve_range)(features)return tone_curvedef process(self, img, param, defog, IcA):param = torch.unsqueeze(param, 3)# print('      param.shape:', param.shape)tone_curve = paramtone_curve_sum = torch.sum(tone_curve, axis=1) + 1e-30# print('      tone_curve_sum.shape:', tone_curve_sum.shape)total_image = img * 0for i in range(self.cfg.curve_steps):total_image += torch.clamp(img - 1.0 * i / self.cfg.curve_steps, 0, 1.0 / self.cfg.curve_steps) \* param[:, i, :, :]total_image *= self.cfg.curve_steps / tone_curve_sumimg = total_imagereturn imgclass GammaFilter(Filter):  #gamma_param is in [1/gamma_range, gamma_range]def __init__(self, net, cfg):Filter.__init__(self, net, cfg)self.short_name = 'G'self.begin_filter_parameter = cfg.gamma_begin_paramself.num_filter_parameters = 1def filter_param_regressor(self, features):log_gamma_range = np.log(self.cfg.gamma_range)# return tf.exp(tanh_range(-log_gamma_range, log_gamma_range)(features))return torch.exp(tanh_range(-log_gamma_range, log_gamma_range)(features))def process(self, img, param, defog_A, IcA):# print('      param:', param)# param_1 = param.repeat(1, 3)zero = torch.zeros_like(img) + 0.00001img = torch.where(img <= 0, zero, img)# print("GAMMMA", param)return torch.pow(img, param)#----------Filter模块的参数------------
from easydict import EasyDict as edict
cfg=edict()
cfg.num_filter_parameters = 4
#这里的配置均被用于DIF模块的滤波操作
cfg.exposure_begin_param = 0
cfg.gamma_begin_param = 1
cfg.contrast_begin_param = 2
cfg.usm_begin_param = 3
# Gamma = 1/x ~ x
cfg.curve_steps = 8
cfg.gamma_range = 3
cfg.exposure_range = 3.5
cfg.wb_range = 1.1
cfg.color_curve_range = (0.90, 1.10)
cfg.lab_curve_range = (0.90, 1.10)
cfg.tone_curve_range = (0.5, 2)
cfg.defog_range = (0.1, 1.0)
cfg.usm_range = (0.0, 5)
cfg.cont_range = (0.0, 1.0)#----------DIF模块------------
class DIF(nn.Module):def __init__(self, Filters):super(DIF, self).__init__()self.Filters=Filtersdef forward(self, img_input,Pr):self.filtered_image_batch = img_inputfilters = [x(img_input, cfg) for x in self.Filters]self.filter_parameters = []self.filtered_images = []for j, filter in enumerate(filters):self.filtered_image_batch, filter_parameter = filter.apply(self.filtered_image_batch, Pr)self.filter_parameters.append(filter_parameter)self.filtered_images.append(self.filtered_image_batch)return self.filtered_image_batch, self.filtered_images, Pr, self.filter_parameters    
#----------IPAM模块------------
def conv_downsample(in_filters, out_filters, normalization=False):layers = [nn.Conv2d(in_filters, out_filters, 3, stride=2, padding=1)]layers.append(nn.LeakyReLU(0.2))if normalization:layers.append(nn.InstanceNorm2d(out_filters, affine=True))return layers
class IPAM(nn.Module):def __init__(self):super(IPAM, self).__init__()self.CNN_PP = nn.Sequential(nn.Upsample(size=(256,256),mode='bilinear'),nn.Conv2d(3, 16, 3, stride=2, padding=1),nn.LeakyReLU(0.2),nn.InstanceNorm2d(16, affine=True),*conv_downsample(16, 32, normalization=True),*conv_downsample(32, 64, normalization=True),*conv_downsample(64, 128, normalization=True),*conv_downsample(128, 128),#*discriminator_block(128, 128, normalization=True),nn.Dropout(p=0.5),nn.Conv2d(128, cfg.num_filter_parameters, 8, padding=0),)Filters=[ExposureFilter, GammaFilter, ContrastFilter, UsmFilter]self.dif=DIF(Filters)def forward(self, img_input):self.Pr = self.CNN_PP(img_input)out = self.dif(img_input,self.Pr)return out

3.2 使用代码

使用代码如下

model = IPAM()
print(model)
x=torch.rand((1,3,256,256))
filtered_image_batch,filtered_images,Pr,filter_parameters=model(x)

代码输出如下,其中filtered_image_batch是优化后的图像,filtered_images是一个长度为4的list,其包含了4个图像增强过程的图像,Pr为DNN-PP的输出,filter_parameters为实际上的DIF参数

IPAM((CNN_PP): Sequential((0): Upsample(size=(256, 256), mode='bilinear')(1): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))(2): LeakyReLU(negative_slope=0.2)(3): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)(4): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))(5): LeakyReLU(negative_slope=0.2)(6): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)(7): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))(8): LeakyReLU(negative_slope=0.2)(9): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)(10): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))(11): LeakyReLU(negative_slope=0.2)(12): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)(13): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))(14): LeakyReLU(negative_slope=0.2)(15): Dropout(p=0.5, inplace=False)(16): Conv2d(128, 4, kernel_size=(8, 8), stride=(1, 1)))(dif): DIF()
)

3.3 使用说明

这里由于IPAM的参数未经过训练,故生成的图像随机性比较强。
其中,ImgUilt的代码在: python工具方法 28 中。需要注意的是IPAM模块输入的图像时需要进行归一化的,这里可以通过检验IA-SEG作者dataset源码

import cv2,torch
from ImgUilt import *
import numpy as np
p=r'D:\YOLO_seq\helmet_yolo\images\train\000092.jpg'
im_tensor,img=read_img_as_tensor(p)model = IPAM().cuda()
im_tensor=im_tensor/255
filtered_image_batch,filtered_images,Pr,filter_parameters=model(im_tensor)new_img=tensor2img(filtered_image_batch.detach()*255)
myimshows([img,new_img])

执行效果如下所示,可见img在进过IPAM处理后,得到了随机增强,下图效果表明了局部的边缘增强效果。
在这里插入图片描述
按照IA-SEG作者的用法,IPAM模块的参数优化不需要额外loss,仅需将其与正常模型的forword流程相连接即可。具体训练代码如下,CNNPP的输出仅与model的输入有关,与任何loss不存在直接关联。

            enhanced_images_pre, ci_map, Pr, filter_parameters = CNNPP(images)enhanced_images = enhanced_images_prefor i_pre in range(enhanced_images_pre.shape[0]):enhanced_images[i_pre,...] = standard_transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(enhanced_images_pre[i_pre,...])if args.model == 'RefineNet' or args.model.startswith('deeplabv3'):pred_c = model(enhanced_images)else:_, pred_c = model(enhanced_images)pred_c = interp(pred_c)loss_seg = seg_loss(pred_c, labels)loss = loss_seg #+ loss_seg_dark_dynamic + loss_seg_mix #+ loss_seg_dark_dynamic #+ loss_enhanceloss_s = loss / args.iter_sizeloss_s.backward(retain_graph=True)loss_seg_value += loss_seg.item() / args.iter_size

同时在使用中,也可以参考IA-YOLO中的用法,将加噪声后的图像传给IPAM,将原始清晰图像与IPAM优化后的图像计算loss

4、LGF模块

4.1 模块简介

引导滤波器是一种边缘保持和梯度保持的图像操作,它利用引导图像中的对象边界来检测对象的显著性。它能够抑制目标外的显著性,提高下行检测或分割性能。从效果上看其就是对输出的feature map的微调。LGF模块的伪代码如下所示,其中fmean表示一个窗口半径为r的平均滤波器。相关性(corr)、方差(var)和协方差(cov)的缩写代表了这些变量的原始含义。其更多详细说明可以查看相关论文。

在这里插入图片描述

4.2 实现代码

代码地址:https://github.com/wenyyu/IA-Seg/blob/d6393cc87e5ca95ab3b27dee4ec31293256ab9a4/network/guided_filter.py
代码原文如下,可见guided_filter没有依赖任何外部函数。其中有GuidedFilter和FastGuidedFilter两个类,在IA-SEG中并没有使用FastGuidedFilter(当输入其中的三个参数lr_x, lr_y, hr_x,lr_x与hr_x相同时,其与GuidedFilter效果一模一样)。

以下代码的亮点在于实现了可微的方框滤波

import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variabledef diff_x(input, r):assert input.dim() == 4left   = input[:, :,         r:2 * r + 1]middle = input[:, :, 2 * r + 1:         ] - input[:, :,           :-2 * r - 1]right  = input[:, :,        -1:         ] - input[:, :, -2 * r - 1:    -r - 1]output = torch.cat([left, middle, right], dim=2)return outputdef diff_y(input, r):assert input.dim() == 4left   = input[:, :, :,         r:2 * r + 1]middle = input[:, :, :, 2 * r + 1:         ] - input[:, :, :,           :-2 * r - 1]right  = input[:, :, :,        -1:         ] - input[:, :, :, -2 * r - 1:    -r - 1]output = torch.cat([left, middle, right], dim=3)return outputclass BoxFilter(nn.Module):def __init__(self, r):super(BoxFilter, self).__init__()self.r = rdef forward(self, x):assert x.dim() == 4return diff_y(diff_x(x.cumsum(dim=2), self.r).cumsum(dim=3), self.r)class FastGuidedFilter(nn.Module):def __init__(self, r, eps=1e-8):super(FastGuidedFilter, self).__init__()self.r = rself.eps = epsself.boxfilter = BoxFilter(r)def forward(self, lr_x, lr_y, hr_x):n_lrx, c_lrx, h_lrx, w_lrx = lr_x.size()n_lry, c_lry, h_lry, w_lry = lr_y.size()n_hrx, c_hrx, h_hrx, w_hrx = hr_x.size()assert n_lrx == n_lry and n_lry == n_hrxassert c_lrx == c_hrx and (c_lrx == 1 or c_lrx == c_lry)assert h_lrx == h_lry and w_lrx == w_lryassert h_lrx > 2*self.r+1 and w_lrx > 2*self.r+1## NN = self.boxfilter(Variable(lr_x.data.new().resize_((1, 1, h_lrx, w_lrx)).fill_(1.0)))## mean_xmean_x = self.boxfilter(lr_x) / N## mean_ymean_y = self.boxfilter(lr_y) / N## cov_xycov_xy = self.boxfilter(lr_x * lr_y) / N - mean_x * mean_y## var_xvar_x = self.boxfilter(lr_x * lr_x) / N - mean_x * mean_x## AA = cov_xy / (var_x + self.eps)## bb = mean_y - A * mean_x## mean_A; mean_bmean_A = F.interpolate(A, (h_hrx, w_hrx), mode='bilinear', align_corners=True)mean_b = F.interpolate(b, (h_hrx, w_hrx), mode='bilinear', align_corners=True)return mean_A*hr_x+mean_bclass GuidedFilter(nn.Module):def __init__(self, r, eps=1e-8):super(GuidedFilter, self).__init__()self.r = rself.eps = epsself.boxfilter = BoxFilter(r)def forward(self, x, y):n_x, c_x, h_x, w_x = x.size()n_y, c_y, h_y, w_y = y.size()assert n_x == n_yassert c_x == 1 or c_x == c_yassert h_x == h_y and w_x == w_yassert h_x > 2 * self.r + 1 and w_x > 2 * self.r + 1# NN = self.boxfilter(Variable(x.data.new().resize_((1, 1, h_x, w_x)).fill_(1.0)))# mean_xmean_x = self.boxfilter(x) / N# mean_ymean_y = self.boxfilter(y) / N# cov_xycov_xy = self.boxfilter(x * y) / N - mean_x * mean_y# var_xvar_x = self.boxfilter(x * x) / N - mean_x * mean_x# AA = cov_xy / (var_x + self.eps)# bb = mean_y - A * mean_x# mean_A; mean_bmean_A = self.boxfilter(A) / Nmean_b = self.boxfilter(b) / Nreturn mean_A * x + mean_b

以上代码可以保存为guided_filter.py

4.3 使用代码

暂时没有语义分割项目开展需求,故仅分析IA-SEG项目中的用法。
GuideFilter需要两个输入(边缘图和原始图),故需要额外的网络结构获取边缘图。
以下代码即是将普通的语义分割模型封装成一个包含LGF的模型,模型返回x1和x2,x1为正常语义分割的预测结果,x2为LGF优化后的结果。

class LGFModel(nn.Module):def __init__(self,  dgf, dgf_r, dgf_eps):self.inplanes = 64super(LGFModel, self).__init__()self.model=SegModel()if self.dgf:self.guided_map_conv1 = nn.Conv2d(3, 64, 1)self.guided_map_relu1 = nn.ReLU(inplace=True)self.guided_map_conv2 = nn.Conv2d(64, num_classes, 1)self.guided_filter = GuidedFilter(dgf_r, dgf_eps)def forward(self, x1):im = x1x1 = self.model(x1)if self.dgf:g = self.guided_map_relu1(self.guided_map_conv1(im))g = self.guided_map_conv2(g)x2 = F.interpolate(x1, im.size()[2:], mode='bilinear', align_corners=True)x2 = self.guided_filter(g, x2)return x1, x2

使用LGFModel,通常只需要对x2计算loss,可以不对x1的计算loss进行反向传播。如若模型收敛速度较慢,可以对x1计算loss进行反向传播。


本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

相关文章

立即
投稿

微信公众账号

微信扫一扫加关注

返回
顶部