您现在的位置是:首页 >技术教程 >剪枝与重参第九课:DBB重参网站首页技术教程

剪枝与重参第九课:DBB重参

爱听歌的周童鞋 2023-07-08 20:00:03
简介剪枝与重参第九课:DBB重参

DBB重参

前言

手写AI推出的全新模型剪枝与重参课程。记录下个人学习笔记,仅供自己参考。

本次课程主要讲解DBB的重参。

课程大纲可看下面的思维导图

在这里插入图片描述

1. DBB

Diverse Branch Block 是继 ACNet 的又一次对网络结构参数化的探索,即ACNet v2,DBB 设计了一个类似 Inception 的模块,以多分支的结构丰富卷积块的特征空间,各分支结构包括平均池化,多尺度卷积等。最后在推理阶段前,把多分支结构中进行重参数化,融合成一个主分支。加快推理速度的同时,顺带提升一下精度。

在这里插入图片描述

上图给出了设计的 DBB 结构示意图。类似 Inception,它采用 1x1,1x1-KxK,1x1-AVG 等组合方式对原始 KxK 卷积进行增强。对于 1x1-KxK 分支,设置中间通道数等于输入通道数并将 1x1 卷积初始化为 Identity 矩阵;其他分支则采用常规方式初始化。

此外,在每个卷积后都添加 BN 层用于提供训练时的非线性,这对于性能提升很有必要。

2. DBB的六种变换

对于一个常规的卷积网络,在推理阶段DBB存在6种变换,如下图所示:

在这里插入图片描述

2.1 Transform I: a conv for conv-BN

变换I:卷积替换卷积+BN

在这里插入图片描述

def transI_fusebn(kernel, bn):
    gamma = bn.weight
    std   = (bn.running_var + bn.eps).sqrt()
    k     = kernel * ((gamma / std).view(-1, 1, 1, 1))
    b     = bn.bias - bn.running_var * gamma / std
    return k, b

2.2 Transform II:a conv for branch addition

变换II:卷积分支融合

def transII_addbranch(kernels, biases):
    k = sum(kernels)
    b = sum(biases)
    return k, b

2.3 Transform III:a conv for sequential convolutions

变换III:sequential卷积融合

在这里插入图片描述

def transIII_1x1_kxk(k1, b1, k2, b2, groups):
    if groups == 1:
        k     = F.conv2d(k2, k1.permute(1, 0, 2, 3))
        b_hat = (k2 * b1.reshape(1, -1, 1, 1)).sum((1, 2, 3))
    else:
        
        k_slices = []
        b_slices = []
        k1_T = k1.permute(1, 0, 2, 3)

        k1_group_width = k1.size(0) // groups
        k2_group_width = k2.size(0) // groups
        for g in range(groups):
            k1_T_slice = k1_T[:, g*k1_group_width:(g+1)*k1_group_width, :, :]
            k2_slice   = k2[g*k2_group_width:(g+1)*k2_group_width, :, :, :]
            k_slices.append(F.conv2d(k2_slice, k1_T_slice))
            b_slices.append((k2_slice * b1[g*k1_group_width:(g+1)*k1_group_width].reshape(1, -1, 1, 1)).sum((1, 2, 3)))
        
        k, b_hat = transIV_depthconcat(k_slices, b_slices)
    return k, b_hat + b2

2.4 Transform IV:a conv for depth concatenation

变换IV:卷积拼接

在这里插入图片描述

def transIV_depthconcat(kernels, biases):
    return torch.cat(kernels, dim=0), torch.cat(biases)

2.5 Transform V:a conv for average pooling

变换V:平均池化

在这里插入图片描述

def transV_avg(channels, kernel_size, groups):
    input_dim = channels // groups
    k = torch.zeros((channels, input_dim, kernel_size, kernel_size))  
    k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = 1. / kernel_size**2
    return k

2.6 Transform VI:a conv for multi-scale convolutions

变换VI:多尺度卷积

def transVI_multiscale(kernel, target_kerne_size):
    H_pixels_to_pad = (target_kerne_size - kernel.size(2)) // 2
    W_pixels_to_pad = (target_kerne_size - kernel.size(3)) // 2
    return F.pad(kernel, [H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad])

3. DBB特殊结构

3.1 具有Identity性质的1x1Conv2d

