您现在的位置是:首页 >技术教程 >剪枝与重参第九课:DBB重参网站首页技术教程
剪枝与重参第九课:DBB重参
简介剪枝与重参第九课:DBB重参
DBB重参
前言
手写AI推出的全新模型剪枝与重参课程。记录下个人学习笔记,仅供自己参考。
本次课程主要讲解DBB的重参。
课程大纲可看下面的思维导图
1. DBB
Diverse Branch Block 是继 ACNet 的又一次对网络结构参数化的探索,即ACNet v2,DBB 设计了一个类似 Inception 的模块,以多分支的结构丰富卷积块的特征空间,各分支结构包括平均池化,多尺度卷积等。最后在推理阶段前,把多分支结构中进行重参数化,融合成一个主分支。加快推理速度的同时,顺带提升一下精度。
上图给出了设计的 DBB 结构示意图。类似 Inception,它采用 1x1,1x1-KxK,1x1-AVG 等组合方式对原始 KxK 卷积进行增强。对于 1x1-KxK 分支,设置中间通道数等于输入通道数并将 1x1 卷积初始化为 Identity 矩阵;其他分支则采用常规方式初始化。
此外,在每个卷积后都添加 BN 层用于提供训练时的非线性,这对于性能提升很有必要。
2. DBB的六种变换
对于一个常规的卷积网络,在推理阶段DBB存在6种变换,如下图所示:
2.1 Transform I: a conv for conv-BN
变换I:卷积替换卷积+BN
def transI_fusebn(kernel, bn):
gamma = bn.weight
std = (bn.running_var + bn.eps).sqrt()
k = kernel * ((gamma / std).view(-1, 1, 1, 1))
b = bn.bias - bn.running_var * gamma / std
return k, b
2.2 Transform II:a conv for branch addition
变换II:卷积分支融合
def transII_addbranch(kernels, biases):
k = sum(kernels)
b = sum(biases)
return k, b
2.3 Transform III:a conv for sequential convolutions
变换III:sequential卷积融合
def transIII_1x1_kxk(k1, b1, k2, b2, groups):
if groups == 1:
k = F.conv2d(k2, k1.permute(1, 0, 2, 3))
b_hat = (k2 * b1.reshape(1, -1, 1, 1)).sum((1, 2, 3))
else:
k_slices = []
b_slices = []
k1_T = k1.permute(1, 0, 2, 3)
k1_group_width = k1.size(0) // groups
k2_group_width = k2.size(0) // groups
for g in range(groups):
k1_T_slice = k1_T[:, g*k1_group_width:(g+1)*k1_group_width, :, :]
k2_slice = k2[g*k2_group_width:(g+1)*k2_group_width, :, :, :]
k_slices.append(F.conv2d(k2_slice, k1_T_slice))
b_slices.append((k2_slice * b1[g*k1_group_width:(g+1)*k1_group_width].reshape(1, -1, 1, 1)).sum((1, 2, 3)))
k, b_hat = transIV_depthconcat(k_slices, b_slices)
return k, b_hat + b2
2.4 Transform IV:a conv for depth concatenation
变换IV:卷积拼接
def transIV_depthconcat(kernels, biases):
return torch.cat(kernels, dim=0), torch.cat(biases)
2.5 Transform V:a conv for average pooling
变换V:平均池化
def transV_avg(channels, kernel_size, groups):
input_dim = channels // groups
k = torch.zeros((channels, input_dim, kernel_size, kernel_size))
k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = 1. / kernel_size**2
return k
2.6 Transform VI:a conv for multi-scale convolutions
变换VI:多尺度卷积
def transVI_multiscale(kernel, target_kerne_size):
H_pixels_to_pad = (target_kerne_size - kernel.size(2)) // 2
W_pixels_to_pad = (target_kerne_size - kernel.size(3)) // 2
return F.pad(kernel, [H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad])
3. DBB特殊结构
3.1 具有Identity性质的1x1Conv2d
DBB网络中还有一种具有Identity性质的卷积模块,其实现如下:
class IdentityBasedConv1x1(nn.Conv2d):
def __init__(self, channels, groups=1):
super().__init__(in_channels=channels,
out_channels=channels,
kernel_size=1,
stride=1,
padding=0,
groups=groups,
bias=False)
assert channels % groups == 0
input_dim = channels // groups
id_value = np.zeros((channels, input_dim, 1, 1))
for i in range(channels):
id_value[i, i % input_dim, 0, 0] = 1
self.id_tensor = torch.from_numpy(id_value).type_as(self.weight)
nn.init.zeros_(self.weight)
def forward(self, input):
kernel = self.weight + self.id_tensor.to(self.weight.device)
result = F.conv2d(input, kernel, None, stride=1, padding=0, dilation=self.dilation, groups=self.groups)
return result
def get_actual_kernel(self):
return self.weight + self.id_tensor.to(self.weight.device)
3.2 BN+Pad
BN层加Pad,其实现如下:
class BNAndPadLayer(nn.Module):
def __init__(self,
pad_pixels,
num_features,
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True):
super().__init__()
self.bn = nn.BatchNorm2d(num_features=num_features,
eps = eps,
momentum=momentum,
affine=affine,
track_running_stats=track_running_stats)
self.pad_pixels = pad_pixels
def forward(self, input):
output = self.bn(input)
if self.pad_pixels > 0:
if self.bn.affine:
pad_values = self.bn.bias.detach() - self.bn.running_mean * self.bn.weight.detach() / torch.sqrt(
self.bn.running_var + self.bn.eps
)
else:
pad_values = - self.bn.running_mean / torch.sqrt(self.bn.running_var + self.bn.eps)
output = F.pad(output, [self.pad_pixels]*4)
pad_values = pad_values.view(1, -1, 1, 1)
output[:, :, 0:self.pad_pixels, :] = pad_values
output[:, :, -self.pad_pixels:, :] = pad_values
output[:, :, :, 0:self.pad_pixels] = pad_values
output[:, :, :, -self.pad_pixels:] = pad_values
return output
@property
def weight(self):
return self.bn.weight
@property
def bias(self):
return self.bn.bias
@property
def running_mean(self):
return self.bn.running_mean
@property
def running_var(self):
return self.bn.running_var
@property
def eps(self):
return self.bn.eps
4. DBB网络搭建
4.1 conv_bn
先写一个函数用来实现conv+bn
def conv_bn(in_channels, out_channels, kernel_size, stride=1, padding=0, padding_mode='zeros', dilation=1, groups=1):
conv_layer = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=False, padding_mode=padding_mode)
bn_layer = nn.BatchNorm2d(num_features=out_channels, affine=True)
se = nn.Sequential()
se.add_module('conv', conv_layer)
se.add_module('bn', bn_layer)
return se
4.2 branch
分支的实现
class DiverseBranchBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
groups=1, internal_channels_1x1_3x3=None, deploy=False, nonlinear=None):
super().__init__()
self.deploy = deploy
if nonlinear is None:
self.nonlinear = nn.Identity()
else:
self.nonlinear = nonlinear
self.kernel_size = kernel_size
self.out_channels = out_channels
self.groups = groups
assert padding == kernel_size // 2
if deploy:
self.dbb_reparam = nn.Conv2d(
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, dilation=dilation, groups=groups, bias=True
)
else:
self.bdd_origin = conv_bn(
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, dilation=dilation, groups=groups
)
self.dbb_avg = nn.Sequential()
if groups < out_channels:
self.dbb_avg.add_module(
'conv', nn.Conv2d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1,
stride=1, padding=0, groups=groups, bias=True
)
)
self.dbb_avg.add_module(
'bn', BNAndPadLayer(pad_pixels=padding,
num_features=out_channels)
)
self.dbb_avg.add_module(
'avg', nn.AvgPool2d(kernel_size=kernel_size,
stride=stride,
padding=0)
)
self.dbb1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1,
stride=stride, padding=0, groups=groups)
else:
self.dbb_avg.add_module(
'avg', nn.AvgPool2d(kernel_size=kernel_size,
stride=stride,
padding=padding)
)
self.dbb_avg.add_module(
'avgbn', nn.BatchNorm2d(out_channels)
)
if internal_channels_1x1_3x3 is None:
internal_channels_1x1_3x3 = in_channels if groups < out_channels else 2 * in_channels
self.dbb_1x1_kxk = nn.Sequential()
if internal_channels_1x1_3x3 == in_channels:
self.dbb_1x1_kxk.add_module('idconv1',
IdentityBasedConv1x1(channels=in_channels, groups=groups))
else:
self.dbb_1x1_kxk.add_module('conv1',
nn.Conv2d(in_channels=in_channels,
out_channels=internal_channels_1x1_3x3,
kernel_size=1,
stride=1,
padding=0,
groups=groups,
bias=False))
self.dbb_1x1_kxk.add_module('bn1',
BNAndPadLayer(pad_pixels=padding,
num_features=internal_channels_1x1_3x3,
affine=True))
self.dbb_1x1_kxk.add_module('conv2',
nn.Conv2d(in_channels=internal_channels_1x1_3x3,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=0,
groups=groups,
bias=True))
self.dbb_1x1_kxk.add_module('bn2', nn.BatchNorm2d(out_channels))
4.3 forward
前向传播的实现
class DiverseBranchBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
groups=1, internal_channels_1x1_3x3=None, deploy=False, nonlinear=None):
super().__init__()
...
def forward(self, inputs):
if hasattr(self, 'dbb_reparam'):
return self.nonlinear(self.dbb_reparam(inputs))
out = self.dbb_origin(inputs)
if hasattr(self, 'dbb_1x1'):
out += self.dbb_1x1(inputs)
out += self.dbb_avg(inputs)
out += self.dbb_1x1(inputs)
return self.nonlinear(out)
4.4 重参的实现
重参实现过程
class DiverseBranchBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
groups=1, internal_channels_1x1_3x3=None, deploy=False, nonlinear=None):
super().__init__()
...
def forward(self, inputs):
...
def switch_to_deploy(self):
if hasattr(self, 'dbb_reparam'):
return
kernel, bias = self.get_equivalent_kernel_bias()
self.dbb_reparam = nn.Conv2d(in_channels=self.dbb_origin.conv.in_channels,
out_channels=self.dbb_origin.conv.out_channels,
kernel_size=self.dbb_origin.conv.kernel_size,
stride=self.dbb_origin.conv.stride,
padding=self.dbb_origin.conv.padding,
dilation=self.dbb_origin.conv.dilation,
groups=self.dbb_origin.conv.groups,
bias=True)
self.dbb_reparam.weight.data = kernel
self.dbb_reparam.bias.data = bias
for para in self.parameters():
para.detach()
self.__delattr__('dbb_origin')
self.__delattr__('dbb_avg')
if hasattr(self, 'dbb_1x1'):
self.__delattr__('dbb_1x1')
self.__delattr__('dbb_1x1_kxk')
def get_equivalent_kernel_bias(self):
k_origin, b_origin = transI_fusebn(self.dbb_origin.conv.weight,
self.dbb_1x1.bn)
if hasattr(self, 'dbb_1x1'):
k_1x1, b_1x1 = transI_fusebn(self.dbb_1x1.conv.weight,
self.dbb_1x1.bn)
k_1x1 = transVI_multiscale(k_1x1,
self.kernel_size)
else:
k_1x1, b_1x1 = 0
if hasattr(self.dbb_1x1_kxk, 'idconv1'):
k_1x1_kxk_first = self.dbb_1x1_kxk.idconv1.get_actual_kernel()
else:
k_1x1_kxk_first = self.dbb_1x1_kxk.conv1.weight
k_1x1_kxk_first, b_1x1_kxk_first = transI_fusebn(k_1x1_kxk_first,
self.dbb_1x1_kxk.bn1)
k_1x1_kxk_second, b_1x1_kxk_second = transI_fusebn(self.dbb_1x1_kxk.conv2.weight,
self.dbb_1x1_kxk.bn2)
k_1x1_kxk_merged, b_1x1_kxk_merged = transIII_1x1_kxk(k_1x1_kxk_first,
b_1x1_kxk_first,
k_1x1_kxk_second,
b_1x1_kxk_second,
groups=self.groups)
k_avg = transV_avg(self.out_channels, self.kernel_size, self.groups)
k_1x1_avg_second, b_1x1_avg_second = transI_fusebn(k_avg.to(self.dbb_avg.avgbn.weight.device),
self.dbb_avg.avgbn)
if hasattr(self.dbb_avg, 'conv'):
k_1x1_avg_first, b_1x1_avg_first = transI_fusebn(self.dbb_avg.conv.weight,
self.dbb_avg.bn)
k_1x1_avg_merged, b_1x1_avg_merged = transIII_1x1_kxk(k_1x1_avg_first,
b_1x1_avg_first,
k_1x1_avg_second,
b_1x1_avg_second,
groups=self.groups)
else:
k_1x1_avg_merged, b_1x1_avg_merged = k_1x1_avg_second, b_1x1_avg_second
return transII_addbranch((k_origin, k_1x1, k_1x1_kxk_merged, k_1x1_avg_merged),
(b_origin, b_1x1, b_1x1_kxk_merged, b_1x1_avg_merged))
4.5 模型导出
DBB网络模型的导出和对比
if __name__ == '__main__':
x = torch.randn(1, 4, 224, 224)
model = DiverseBranchBlock(in_channels=4, out_channels=32, kernel_size=3, stride=1, padding=1, groups=2, deploy=False)
for module in model.modules():
if isinstance(module, torch.nn.BatchNorm2d):
nn.init.uniform_(module.running_mean, 0, 0.1)
nn.init.uniform_(module.running_var, 0, 0.2)
nn.init.uniform_(module.weight, 0, 0.3)
nn.init.uniform_(module.bias, 0, 0.4)
model.eval()
out = model(x)
torch.onnx.export(model=model, args=x, f='./DBB.onnx', verbose=False)
model.switch_to_deploy()
deployout = model(x)
torch.onnx.export(model=model, args=x, f='./DBB-deploy.onnx', verbose=False)
print('
Difference between the outputs of the origin-DBB and rep-DBB is: {}
'.format(
((deployout - out) ** 2).sum()
))
5. 完整示例代码
DBB网络重参的完整示例代码如下:
import torch
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
def conv_bn(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1,
padding_mode='zeros'):
conv_layer = nn.Conv2d(in_channels = in_channels,
out_channels = out_channels,
kernel_size = kernel_size,
stride = stride,
padding = padding,
dilation = dilation,
groups = groups,
bias = False,
padding_mode = padding_mode)
bn_layer = nn.BatchNorm2d(num_features=out_channels, affine=True)
se = nn.Sequential()
se.add_module('conv', conv_layer)
se.add_module('bn', bn_layer)
return se
def transI_fusebn(kernel, bn):
'''
Returns:
k: the scaled kernel, computed by element-wise multiplying the kernel
with the ratio of the scaling factor
and the standard deviation, reshaped to have a new first dimension of size -1
b: the bias term, computed by subtracting the product of the scaling factor
and the running mean of the batch normalization layer,
normalized by the standard deviation, from the bias of the batch normalization layer
'''
gamma = bn.weight
std = (bn.running_var + bn.eps).sqrt()
k = kernel * ((gamma / std).view(-1, 1, 1, 1))
b = bn.bias - bn.running_mean * gamma / std
return k, b
def transII_addbranch(kernels, biases):
'''
Input:
kernels: tuple
biases : tuple
'''
k = sum(kernels)
b = sum(biases)
return k, b
def transIII_1x1_kxk(k1, b1, k2, b2, groups):
if groups == 1:
k = F.conv2d(k2, k1.permute(1, 0, 2, 3)) #
b_hat = (k2 * b1.view(1, -1, 1, 1)).sum((1, 2, 3))
else:
# initializes an empty list for storing the results of the 1x1 convolutions.
k_slices = []
# initializes an empty list for storing the bias terms for the kxk convolutions
b_slices = []
# switch the in_channels and out_channels
k1_T = k1.permute(1, 0, 2, 3)
# Compute the numbers of k1-group out channels
k1_group_width = k1.size(0) // groups
# Compute the numbers of k2-group out channels
k2_group_width = k2.size(0) // groups
# loops over the number of groups
for g in range(groups):
# extracts a slice of k1_T that corresponds to the channels in the current group
k1_T_slice = k1_T[:, g*k1_group_width:(g+1)*k1_group_width, :, :]
# extracts a slice of k2 that corresponds to the channels in the current group
k2_slice = k2[g*k2_group_width:(g+1)*k2_group_width, :, :, :]
k_slices.append(F.conv2d(k2_slice, k1_T_slice))
b_slices.append((k2_slice * b1[g*k1_group_width:(g+1)*k1_group_width].view(1, -1, 1, 1)).sum((1, 2, 3)))
# concatenates the results of the 1x1 convolutions and
# the bias terms across the group dimension by calling the transIV_depthconcat function
k, b_hat = transIV_depthconcat(k_slices, b_slices)
# returns the concatenated results of the 1x1 convolutions and
# the bias terms, with the bias term for the kxk convolution added to b2
return k, b_hat + b2
def transIV_depthconcat(kernels, biases):
'''
Parameters:
kernels: list
biases : list
'''
return torch.cat(kernels, dim=0), torch.cat(biases)
def transV_avg(channels, kernel_size, groups):
# Calculate the number of input dimensions for each group
input_dim = channels // groups
# Create a tensor of zeros with dimensions (channels, input_dim, kernel_size, kernel_size)
k = torch.zeros((channels, input_dim, kernel_size, kernel_size))
# Fill the diagonal blocks of the tensor with the average transform
k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = 1.0 / kernel_size ** 2
return k
# This has not been tested with non-square kernels (kernel.size(2) != kernel.size(3)) nor even-size kernels
def transVI_multiscale(kernel, target_kernel_size):
# Calculate the number of pixels to pad on the height dimension
H_pixels_to_pad = (target_kernel_size - kernel.size(2)) // 2
# Calculate the number of pixels to pad on the width dimension
W_pixels_to_pad = (target_kernel_size - kernel.size(3)) // 2
return F.pad(kernel, [H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad])
class IdentityBasedConv1x1(nn.Conv2d):
'''
This module implements a convolution operation that adds an identity matrix to the weight kernel,
allowing it to act as an identity operation in addition to the normal convolutional operation.
'''
def __init__(self, channels, groups=1):
super().__init__(in_channels = channels,
out_channels = channels,
kernel_size = 1,
stride = 1,
padding = 0,
groups = groups,
bias = False)
# Raises an assertion error if the number of input channels is not divisible by the number of groups
assert channels % groups == 0
# Calculates the size of input channel per group
input_dim = channels // groups
# Creates an identity matrix with the same size as the weight tensor with the value of 1
# for the diagonal elements and 0 for other elements.
id_value = np.zeros((channels, input_dim, 1, 1))
for i in range(channels):
id_value[i, i % input_dim, 0, 0] = 1
# Initializes the id_tensor attribute with the identity matrix
# and initializes the weight attribute with zeros.
self.id_tensor = torch.from_numpy(id_value).type_as(self.weight)
nn.init.zeros_(self.weight)
def forward(self, input):
# By adding the identity matrix to the weight tensor,
# the IdentityBasedConv1x1 module can perform two operations simultaneously:
# normal convolution operation and identity operation.
# This makes the module more flexible and powerful,
# and it can be useful in many applications, such as in residual networks and in neural architecture search
kernel = self.weight + self.id_tensor.to(self.weight.device)
result = F.conv2d(input,
kernel,
None,
stride=1,
padding=0,
dilation=self.dilation,
groups=self.groups)
return result
def get_actual_kernel(self):
return self.weight + self.id_tensor.to(self.weight.device)
class BNAndPadLayer(nn.Module):
def __init__(self,
pad_pixels,
num_features,
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True):
super().__init__()
self.bn = nn.BatchNorm2d(num_features,
eps,
momentum,
affine,
track_running_stats)
self.pad_pixels = pad_pixels
def forward(self, input):
output = self.bn(input)
if self.pad_pixels > 0:
# If the BatchNorm2d layer is affine (i.e. has learnable weights)
if self.bn.affine:
# Calculate the padding values using the batch normalization statistics
pad_values = self.bn.bias.detach() - self.bn.running_mean * self.bn.weight.detach() / torch.sqrt(
self.bn.running_var + self.bn.eps)
# If the BatchNorm2d layer is not affine (i.e. has no learnable weights)
else:
# Calculate the padding values based on the batch normalization mean and variance
pad_values = - self.bn.running_mean / torch.sqrt(self.bn.running_var + self.bn.eps)
# Pad the output tensor with zeros on all sides
output = F.pad(output, [self.pad_pixels] * 4)
# Reshape the padding values to have a size of (1, num_features, 1, 1)
pad_values = pad_values.view(1, -1, 1, 1)
# Replace the top padding values with the calculated values
output[:, :, 0:self.pad_pixels, :] = pad_values
# Replace the bottom padding values with the calculated values
output[:, :, -self.pad_pixels:, :] = pad_values
# Replace the left padding values with the calculated values
output[:, :, :, 0:self.pad_pixels] = pad_values
# Replace the right padding values with the calculated values
output[:, :, :, -self.pad_pixels:] = pad_values
return output
@property
def weight(self):
return self.bn.weight
@property
def bias(self):
return self.bn.bias
@property
def running_mean(self):
return self.bn.running_mean
@property
def running_var(self):
return self.bn.running_var
@property
def eps(self):
return self.bn.eps
class DiverseBranchBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size,
stride = 1, padding = 0, dilation = 1, groups = 1,
internal_channels_1x1_3x3 = None,
deploy = False, nonlinear = None
):
super().__init__()
self.deploy = deploy
if nonlinear is None:
self.nonlinear = nn.Identity()
else:
self.nonlinear = nonlinear
self.kernel_size = kernel_size
self.out_channels = out_channels
self.groups = groups
assert padding == kernel_size // 2
if deploy:
self.dbb_reparam = nn.Conv2d(
in_channels = in_channels, out_channels = out_channels, kernel_size = kernel_size,
stride = stride, padding = padding, dilation = dilation,
groups = groups, bias = True)
else:
self.dbb_origin = conv_bn(
in_channels = in_channels, out_channels = out_channels, kernel_size = kernel_size,
stride = stride, padding = padding,
dilation = dilation, groups = groups)
self.dbb_avg = nn.Sequential()
if groups < out_channels:
self.dbb_avg.add_module(
'conv', nn.Conv2d(in_channels = in_channels,
out_channels = out_channels,
kernel_size = 1,
stride = 1,
padding = 0,
groups = groups,
bias = False))
self.dbb_avg.add_module(
'bn', BNAndPadLayer(pad_pixels = padding,
num_features = out_channels))
self.dbb_avg.add_module(
'avg', nn.AvgPool2d(kernel_size = kernel_size,
stride = stride,
padding = 0))
self.dbb_1x1 = conv_bn(in_channels = in_channels,
out_channels = out_channels,
kernel_size = 1,
stride = stride,
padding = 0,
groups = groups)
else:
self.dbb_avg.add_module('avg',
nn.AvgPool2d(kernel_size = kernel_size,
stride = stride,
padding = padding))
self.dbb_avg.add_module('avgbn',
nn.BatchNorm2d(out_channels))
if internal_channels_1x1_3x3 is None:
# For mobilenet, it is better to have 2X internal channels
# internal_channels = in_channels or 2*in_channels
internal_channels_1x1_3x3 = in_channels if groups < out_channels else 2 * in_channels
self.dbb_1x1_kxk = nn.Sequential()
if internal_channels_1x1_3x3 == in_channels:
self.dbb_1x1_kxk.add_module('idconv1',
IdentityBasedConv1x1(channels=in_channels, groups=groups))
else:
self.dbb_1x1_kxk.add_module('conv1',
nn.Conv2d(in_channels=in_channels,
out_channels=internal_channels_1x1_3x3,
kernel_size=1,
stride=1,
padding=0,
groups=groups,
bias=False))
self.dbb_1x1_kxk.add_module('bn1',
BNAndPadLayer(pad_pixels=padding,
num_features=internal_channels_1x1_3x3,
affine=True))
self.dbb_1x1_kxk.add_module('conv2',
nn.Conv2d(in_channels=internal_channels_1x1_3x3,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=0,
groups=groups,
bias=False))
self.dbb_1x1_kxk.add_module('bn2', nn.BatchNorm2d(out_channels))
def forward(self, inputs):
if hasattr(self, 'dbb_reparam'):
return self.nonlinear(self.dbb_reparam(inputs))
out = self.dbb_origin(inputs)
if hasattr(self, 'dbb_1x1'):
out += self.dbb_1x1(inputs)
out += self.dbb_avg(inputs)
out += self.dbb_1x1_kxk(inputs)
return self.nonlinear(out)
def switch_to_deploy(self):
if hasattr(self, 'dbb_reparam'):
return
kernel, bias = self.get_equivalent_kernel_bias()
self.dbb_reparam = nn.Conv2d(in_channels = self.dbb_origin.conv.in_channels,
out_channels = self.dbb_origin.conv.out_channels,
kernel_size = self.dbb_origin.conv.kernel_size,
stride = self.dbb_origin.conv.stride,
padding = self.dbb_origin.conv.padding,
dilation = self.dbb_origin.conv.dilation,
groups = self.dbb_origin.conv.groups,
bias = True)
self.dbb_reparam.weight.data = kernel
self.dbb_reparam.bias.data = bias
for para in self.parameters():
para.detach_()
self.__delattr__('dbb_origin')
self.__delattr__('dbb_avg')
if hasattr(self, 'dbb_1x1'):
self.__delattr__('dbb_1x1')
self.__delattr__('dbb_1x1_kxk')
def get_equivalent_kernel_bias(self):
# ================== 1. k_origin, b_origin
k_origin, b_origin = transI_fusebn(self.dbb_origin.conv.weight,
self.dbb_origin.bn)
# ================== 2. k_1x1_origin, b_1x1_origin
if hasattr(self, 'dbb_1x1'):
# 按照方式1进行conv+bn的融合
k_1x1, b_1x1 = transI_fusebn(self.dbb_1x1.conv.weight,
self.dbb_1x1.bn)
# 按照方式方式6进行多尺度卷积的合并
k_1x1 = transVI_multiscale(k_1x1,
self.kernel_size)
else:
k_1x1, b_1x1 = 0, 0
# ================== 3. k_1x1_kxk_merged, b_1x1_kxk_merged
if hasattr(self.dbb_1x1_kxk, 'idconv1'):
k_1x1_kxk_first = self.dbb_1x1_kxk.idconv1.get_actual_kernel()
else:
k_1x1_kxk_first = self.dbb_1x1_kxk.conv1.weight
# 按照方式1进行conv+bn的融合
k_1x1_kxk_first, b_1x1_kxk_first = transI_fusebn(k_1x1_kxk_first,
self.dbb_1x1_kxk.bn1)
# 按照方式1进行conv+bn的融合
k_1x1_kxk_second, b_1x1_kxk_second = transI_fusebn(self.dbb_1x1_kxk.conv2.weight,
self.dbb_1x1_kxk.bn2)
# 按照方式3进行1x1卷积与kxk卷积的合并
k_1x1_kxk_merged, b_1x1_kxk_merged = transIII_1x1_kxk(k_1x1_kxk_first,
b_1x1_kxk_first,
k_1x1_kxk_second,
b_1x1_kxk_second,
groups=self.groups)
# ================== 4. k_1x1_avg_merged, b_1x1_avg_merged
k_avg = transV_avg(self.out_channels, self.kernel_size, self.groups)
# 按照方式1进行conv+bn的融合
k_1x1_avg_second, b_1x1_avg_second = transI_fusebn(k_avg.to(self.dbb_avg.avgbn.weight.device),
self.dbb_avg.avgbn)
if hasattr(self.dbb_avg, 'conv'):
# 按照方式1进行conv+bn的融合
k_1x1_avg_first, b_1x1_avg_first = transI_fusebn(self.dbb_avg.conv.weight,
self.dbb_avg.bn)
# 按照方式3进行1x1卷积与kxk卷积的合并
k_1x1_avg_merged, b_1x1_avg_merged = transIII_1x1_kxk(k_1x1_avg_first,
b_1x1_avg_first,
k_1x1_avg_second,
b_1x1_avg_second,
groups=self.groups)
else:
k_1x1_avg_merged, b_1x1_avg_merged = k_1x1_avg_second, b_1x1_avg_second
# ================== 5. Final merge
return transII_addbranch((k_origin,
k_1x1,
k_1x1_kxk_merged,
k_1x1_avg_merged),
(b_origin,
b_1x1,
b_1x1_kxk_merged,
b_1x1_avg_merged))
if __name__ == '__main__':
x = torch.randn(1, 4, 224, 224)
model = DiverseBranchBlock(in_channels=4, out_channels=32, kernel_size=3, stride=1, padding=3//2,
groups=2, deploy=False)
for module in model.modules():
if isinstance(module, torch.nn.BatchNorm2d):
nn.init.uniform_(module.running_mean, 0, 0.1)
nn.init.uniform_(module.running_var, 0, 0.2)
nn.init.uniform_(module.weight, 0, 0.3)
nn.init.uniform_(module.bias, 0, 0.4)
model.eval()
out = model(x)
# print(model)
torch.onnx.export(model=model, args=x, f='../DBB.onnx',
verbose=False)
model.switch_to_deploy()
deployout = model(x)
# print(model)
torch.onnx.export(
model=model, args=x, f='../DBB-deploy.onnx',
verbose=False)
print('
Difference between the outputs of the origin-DBB and rep-DBB is: {}
'.format(
((deployout - out) ** 2).sum()
))
总结
本次课程学习了 DBB 网络的重参,与 ACNet 的卷积替换相比,DBB 网络提出了更为复杂的类似 Inception 的多分支结构,并在推理阶段采用6种变换进行重参数化,融合成一个主分支,加快推理速度。
风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。