您现在的位置是:首页 >其他 >BEVFormer组件分析网站首页其他

BEVFormer组件分析

FourierTransformer 2024-08-15 12:01:03
简介BEVFormer组件分析

BEVFormerEncoder中的get_reference_points


@staticmethod
    def get_reference_points(H, W, Z=8, num_points_in_pillar=4, dim='3d', bs=1, device='cuda', dtype=torch.float):
        """Get the reference points used in SCA and TSA.
        Args:
            H, W: spatial shape of bev.
            Z: hight of pillar.
            D: sample D points uniformly from each pillar.
            device (obj:`device`): The device where
                reference_points should be.
        Returns:
            Tensor: reference points used in decoder, has 
                shape (bs, num_keys, num_levels, 2).
        """

        # reference points in 3D space, used in spatial cross-attention (SCA)
        if dim == '3d':
            zs = torch.linspace(0.5, Z - 0.5, num_points_in_pillar, dtype=dtype,
                                device=device).view(-1, 1, 1).expand(num_points_in_pillar, H, W) / Z
            xs = torch.linspace(0.5, W - 0.5, W, dtype=dtype,
                                device=device).view(1, 1, W).expand(num_points_in_pillar, H, W) / W
            ys = torch.linspace(0.5, H - 0.5, H, dtype=dtype,
                                device=device).view(1, H, 1).expand(num_points_in_pillar, H, W) / H
            ref_3d = torch.stack((xs, ys, zs), -1)
            ref_3d = ref_3d.permute(0, 3, 1, 2).flatten(2).permute(0, 2, 1)
            ref_3d = ref_3d[None].repeat(bs, 1, 1, 1)
            return ref_3d

        # reference points on 2D bev plane, used in temporal self-attention (TSA).
        elif dim == '2d':
            ref_y, ref_x = torch.meshgrid(
                torch.linspace(
                    0.5, H - 0.5, H, dtype=dtype, device=device),
                torch.linspace(
                    0.5, W - 0.5, W, dtype=dtype, device=device)
            )
            ref_y = ref_y.reshape(-1)[None] / H
            ref_x = ref_x.reshape(-1)[None] / W
            ref_2d = torch.stack((ref_x, ref_y), -1)
            ref_2d = ref_2d.repeat(bs, 1, 1).unsqueeze(2)
            return ref_2d

根据上面的代码可以看出来,如果输入的是3d, 则是
按照:

  • X方向: 从0.5, 到W-0.5分成W份.
  • Y方向: 从0.5, 到H-0.5分成H份.
  • Z方向: 从0.5, 到Z-0.5, 分成 num_points_in_pillar份.
    其中num_points_in_pillar 默认给的是4.

配置文件里面给的其实也是4.
在这里插入图片描述

BEVFormerEncoder中的point_sampling

  # This function must use fp32!!!
    @force_fp32(apply_to=('reference_points', 'img_metas'))
    def point_sampling(self, reference_points, pc_range,  img_metas):
        lidar2img = []
        for img_meta in img_metas:
            lidar2img.append(img_meta['lidar2img'])
        lidar2img = np.asarray(lidar2img)
        lidar2img = reference_points.new_tensor(lidar2img)  # (B, N, 4, 4)
        reference_points = reference_points.clone()

        # 变换到点云的范围内. 这也是为何get_reference_points中会/H, /W, /Z, 先化到[0, 1]变成ratio.
        reference_points[..., 0:1] = reference_points[..., 0:1] * 
            (pc_range[3] - pc_range[0]) + pc_range[0]
        reference_points[..., 1:2] = reference_points[..., 1:2] * 
            (pc_range[4] - pc_range[1]) + pc_range[1]
        reference_points[..., 2:3] = reference_points[..., 2:3] * 
            (pc_range[5] - pc_range[2]) + pc_range[2]

        # 由(x, y, z) 变成(x, y, z, 1) 便于与4*4的参数矩阵相乘.
        reference_points = torch.cat(
            (reference_points, torch.ones_like(reference_points[..., :1])), -1)
        # 此时reference_points可以当成是点云的点了.

        reference_points = reference_points.permute(1, 0, 2, 3)
        # num_query等于H*W*Z. 等于grid_points的数量.
        D, B, num_query = reference_points.size()[:3]
        num_cam = lidar2img.size(1)

        # 要往每个相机上去投影. 因此先申请num_cam份.
        # reference_points的shape就变成了, (D, b, num_cam, num_query, 4, 1) 便于和4*4的矩阵做matmul.
        reference_points = reference_points.view(
            D, B, 1, num_query, 4).repeat(1, 1, num_cam, 1, 1).unsqueeze(-1)

        # 相机参数由(b,num_cam, 4, 4) 变成(1, b, num_cam, 1, 4, 4) 再变成(D,b,num_cam,num_query,4,4)
        lidar2img = lidar2img.view(
            1, B, num_cam, 1, 4, 4).repeat(D, 1, 1, num_query, 1, 1)

        reference_points_cam = torch.matmul(lidar2img.to(torch.float32),
                                            reference_points.to(torch.float32)).squeeze(-1)
        eps = 1e-5

        # 把每个相机后面的点mask掉. 因为相机后面的点投过来之后第三位是负的.
        bev_mask = (reference_points_cam[..., 2:3] > eps)
        # 再做齐次化. 得到像素坐标.
        reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum(
            reference_points_cam[..., 2:3], torch.ones_like(reference_points_cam[..., 2:3]) * eps)

        # 由像素坐标转成相对于图像的ratio..
        # NOTE 这里如果不同相机size不一样的话.要除以对应的相机的size
        reference_points_cam[..., 0] /= img_metas[0]['img_shape'][0][1]
        reference_points_cam[..., 1] /= img_metas[0]['img_shape'][0][0]

        # 再把超出图像fov范围的点给去掉.
        bev_mask = (bev_mask & (reference_points_cam[..., 1:2] > 0.0)
                    & (reference_points_cam[..., 1:2] < 1.0)
                    & (reference_points_cam[..., 0:1] < 1.0)
                    & (reference_points_cam[..., 0:1] > 0.0))
        if digit_version(TORCH_VERSION) >= digit_version('1.8'):
            bev_mask = torch.nan_to_num(bev_mask)
        else:
            bev_mask = bev_mask.new_tensor(
                np.nan_to_num(bev_mask.cpu().numpy()))

        # 由(D, b, num_cam, num_query, 2) 变成 (num_cam, b, num_query, D, 2)
        reference_points_cam = reference_points_cam.permute(2, 1, 3, 0, 4)
        bev_mask = bev_mask.permute(2, 1, 3, 0, 4).squeeze(-1)

        # 至此. reference_points_cam代表的就是像素点相对于各个相机的ratio.
        # bev_mask就代表哪些点是有效的
        return reference_points_cam, bev_mask

SpatialCrossAttention

个人理解SpatialCrossAttention其实就是正常的Deformable Attention, 只不过原始Deformable Attention中的
refer points是由网络产生的,
而现在的refer points 是由 虚拟的grid points往图像上投影得到的. 在相机参数固定的情况下, 此时的refer points是固定的.

下面是 SpatialCrossAttention这个模块的forward函数的部分代码

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

问题: 给固定的这些refer points 的收益是多大? 文章好像并没有提. 这一块儿感觉不充分.

另外, 显然这样虚拟的grid points 是不合理的, 因为有些地方可能就没有点, 但是还是能够投影到图像上的. 这里用真值的点应该会更好,
比如用lidar的points. 但是BEVFormer paper里面没有对比加入lidar后的效果.

在这里插入图片描述

风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。