DBB网络中还有一种具有Identity性质的卷积模块,其实现如下:

class IdentityBasedConv1x1(nn.Conv2d):
    def __init__(self, channels, groups=1):
        super().__init__(in_channels=channels,
                         out_channels=channels,
                         kernel_size=1,
                         stride=1,
                         padding=0,
                         groups=groups,
                         bias=False)
        assert channels % groups == 0
        input_dim = channels // groups
        id_value = np.zeros((channels, input_dim, 1, 1))
        for i in range(channels):
            id_value[i, i % input_dim, 0, 0] = 1
        self.id_tensor = torch.from_numpy(id_value).type_as(self.weight)
        nn.init.zeros_(self.weight)
    
    def forward(self, input):
        kernel = self.weight + self.id_tensor.to(self.weight.device)
        result = F.conv2d(input, kernel, None, stride=1, padding=0, dilation=self.dilation, groups=self.groups)
        return result

    def get_actual_kernel(self):
        return self.weight + self.id_tensor.to(self.weight.device)

3.2 BN+Pad

BN层加Pad,其实现如下:

class BNAndPadLayer(nn.Module):
    def __init__(self,
                 pad_pixels,
                 num_features,
                 eps=1e-5,
                 momentum=0.1,
                 affine=True,
                 track_running_stats=True):
        super().__init__()
        self.bn = nn.BatchNorm2d(num_features=num_features,
                                 eps = eps,
                                 momentum=momentum,
                                 affine=affine,
                                 track_running_stats=track_running_stats)
        self.pad_pixels = pad_pixels
    
    def forward(self, input):
        output = self.bn(input)
        if self.pad_pixels > 0:
            if self.bn.affine:
                pad_values = self.bn.bias.detach() - self.bn.running_mean * self.bn.weight.detach() / torch.sqrt(
                    self.bn.running_var + self.bn.eps
                )
            else:
                pad_values = - self.bn.running_mean / torch.sqrt(self.bn.running_var + self.bn.eps)

            output = F.pad(output, [self.pad_pixels]*4)
            pad_values = pad_values.view(1, -1, 1, 1)
            output[:, :, 0:self.pad_pixels, :] = pad_values
            output[:, :, -self.pad_pixels:, :] = pad_values
            output[:, :, :, 0:self.pad_pixels] = pad_values
            output[:, :, :, -self.pad_pixels:] = pad_values
        return output
    
    @property
    def weight(self):
        return self.bn.weight
    
    @property
    def bias(self):
        return self.bn.bias

    @property
    def running_mean(self):
        return self.bn.running_mean
    
    @property
    def running_var(self):
        return self.bn.running_var
    
    @property
    def eps(self):
        return self.bn.eps

4. DBB网络搭建

4.1 conv_bn

先写一个函数用来实现conv+bn

def conv_bn(in_channels, out_channels, kernel_size, stride=1, padding=0, padding_mode='zeros', dilation=1, groups=1):
    conv_layer = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                           stride=stride, padding=padding, dilation=dilation,
                           groups=groups, bias=False, padding_mode=padding_mode)
    bn_layer = nn.BatchNorm2d(num_features=out_channels, affine=True)
    se = nn.Sequential()
    se.add_module('conv', conv_layer)
    se.add_module('bn', bn_layer)
    return se

4.2 branch

分支的实现

class DiverseBranchBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
                 groups=1, internal_channels_1x1_3x3=None, deploy=False, nonlinear=None):
        super().__init__()
        self.deploy = deploy

        if nonlinear is None:
            self.nonlinear = nn.Identity()
        else:
            self.nonlinear = nonlinear
    
        self.kernel_size = kernel_size
        self.out_channels = out_channels
        self.groups = groups
        assert padding == kernel_size // 2

        if deploy:
            self.dbb_reparam = nn.Conv2d(
                in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                stride=stride, padding=padding, dilation=dilation, groups=groups, bias=True
            )
        else:
            self.bdd_origin = conv_bn(
                in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                stride=stride, padding=padding, dilation=dilation, groups=groups
            )
        
            self.dbb_avg = nn.Sequential()
            if groups < out_channels:
                self.dbb_avg.add_module(
                    'conv', nn.Conv2d(
                        in_channels=in_channels, out_channels=out_channels, kernel_size=1,
                        stride=1, padding=0, groups=groups, bias=True
                    )
                )

                self.dbb_avg.add_module(
                    'bn', BNAndPadLayer(pad_pixels=padding,
                                        num_features=out_channels)
                )

                self.dbb_avg.add_module(
                    'avg', nn.AvgPool2d(kernel_size=kernel_size,
                                        stride=stride,
                                        padding=0)
                )

                self.dbb1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1,
                                      stride=stride, padding=0, groups=groups)
            else:
                self.dbb_avg.add_module(
                    'avg', nn.AvgPool2d(kernel_size=kernel_size,
                                        stride=stride,
                                        padding=padding)
                )
            self.dbb_avg.add_module(
                'avgbn', nn.BatchNorm2d(out_channels)
            )

        if internal_channels_1x1_3x3 is None:
            internal_channels_1x1_3x3 = in_channels if groups < out_channels else 2 * in_channels
        
        self.dbb_1x1_kxk = nn.Sequential()
        if internal_channels_1x1_3x3 == in_channels:
            self.dbb_1x1_kxk.add_module('idconv1',
                                        IdentityBasedConv1x1(channels=in_channels, groups=groups))
        else:
            self.dbb_1x1_kxk.add_module('conv1',
                                        nn.Conv2d(in_channels=in_channels,
                                                  out_channels=internal_channels_1x1_3x3,
                                                  kernel_size=1,
                                                  stride=1,
                                                  padding=0,
                                                  groups=groups,
                                                  bias=False))
        self.dbb_1x1_kxk.add_module('bn1',
                                    BNAndPadLayer(pad_pixels=padding,
                                                 num_features=internal_channels_1x1_3x3,
                                                 affine=True))
        self.dbb_1x1_kxk.add_module('conv2',
                                    nn.Conv2d(in_channels=internal_channels_1x1_3x3,
                                              out_channels=out_channels,
                                              kernel_size=kernel_size,
                                              stride=stride,
                                              padding=0,
                                              groups=groups,
                                              bias=True))
        self.dbb_1x1_kxk.add_module('bn2', nn.BatchNorm2d(out_channels))

4.3 forward

前向传播的实现

class DiverseBranchBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
                 groups=1, internal_channels_1x1_3x3=None, deploy=False, nonlinear=None):
        super().__init__()
        ...

    def forward(self, inputs):
        if hasattr(self, 'dbb_reparam'):
            return self.nonlinear(self.dbb_reparam(inputs))
        
        out = self.dbb_origin(inputs)
        if hasattr(self, 'dbb_1x1'):
            out += self.dbb_1x1(inputs)
        out += self.dbb_avg(inputs)
        out += self.dbb_1x1(inputs)
        return self.nonlinear(out)

4.4 重参的实现

重参实现过程

class DiverseBranchBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
                 groups=1, internal_channels_1x1_3x3=None, deploy=False, nonlinear=None):
        super().__init__()
        ...

    def forward(self, inputs):
        ...

    def switch_to_deploy(self):
        if hasattr(self, 'dbb_reparam'):
            return

        kernel, bias = self.get_equivalent_kernel_bias()
        self.dbb_reparam = nn.Conv2d(in_channels=self.dbb_origin.conv.in_channels,
                                     out_channels=self.dbb_origin.conv.out_channels,
                                     kernel_size=self.dbb_origin.conv.kernel_size,
                                     stride=self.dbb_origin.conv.stride,
                                     padding=self.dbb_origin.conv.padding,
                                     dilation=self.dbb_origin.conv.dilation,
                                     groups=self.dbb_origin.conv.groups,
                                     bias=True)
        
        self.dbb_reparam.weight.data = kernel
        self.dbb_reparam.bias.data = bias
        for para in self.parameters():
            para.detach()
        
        self.__delattr__('dbb_origin')
        self.__delattr__('dbb_avg')
        if hasattr(self, 'dbb_1x1'):
            self.__delattr__('dbb_1x1')
        self.__delattr__('dbb_1x1_kxk')
    
    def get_equivalent_kernel_bias(self):
        k_origin, b_origin = transI_fusebn(self.dbb_origin.conv.weight,
                                           self.dbb_1x1.bn)
        
        if hasattr(self, 'dbb_1x1'):
            k_1x1, b_1x1 = transI_fusebn(self.dbb_1x1.conv.weight,
                                         self.dbb_1x1.bn)
            
            k_1x1 = transVI_multiscale(k_1x1,
                                       self.kernel_size)
        else:
            k_1x1, b_1x1 = 0
        
        if hasattr(self.dbb_1x1_kxk, 'idconv1'):
            k_1x1_kxk_first = self.dbb_1x1_kxk.idconv1.get_actual_kernel()
        else:
            k_1x1_kxk_first = self.dbb_1x1_kxk.conv1.weight
        
        k_1x1_kxk_first, b_1x1_kxk_first   = transI_fusebn(k_1x1_kxk_first,
                                                           self.dbb_1x1_kxk.bn1)
        k_1x1_kxk_second, b_1x1_kxk_second = transI_fusebn(self.dbb_1x1_kxk.conv2.weight,
                                                           self.dbb_1x1_kxk.bn2)

        k_1x1_kxk_merged, b_1x1_kxk_merged = transIII_1x1_kxk(k_1x1_kxk_first,
                                                              b_1x1_kxk_first,
                                                              k_1x1_kxk_second,
                                                              b_1x1_kxk_second,
                                                              groups=self.groups)
        
        k_avg = transV_avg(self.out_channels, self.kernel_size, self.groups)
        k_1x1_avg_second, b_1x1_avg_second = transI_fusebn(k_avg.to(self.dbb_avg.avgbn.weight.device),
                                                           self.dbb_avg.avgbn)
        if hasattr(self.dbb_avg, 'conv'):
            k_1x1_avg_first, b_1x1_avg_first = transI_fusebn(self.dbb_avg.conv.weight,
                                                             self.dbb_avg.bn)
            k_1x1_avg_merged, b_1x1_avg_merged = transIII_1x1_kxk(k_1x1_avg_first,
                                                                  b_1x1_avg_first,
                                                                  k_1x1_avg_second,
                                                                  b_1x1_avg_second,
                                                                  groups=self.groups)
        else:
            k_1x1_avg_merged, b_1x1_avg_merged = k_1x1_avg_second, b_1x1_avg_second
        
        return transII_addbranch((k_origin, k_1x1, k_1x1_kxk_merged, k_1x1_avg_merged),
                                 (b_origin, b_1x1, b_1x1_kxk_merged, b_1x1_avg_merged))

4.5 模型导出

DBB网络模型的导出和对比

if __name__ == '__main__':
    
    x = torch.randn(1, 4, 224, 224)

    model = DiverseBranchBlock(in_channels=4, out_channels=32, kernel_size=3, stride=1, padding=1, groups=2, deploy=False)

    for module in model.modules():
        if isinstance(module, torch.nn.BatchNorm2d):
            nn.init.uniform_(module.running_mean, 0, 0.1)
            nn.init.uniform_(module.running_var, 0, 0.2)
            nn.init.uniform_(module.weight, 0, 0.3)
            nn.init.uniform_(module.bias, 0, 0.4)
        
    model.eval()
    out = model(x)
    torch.onnx.export(model=model, args=x, f='./DBB.onnx', verbose=False)

    model.switch_to_deploy()
    deployout = model(x)

    torch.onnx.export(model=model, args=x, f='./DBB-deploy.onnx', verbose=False)

    print('
Difference between the outputs of the origin-DBB and rep-DBB is: {}
'.format(
        ((deployout - out) ** 2).sum()
    ))

5. 完整示例代码

DBB网络重参的完整示例代码如下:

import torch
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import torch.nn as nn
import torch.nn.functional as F


def conv_bn(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1,
            padding_mode='zeros'):
    conv_layer = nn.Conv2d(in_channels  = in_channels, 
                           out_channels = out_channels, 
                           kernel_size  = kernel_size,
                           stride       = stride, 
                           padding      = padding, 
                           dilation     = dilation, 
                           groups       = groups,
                           bias         = False, 
                           padding_mode = padding_mode)
    bn_layer = nn.BatchNorm2d(num_features=out_channels, affine=True)
    se = nn.Sequential()
    se.add_module('conv', conv_layer)
    se.add_module('bn', bn_layer)
    return se


