您现在的位置是:首页 >技术交流 >(即插即用模块-特征处理部分) 二十二、(ICCV 2021) CrossNorm / SelfNorm 交叉/自 归一化网站首页技术交流

(即插即用模块-特征处理部分) 二十二、(ICCV 2021) CrossNorm / SelfNorm 交叉/自 归一化

御宇w 2025-03-06 12:01:02
简介(即插即用模块-特征处理部分) 二十二、(ICCV 2021) CrossNorm / SelfNorm 交叉/自 归一化

在这里插入图片描述

文章目录

  • 1、CrossNorm / SelfNorm
  • 2、代码实现

paper:CrossNorm and SelfNorm for Generalization under Distribution Shifts

Code:https://github.com/amazon-research/crossnorm-selfnorm


1、CrossNorm / SelfNorm

现有的传统归一化方法(如 Batch Normalization 和 Instance Normalization),其假设训练数据和测试数据来自同一分布,这在实际应用中往往不成立。这篇论文提出两种归一化方式 交叉归一化(CrossNorm )和 自归一化(SelfNorm ),CrossNorm 和 SelfNorm 旨在解决深度学习模型在面对数据分布变化时泛化能力不足的问题。CrossNorm 和 SelfNorm 从两个角度出发,CrossNorm 通过交换特征图之间的通道均值和方差来扩充训练数据的分布,使模型对不同的外观变化更具鲁棒性。而 SelfNorm 使用注意力机制重新校准特征图的统计信息,缩小训练数据和测试数据之间的分布差异,使模型在测试数据上也能取得更好的表现。

CrossNorm 和 SelfNorm 都是基于 Instance Normalization 的扩展,但它们使用通道均值和方差的方式有所不同。具体来说,CrossNorm: 交换特征图之间的通道均值和方差,使模型学习到更多样化的风格信息,从而提高对数据外观变化的鲁棒性。SelfNorm: 使用注意力机制重新校准特征图的通道均值和方差,突出训练数据和测试数据之间共享的判别性风格信息,抑制无关的风格信息,从而缩小分布差异。

对于输入X,CrossNorm 和 SelfNorm 的实现过程如下:

CrossNorm:

  1. 选择一对特征图 A 和 B。
  2. 计算 A 和 B 的通道均值和方差。
  3. 交换 A 和 B 的通道均值和方差。
  4. 对交换后的特征图进行归一化和仿射变换。

SelfNorm:

  1. 对特征图 A 进行归一化。
  2. 使用注意力机制学习一个可学习的函数 f 和 g,分别对通道均值和方差进行缩放。
  3. 使用缩放后的均值和方差对归一化后的特征图进行仿射变换。

CrossNorm / SelfNorm 结构图:
在这里插入图片描述

2、代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import functools
import numpy as np


def calc_ins_mean_std(x, eps=1e-5):
    """extract feature map statistics"""
    # eps is a small value added to the variance to avoid divide-by-zero.
    size = x.size()
    assert (len(size) == 4)
    N, C = size[:2]
    var = x.contiguous().view(N, C, -1).var(dim=2) + eps
    std = var.sqrt().view(N, C, 1, 1)
    mean = x.contiguous().view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
    return mean, std


def instance_norm_mix(content_feat, style_feat):
    """replace content statistics with style statistics"""
    assert (content_feat.size()[:2] == style_feat.size()[:2])
    size = content_feat.size()
    style_mean, style_std = calc_ins_mean_std(style_feat)
    content_mean, content_std = calc_ins_mean_std(content_feat)

    normalized_feat = (content_feat - content_mean.expand(
        size)) / content_std.expand(size)
    return normalized_feat * style_std.expand(size) + style_mean.expand(size)


