您现在的位置是:首页 >技术教程 >基于MMDet的Swin Transformer遥感图像检测网站首页技术教程

基于MMDet的Swin Transformer遥感图像检测

cbx0916 2024-06-14 17:18:30
简介基于MMDet的Swin Transformer遥感图像检测

主要使用swin trasnsformer试了一下sar图像的目标检测,用了舰船ssdd数据集和地面目标MSTAR数据集。

MMDet安装

MMDet地址:https://github.com/open-mmlab/mmdetection
直接pull下来后按照官方文档进行安装环境即可。
记得如果克隆环境或转移到别的环境,需要重新setup一下

python setup.py develop

Swin Transformer代码

1.创建configs下的配置文件

configs/swin下创建一个faster_rcnn_swin_l-p4-w12_coco.py
在这个文件中可以修改学习率、迭代次数等参数。

_base_ = [
    '../_base_/models/faster_rcnn_swin_large_fpn.py',
    '../_base_/datasets/faster_rcnn_coco_instance.py',
    '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]

optimizer = dict(
    _delete_=True,
    type='AdamW',
    # lr=0.0001,
    lr=0.000051,
    betas=(0.9, 0.999),
    weight_decay=0.05,
    paramwise_cfg=dict(
        custom_keys={
            'absolute_pos_embed': dict(decay_mult=0.),
            'relative_position_bias_table': dict(decay_mult=0.),
            'norm': dict(decay_mult=0.)
        }))
lr_config = dict(warmup_iters=1000, step=[27, 33])
runner = dict(type='EpochBasedRunner', max_epochs=36)

2.创建model文件

在/base/models/中新建faster_rcnn_swin_large_fpn.py文件
在文件中可以修改网络backbone、neck等配置,这里使用swin的large模型,PAFPN为neck。

# model settings
pretrained = 'D:/Project/mmdetection-master/checkpoints/swin_large_patch4_window12_384_22k.pth'

# 1. ROI 0.5-0.7
# 2. pafpn
# 3. albu_train_transforms
# 4. 多尺度

model = dict(
    type='FasterRCNN',
    backbone=dict(
        type='SwinTransformer',
        embed_dims=192,
        depths=[2, 2, 18, 2],
        num_heads=[6, 12, 24, 48],
        window_size=12,
        mlp_ratio=4,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.,
        attn_drop_rate=0.,
        drop_path_rate=0.2,
        patch_norm=True,
        out_indices=(0, 1, 2, 3),
        with_cp=False,
        convert_weights=True,
        init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
    neck=dict(
        type='PAFPN',
        in_channels=[192, 384, 768, 1536],
        out_channels=256,
        num_outs=5),
    rpn_head=dict(
        type='RPNHead',
        in_channels=256,
        feat_channels=256,
        anchor_generator=dict(
            type='AnchorGenerator',
            scales=[8],
            ratios=[0.5, 1.0, 2.0],
            strides=[4, 8, 16, 32, 64]),
        bbox_coder=dict(
            type='DeltaXYWHBBoxCoder',
            target_means=[.0, .0, .0, .0],
            target_stds=[1.0, 1.0, 1.0, 1.0]),
        loss_cls=dict(
            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
        loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
    roi_head=dict(
        type='StandardRoIHead',
        bbox_roi_extractor=dict(
            type='SingleRoIExtractor',
            roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
            out_channels=256,
            featmap_strides=[4, 8, 16, 32]),
        bbox_head=dict(
            type='Shared2FCBBoxHead',
            in_channels=256,
            fc_out_channels=1024,
            roi_feat_size=7,
            num_classes=10,
            bbox_coder=dict(
                type='DeltaXYWHBBoxCoder',
                target_means=[0., 0., 0., 0.],
                target_stds=[0.1, 0.1, 0.2, 0.2]),
            reg_class_agnostic=False,
            loss_cls=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
            loss_bbox=dict(type='L1Loss', loss_weight=1.0))),
    # model training and testing settings
    train_cfg=dict(
        rpn=dict(
            assigner=dict(
                type='MaxIoUAssigner',
                pos_iou_thr=0.7,
                neg_iou_thr=0.3,
                min_pos_iou=0.3,
                match_low_quality=True,
                ignore_iof_thr=-1),
            sampler=dict(
                type='RandomSampler',
                num=256,
                pos_fraction=0.5,
                neg_pos_ub=-1,
                add_gt_as_proposals=False),
            allowed_border=-1,
            pos_weight=-1,
            debug=False),
        rpn_proposal=dict(
            nms_pre=2000,
            max_per_img=1000,
            nms=dict(type='nms', iou_threshold=0.7),
            # nms=dict(type='nms', iou_threshold=0.6),
            min_bbox_size=0),
        rcnn=dict(
            assigner=dict(
                type='MaxIoUAssigner',
                pos_iou_thr=0.5,
                neg_iou_thr=0.5,
                min_pos_iou=0.5,
                match_low_quality=False,
                ignore_iof_thr=-1),
            sampler=dict(
                type='RandomSampler',
                num=512,
                pos_fraction=0.25,
                neg_pos_ub=-1,
                add_gt_as_proposals=True),
            pos_weight=-1,
            debug=False)),
    test_cfg=dict(
        rpn=dict(
            nms_pre=1000,
            max_per_img=1000,
            nms=dict(type='nms', iou_threshold=0.7),
            min_bbox_size=0),
        rcnn=dict(
            score_thr=0.05,
            # nms=dict(type='nms', iou_threshold=0.5),
            nms=dict(type='nms', iou_threshold=0.6),
            max_per_img=100)
        # soft-nms is also supported for rcnn testing
        # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05)
    ))

3.创建数据集文件

在/base/datasets/中创建faster_rcnn_coco_instance.py文件
代码包含了albu数据增强,可以调整图像大小,samples_per_gpu和 workers_per_gpu等。

# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco_mstar/'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

albu_train_transforms = [
    # dict(
    #     type='HorizontalFlip',
    #     p=0.5),
    # dict(
    #     type='VerticalFlip',
    #     p=0.5),

    dict(
        type='ShiftScaleRotate',
        shift_limit=0.0625,
        scale_limit=0.0,
        rotate_limit=180,
        interpolation=1,
        p=0.5),
    # dict(
    #     type='RandomBrightnessContrast',
    #     brightness_limit=[0.1, 0.3],
    #     contrast_limit=[0.1, 0.3],
    #     p=0.2),
    # dict(
    #     type='OneOf',
    #     transforms=[
    #         dict(
    #             type='RGBShift',
    #             r_shift_limit=10,
    #             g_shift_limit=10,
    #             b_shift_limit=10,
    #             p=1.0),
    #         dict(
    #             type='HueSaturationValue',
    #             hue_shift_limit=20,
    #             sat_shift_limit=30,
    #             val_shift_limit=20,
    #             p=1.0)
    #     ],
    #     p=0.1),
    # # dict(type='JpegCompression', quality_lower=85, quality_upper=95, p=0.2),
    #
    # dict(type='ChannelShuffle', p=0.1),
    # dict(
    #     type='OneOf',
    #     transforms=[
    #         dict(type='Blur', blur_limit=3, p=1.0),
    #         dict(type='MedianBlur', blur_limit=3, p=1.0)
    #     ],
    #     p=0.1),
]

train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
    # dict(type='Resize', img_scale=(448, 448), keep_ratio=True),
    dict(type='Resize', img_scale=[(768, 768), (1333,800)], keep_ratio=True),
    dict(type='RandomFlip', flip_ratio=0.5),
    # dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size_divisor=32),
    # dict(type='DefaultFormatBundle'),
    # dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),

    dict(
            type='Albu',
            transforms=albu_train_transforms,
            bbox_params=dict(
                type='BboxParams',
                format='pascal_voc',
                label_fields=['gt_labels'],
                min_visibility=0.0,
                filter_lost_elements=True),
            keymap={
                'img': 'image',
                'gt_bboxes': 'bboxes'
            },
            update_pad_shape=False,
            skip_img_without_anno=True),

    dict(type='Normalize', **img_norm_cfg),
    dict(type='DefaultFormatBundle'),
    dict(
        type='Collect',
        keys=['img', 'gt_bboxes', 'gt_labels'],
        meta_keys=('filename', 'ori_shape', 'img_shape', 'img_norm_cfg',
                   'pad_shape', 'scale_factor')
    )
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        # img_scale=(448, 448),
        img_scale=[(768, 768), (1333, 800)],
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='Pad', size_divisor=32),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]
data = dict(
    samples_per_gpu=4,
    workers_per_gpu=4,
    train=dict(
        type=dataset_type,
        # ann_file=data_root + 'annotations/instances_train2017.json',
        # img_prefix=data_root + 'train2017/',
        ann_file=data_root + 'annotations/train.json',
        img_prefix=data_root + 'train2017/',
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/val.json',
        img_prefix=data_root + 'val2017/',
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/val.json',
        img_prefix=data_root + 'val2017/',
        pipeline=test_pipeline))
evaluation = dict(metric=['bbox'])

4.修改数据集classes

在mmdet/datasets/ 的 coco.py文件中写明类别。


    # CLASSES = ('ship',)
    CLASSES = ('2S1', 'BMP2', 'BRDM2', 'BTR60', 'BTR70', 'D7', 'T62', 'T72', 'ZIL131', 'ZSU234',)
    # CLASSES = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9',)

5.训练model

在选定的conda环境下执行

python tools/train.py configs/swin/faster_rcnn_swin_l-p4-w12_coco.py

会在根目录的work_dirs生成pth文件,如果报错可以自己建文件夹。

6.测试ckpt

在选定的conda环境

python demo/image_demo.py data/coco/train2017/000092.jpg configs/swin/faster_rcnn_swin_l-p4-w12_coco.py work_dirs/faster_rcnn_swin_l-p4-w12_coco.py/epoch_12.pth 

7.测试model

python tools/test.py configs/swin/faster_rcnn_swin_l-p4-w12_coco.py work_dirs/faster_rcnn_swin_l-p4-w12_coco.py/epoch_12.pth --eval bbox --out work_dirs/faster_rcnn_swin_l-p4-w12_coco.py/result.pkl --show-score-thr 0.5 --show-dir work_dirs/faster_rcnn_swin_l-p4-w12_coco.py/eval/ --eval-options "classwise=True"
风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。