def transI_fusebn(kernel, bn):
    '''
    Returns:
    k: the scaled kernel, computed by element-wise multiplying the kernel 
       with the ratio of the scaling factor 
       and the standard deviation, reshaped to have a new first dimension of size -1
    b: the bias term, computed by subtracting the product of the scaling factor 
       and the running mean of the batch normalization layer, 
       normalized by the standard deviation, from the bias of the batch normalization layer
    '''
    gamma = bn.weight
    std = (bn.running_var + bn.eps).sqrt()
    k = kernel * ((gamma / std).view(-1, 1, 1, 1))
    b = bn.bias - bn.running_mean * gamma / std
    return k, b


def transII_addbranch(kernels, biases):
    '''
    Input:
        kernels: tuple
        biases : tuple
    '''
    k = sum(kernels)
    b = sum(biases)
    return k, b


def transIII_1x1_kxk(k1, b1, k2, b2, groups):
    if groups == 1:
        k = F.conv2d(k2, k1.permute(1, 0, 2, 3)) #
        b_hat = (k2 * b1.view(1, -1, 1, 1)).sum((1, 2, 3))
    else:
        # initializes an empty list for storing the results of the 1x1 convolutions.
        k_slices = []
        # initializes an empty list for storing the bias terms for the kxk convolutions
        b_slices = []
        # switch the in_channels and out_channels
        k1_T = k1.permute(1, 0, 2, 3)
        # Compute the numbers of k1-group out channels
        k1_group_width = k1.size(0) // groups
        # Compute the numbers of k2-group out channels
        k2_group_width = k2.size(0) // groups
        # loops over the number of groups
        for g in range(groups):
            # extracts a slice of k1_T that corresponds to the channels in the current group
            k1_T_slice = k1_T[:, g*k1_group_width:(g+1)*k1_group_width, :, :]
            # extracts a slice of k2 that corresponds to the channels in the current group
            k2_slice = k2[g*k2_group_width:(g+1)*k2_group_width, :, :, :]
            k_slices.append(F.conv2d(k2_slice, k1_T_slice))
            b_slices.append((k2_slice * b1[g*k1_group_width:(g+1)*k1_group_width].view(1, -1, 1, 1)).sum((1, 2, 3)))
        # concatenates the results of the 1x1 convolutions and 
        # the bias terms across the group dimension by calling the transIV_depthconcat function
        k, b_hat = transIV_depthconcat(k_slices, b_slices)
    # returns the concatenated results of the 1x1 convolutions and 
    # the bias terms, with the bias term for the kxk convolution added to b2
    return k, b_hat + b2


def transIV_depthconcat(kernels, biases):
    '''
    Parameters:
        kernels: list
        biases : list
    '''
    return torch.cat(kernels, dim=0), torch.cat(biases)


def transV_avg(channels, kernel_size, groups):
    # Calculate the number of input dimensions for each group
    input_dim = channels // groups
    # Create a tensor of zeros with dimensions (channels, input_dim, kernel_size, kernel_size)
    k = torch.zeros((channels, input_dim, kernel_size, kernel_size))
    # Fill the diagonal blocks of the tensor with the average transform
    k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = 1.0 / kernel_size ** 2
    return k


#   This has not been tested with non-square kernels (kernel.size(2) != kernel.size(3)) nor even-size kernels
def transVI_multiscale(kernel, target_kernel_size):
    # Calculate the number of pixels to pad on the height dimension
    H_pixels_to_pad = (target_kernel_size - kernel.size(2)) // 2
    # Calculate the number of pixels to pad on the width dimension
    W_pixels_to_pad = (target_kernel_size - kernel.size(3)) // 2
    return F.pad(kernel, [H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad])


class IdentityBasedConv1x1(nn.Conv2d):
    '''
    This module implements a convolution operation that adds an identity matrix to the weight kernel, 
    allowing it to act as an identity operation in addition to the normal convolutional operation.
    '''
    def __init__(self, channels, groups=1):
        super().__init__(in_channels  = channels,
                         out_channels = channels,
                         kernel_size  = 1,
                         stride       = 1,
                         padding      = 0,
                         groups       = groups,
                         bias         = False)
        # Raises an assertion error if the number of input channels is not divisible by the number of groups
        assert channels % groups == 0
        # Calculates the size of input channel per group
        input_dim = channels // groups
        # Creates an identity matrix with the same size as the weight tensor with the value of 1 
        # for the diagonal elements and 0 for other elements.
        id_value  = np.zeros((channels, input_dim, 1, 1))
        for i in range(channels):
            id_value[i, i % input_dim, 0, 0] = 1
        # Initializes the id_tensor attribute with the identity matrix 
        # and initializes the weight attribute with zeros.
        self.id_tensor = torch.from_numpy(id_value).type_as(self.weight)
        nn.init.zeros_(self.weight)

    def forward(self, input):
        # By adding the identity matrix to the weight tensor, 
        # the IdentityBasedConv1x1 module can perform two operations simultaneously: 
        # normal convolution operation and identity operation. 
        # This makes the module more flexible and powerful, 
        # and it can be useful in many applications, such as in residual networks and in neural architecture search
        kernel = self.weight + self.id_tensor.to(self.weight.device)
        result = F.conv2d(input,
                          kernel,
                          None,
                          stride=1,
                          padding=0,
                          dilation=self.dilation,
                          groups=self.groups)
        return result

    def get_actual_kernel(self):
        return self.weight + self.id_tensor.to(self.weight.device)