def cn_rand_bbox(size, beta, bbx_thres):
    """sample a bounding box for cropping."""
    W = size[2]
    H = size[3]
    while True:
        ratio = np.random.beta(beta, beta)
        cut_rat = np.sqrt(ratio)
        cut_w = np.int(W * cut_rat)
        cut_h = np.int(H * cut_rat)

        # uniform
        cx = np.random.randint(W)
        cy = np.random.randint(H)

        bbx1 = np.clip(cx - cut_w // 2, 0, W)
        bby1 = np.clip(cy - cut_h // 2, 0, H)
        bbx2 = np.clip(cx + cut_w // 2, 0, W)
        bby2 = np.clip(cy + cut_h // 2, 0, H)

        ratio = float(bbx2 - bbx1) * (bby2 - bby1) / (W * H)
        if ratio > bbx_thres:
            break

    return bbx1, bby1, bbx2, bby2


def cn_op_2ins_space_chan(x, crop='neither', beta=1, bbx_thres=0.1, lam=None, chan=False):
    """2-instance crossnorm with cropping."""

    assert crop in ['neither', 'style', 'content', 'both']
    ins_idxs = torch.randperm(x.size()[0]).to(x.device)

    if crop in ['style', 'both']:
        bbx3, bby3, bbx4, bby4 = cn_rand_bbox(x.size(), beta=beta, bbx_thres=bbx_thres)
        x2 = x[ins_idxs, :, bbx3:bbx4, bby3:bby4]
    else:
        x2 = x[ins_idxs]

    if chan:
        chan_idxs = torch.randperm(x.size()[1]).to(x.device)
        x2 = x2[:, chan_idxs, :, :]

    if crop in ['content', 'both']:
        x_aug = torch.zeros_like(x)
        bbx1, bby1, bbx2, bby2 = cn_rand_bbox(x.size(), beta=beta, bbx_thres=bbx_thres)
        x_aug[:, :, bbx1:bbx2, bby1:bby2] = instance_norm_mix(content_feat=x[:, :, bbx1:bbx2, bby1:bby2],
                                                              style_feat=x2)

        mask = torch.ones_like(x, requires_grad=False)
        mask[:, :, bbx1:bbx2, bby1:bby2] = 0.
        x_aug = x * mask + x_aug
    else:
        x_aug = instance_norm_mix(content_feat=x, style_feat=x2)

    if lam is not None:
        x = x * lam + x_aug * (1-lam)
    else:
        x = x_aug

    return x


class CrossNorm(nn.Module):
    """CrossNorm module"""
    def __init__(self, crop=None, beta=None):
        super(CrossNorm, self).__init__()

        self.active = False
        self.cn_op = functools.partial(cn_op_2ins_space_chan,
                                       crop=crop, beta=beta)

    def forward(self, x):
        if self.training and self.active:

            x = self.cn_op(x)

        self.active = False

        return x


class SelfNorm(nn.Module):
    """SelfNorm module"""
    def __init__(self, chan_num, is_two=False):
        super(SelfNorm, self).__init__()

        # channel-wise fully connected layer
        self.g_fc = nn.Conv1d(chan_num, chan_num, kernel_size=2,
                              bias=False, groups=chan_num)
        self.g_bn = nn.BatchNorm1d(chan_num)

        if is_two is True:
            self.f_fc = nn.Conv1d(chan_num, chan_num, kernel_size=2,
                                  bias=False, groups=chan_num)
            self.f_bn = nn.BatchNorm1d(chan_num)
        else:
            self.f_fc = None

    def forward(self, x):
        b, c, _, _ = x.size()

        mean, std = calc_ins_mean_std(x, eps=1e-12)

        statistics = torch.cat((mean.squeeze(3), std.squeeze(3)), -1)

        g_y = self.g_fc(statistics)
        g_y = self.g_bn(g_y)
        g_y = torch.sigmoid(g_y)
        g_y = g_y.view(b, c, 1, 1)

        if self.f_fc is not None:
            f_y = self.f_fc(statistics)
            f_y = self.f_bn(f_y)
            f_y = torch.sigmoid(f_y)
            f_y = f_y.view(b, c, 1, 1)

            return x * g_y.expand_as(x) + mean.expand_as(x) * (f_y.expand_as(x)-g_y.expand_as(x))
        else:
            return x * g_y.expand_as(x)


if __name__ == '__main__':
    x = torch.randn(4, 512, 7, 7).cuda()
    # model = CrossNorm().cuda()
    model = SelfNorm(512).cuda()
    out = model(x)
    print(out.shape)

风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。