class BNAndPadLayer(nn.Module):
    def __init__(self,
                 pad_pixels,
                 num_features,
                 eps=1e-5,
                 momentum=0.1,
                 affine=True,
                 track_running_stats=True):
        super().__init__()
        self.bn = nn.BatchNorm2d(num_features,
                                 eps,
                                 momentum,
                                 affine,
                                 track_running_stats)
        self.pad_pixels = pad_pixels

    def forward(self, input):
        output = self.bn(input)
        if self.pad_pixels > 0:
            # If the BatchNorm2d layer is affine (i.e. has learnable weights)
            if self.bn.affine:
                # Calculate the padding values using the batch normalization statistics
                pad_values = self.bn.bias.detach() - self.bn.running_mean * self.bn.weight.detach() / torch.sqrt(
                    self.bn.running_var + self.bn.eps)
            # If the BatchNorm2d layer is not affine (i.e. has no learnable weights)
            else:
                # Calculate the padding values based on the batch normalization mean and variance
                pad_values = - self.bn.running_mean / torch.sqrt(self.bn.running_var + self.bn.eps)
            # Pad the output tensor with zeros on all sides
            output = F.pad(output, [self.pad_pixels] * 4)
            # Reshape the padding values to have a size of (1, num_features, 1, 1)
            pad_values = pad_values.view(1, -1, 1, 1)
            # Replace the top padding values with the calculated values
            output[:, :, 0:self.pad_pixels, :] = pad_values
            # Replace the bottom padding values with the calculated values
            output[:, :, -self.pad_pixels:, :] = pad_values
            # Replace the left padding values with the calculated values
            output[:, :, :, 0:self.pad_pixels] = pad_values
            # Replace the right padding values with the calculated values
            output[:, :, :, -self.pad_pixels:] = pad_values
        return output

    @property
    def weight(self):
        return self.bn.weight

    @property
    def bias(self):
        return self.bn.bias

    @property
    def running_mean(self):
        return self.bn.running_mean

    @property
    def running_var(self):
        return self.bn.running_var

    @property
    def eps(self):
        return self.bn.eps


class DiverseBranchBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride = 1, padding   = 0, dilation  = 1, groups = 1,
                 internal_channels_1x1_3x3 = None,
                 deploy = False, nonlinear = None
        ):
        super().__init__()
        self.deploy = deploy

        if nonlinear is None:
            self.nonlinear = nn.Identity()
        else:
            self.nonlinear = nonlinear

        self.kernel_size   = kernel_size
        self.out_channels  = out_channels
        self.groups        = groups
        assert padding == kernel_size // 2

        if deploy:
            self.dbb_reparam = nn.Conv2d(
                in_channels  = in_channels, out_channels = out_channels, kernel_size  = kernel_size,
                stride       = stride,      padding      = padding,      dilation     = dilation,
                groups       = groups,      bias         = True)
        else:
            self.dbb_origin = conv_bn(
                in_channels  = in_channels, out_channels = out_channels, kernel_size  = kernel_size,
                stride       = stride,      padding      = padding,
                dilation     = dilation,    groups       = groups)

            self.dbb_avg = nn.Sequential()
            if groups < out_channels:
                self.dbb_avg.add_module(
                    'conv', nn.Conv2d(in_channels  = in_channels,
                                      out_channels = out_channels,
                                      kernel_size  = 1,
                                      stride       = 1,
                                      padding      = 0,
                                      groups       = groups,
                                      bias         = False))

                self.dbb_avg.add_module(
                    'bn', BNAndPadLayer(pad_pixels   = padding,
                                        num_features = out_channels))

                self.dbb_avg.add_module(
                    'avg', nn.AvgPool2d(kernel_size = kernel_size,
                                        stride      = stride,
                                        padding     = 0))

                self.dbb_1x1 = conv_bn(in_channels  = in_channels,
                                       out_channels = out_channels,
                                       kernel_size  = 1,
                                       stride       = stride,
                                       padding      = 0,
                                       groups       = groups)
            else:
                self.dbb_avg.add_module('avg',
                                        nn.AvgPool2d(kernel_size = kernel_size,
                                                     stride      = stride,
                                                     padding     = padding))

            self.dbb_avg.add_module('avgbn',
                                    nn.BatchNorm2d(out_channels))

            if internal_channels_1x1_3x3 is None:
                # For mobilenet, it is better to have 2X internal channels
                # internal_channels = in_channels or 2*in_channels
                internal_channels_1x1_3x3 = in_channels if groups < out_channels else 2 * in_channels

            self.dbb_1x1_kxk = nn.Sequential()
            if internal_channels_1x1_3x3 == in_channels:
                self.dbb_1x1_kxk.add_module('idconv1',
                                            IdentityBasedConv1x1(channels=in_channels, groups=groups))
            else:
                self.dbb_1x1_kxk.add_module('conv1',
                                            nn.Conv2d(in_channels=in_channels,
                                                      out_channels=internal_channels_1x1_3x3,
                                                      kernel_size=1,
                                                      stride=1,
                                                      padding=0,
                                                      groups=groups,
                                                      bias=False))
            self.dbb_1x1_kxk.add_module('bn1',
                                        BNAndPadLayer(pad_pixels=padding,
                                                      num_features=internal_channels_1x1_3x3,
                                                      affine=True))
            self.dbb_1x1_kxk.add_module('conv2',
                                        nn.Conv2d(in_channels=internal_channels_1x1_3x3, 
                                                  out_channels=out_channels,
                                                  kernel_size=kernel_size,
                                                  stride=stride,
                                                  padding=0,
                                                  groups=groups,
                                                  bias=False))
            self.dbb_1x1_kxk.add_module('bn2', nn.BatchNorm2d(out_channels))


    def forward(self, inputs):
    
        if hasattr(self, 'dbb_reparam'):
            return self.nonlinear(self.dbb_reparam(inputs))

        out = self.dbb_origin(inputs)
        if hasattr(self, 'dbb_1x1'):
            out += self.dbb_1x1(inputs)
        out += self.dbb_avg(inputs)
        out += self.dbb_1x1_kxk(inputs)
        return self.nonlinear(out)
    
    
    def switch_to_deploy(self):
        if hasattr(self, 'dbb_reparam'):
            return
        kernel, bias     = self.get_equivalent_kernel_bias()
        self.dbb_reparam = nn.Conv2d(in_channels  = self.dbb_origin.conv.in_channels,
                                     out_channels = self.dbb_origin.conv.out_channels,
                                     kernel_size  = self.dbb_origin.conv.kernel_size,
                                     stride       = self.dbb_origin.conv.stride,
                                     padding      = self.dbb_origin.conv.padding,
                                     dilation     = self.dbb_origin.conv.dilation,
                                     groups       = self.dbb_origin.conv.groups, 
                                     bias         = True)
        self.dbb_reparam.weight.data = kernel
        self.dbb_reparam.bias.data   = bias
        for para in self.parameters():
            para.detach_()
        self.__delattr__('dbb_origin')
        self.__delattr__('dbb_avg')
        if hasattr(self, 'dbb_1x1'):
            self.__delattr__('dbb_1x1')
        self.__delattr__('dbb_1x1_kxk')
    
    
    def get_equivalent_kernel_bias(self):
    # ================== 1. k_origin, b_origin 
        k_origin, b_origin = transI_fusebn(self.dbb_origin.conv.weight,
                                           self.dbb_origin.bn)
        
    # ================== 2. k_1x1_origin, b_1x1_origin 
        if hasattr(self, 'dbb_1x1'):
            # 按照方式1进行conv+bn的融合
            k_1x1, b_1x1 = transI_fusebn(self.dbb_1x1.conv.weight,
                                         self.dbb_1x1.bn)
            # 按照方式方式6进行多尺度卷积的合并
            k_1x1 = transVI_multiscale(k_1x1,
                                       self.kernel_size)
        else:
            k_1x1, b_1x1 = 0, 0

    # ================== 3. k_1x1_kxk_merged, b_1x1_kxk_merged 
        if hasattr(self.dbb_1x1_kxk, 'idconv1'):
            k_1x1_kxk_first = self.dbb_1x1_kxk.idconv1.get_actual_kernel()
        else:
            k_1x1_kxk_first = self.dbb_1x1_kxk.conv1.weight
        # 按照方式1进行conv+bn的融合
        k_1x1_kxk_first, b_1x1_kxk_first   = transI_fusebn(k_1x1_kxk_first,
                                                           self.dbb_1x1_kxk.bn1)
        # 按照方式1进行conv+bn的融合
        k_1x1_kxk_second, b_1x1_kxk_second = transI_fusebn(self.dbb_1x1_kxk.conv2.weight,
                                                           self.dbb_1x1_kxk.bn2)
        # 按照方式3进行1x1卷积与kxk卷积的合并
        k_1x1_kxk_merged, b_1x1_kxk_merged = transIII_1x1_kxk(k_1x1_kxk_first,
                                                              b_1x1_kxk_first,
                                                              k_1x1_kxk_second,
                                                              b_1x1_kxk_second,
                                                              groups=self.groups)
        
    # ================== 4. k_1x1_avg_merged, b_1x1_avg_merged 
        k_avg = transV_avg(self.out_channels, self.kernel_size, self.groups)
        # 按照方式1进行conv+bn的融合
        k_1x1_avg_second, b_1x1_avg_second     = transI_fusebn(k_avg.to(self.dbb_avg.avgbn.weight.device),
                                                               self.dbb_avg.avgbn)
        if hasattr(self.dbb_avg, 'conv'):
            # 按照方式1进行conv+bn的融合
            k_1x1_avg_first, b_1x1_avg_first   = transI_fusebn(self.dbb_avg.conv.weight,
                                                               self.dbb_avg.bn)
            # 按照方式3进行1x1卷积与kxk卷积的合并
            k_1x1_avg_merged, b_1x1_avg_merged = transIII_1x1_kxk(k_1x1_avg_first,
                                                                  b_1x1_avg_first,
                                                                  k_1x1_avg_second,
                                                                  b_1x1_avg_second,
                                                                  groups=self.groups)
        else:
            k_1x1_avg_merged, b_1x1_avg_merged = k_1x1_avg_second, b_1x1_avg_second
            
    # ================== 5. Final merge
        return transII_addbranch((k_origin,
                                  k_1x1,
                                  k_1x1_kxk_merged,
                                  k_1x1_avg_merged),
                                 (b_origin,
                                  b_1x1,
                                  b_1x1_kxk_merged,
                                  b_1x1_avg_merged))

    

if __name__ == '__main__':
    x = torch.randn(1, 4, 224, 224)

    model = DiverseBranchBlock(in_channels=4, out_channels=32, kernel_size=3, stride=1, padding=3//2,
                               groups=2, deploy=False)
    for module in model.modules():
        if isinstance(module, torch.nn.BatchNorm2d):
            nn.init.uniform_(module.running_mean, 0, 0.1)
            nn.init.uniform_(module.running_var, 0, 0.2)
            nn.init.uniform_(module.weight, 0, 0.3)
            nn.init.uniform_(module.bias, 0, 0.4)
      
            
    model.eval()
    out = model(x)
    # print(model)
    torch.onnx.export(model=model, args=x, f='../DBB.onnx', 
                      verbose=False)
    
    
    model.switch_to_deploy()
    deployout = model(x)
    # print(model)
    torch.onnx.export(
        model=model, args=x, f='../DBB-deploy.onnx', 
        verbose=False)

    print('
Difference between the outputs of the origin-DBB and rep-DBB is: {}
'.format(
        ((deployout - out) ** 2).sum()
    ))

总结

本次课程学习了 DBB 网络的重参,与 ACNet 的卷积替换相比,DBB 网络提出了更为复杂的类似 Inception 的多分支结构,并在推理阶段采用6种变换进行重参数化,融合成一个主分支,加快推理速度。

风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。