您现在的位置是:首页 >技术交流 >【项目实践】猫十二分类网站首页技术交流
【项目实践】猫十二分类
【数据科学项目实践】基于ResNet和Inception v3的猫十二分类迁移学习
一、项目背景
本项目来源于飞浆平台的图像分类学习赛。指路链接
- 代码和结果来源于我的小组同学,没有做任何的改动,我这边仅做一个总结归纳,以便学习和复盘
 
简单把赛题Copy一下:
本场比赛要求参赛选手对十二种猫进行分类,属于CV方向经典的图像分类任务。图像分类任务作为其他图像任务的基石,可以让大家更快上手计算机视觉。
数据集
比赛数据集包含12种猫的图片,并划分为训练集与测试集。
训练集: 提供高清彩色图片以及图片所属的分类,共有2160张猫的图片,含标注文件。
测试集: 仅提供彩色图片,共有240张猫的图片,不含标注文件。
二、Baseline
2.1 准备阶段
主要是导入一些要用到的模块:
import os
import cv2
import torch
import torch.nn as nn
from torchvision import models,transforms
from torch.utils.data import DataLoader,Dataset
import numpy as np
from PIL import Image
from torch.optim import lr_scheduler
import copy
 
2.2 数据读取阶段
这个阶段就是如何将数据读取到模型中来,由于猫猫是图像数据,所以这边将其读取成数字图像一般是通过数组来存在内存中的,考虑到中间过程的可视化,我们通过PIL来读取Image类型的数据。这步可以写作:
x=np.fromfile(imgPath,dtype=np.float32) # 读取成ndarray
x=cv2.imdecode(x,1) # 将区间转化为[0,255]
img=PIL.Image.fromarray(x) # 读取成Image对象
 

上图中,左边的是Image类型的数据,右边是cv读取的数据,可以发现发生了颜色通道的调换。实际上,读取到cv这部分就好了,可以调用多窗口的imshow进行数据可视化。
我们现在拿到了猫猫图像!那么接下来就要拿到猫猫的标签啦,一般情况下,我们会将数据跟标签记录在一个文档里,每一行对应一个数据(图片)路径和一个标签:
# 文件标签
filelist=r"data_split_list.txt"
imgs,labels=[],[] # 存储列表
with open(filelist) as f:
    lines=[_.strip() for _ in f] # 去除空白
    np.random.shuffle(lines) # 随机打乱
    for l in lines:
        img_path,label=l.split('	') # 获取图片路径和标签
        img=Image.fromarray(cv2.imdecode(np.fromfile(img_path,np.float32),1))
        imgs.append(img)
        labels.append(label)
 
我们将这部分工作封装成一个函数,就可以实现数据的读取了。
接下来的工作,就是将数据转化为PyTorch接受的格式啦。众所周知,PyTorch的模型训练跟推理一般是通过迭代一个DataLoader对象来进行的,而DataLoader对象的数据集是一个DataSet类。所以这里我们需要构建一个Dataset类啦:
class myData(Dataset):
    
    def __init__(self):
        super(myData,self).__init__()
        self.data=[]
    
    def __getitem__(self,x):
        return self.data[x]
    
    def __len__(self):
        return len(self.data)
 
嗯,把上面三个函数填完就阔以啦。
对于图像数据,我们需要应用一个transforms,这里做最简单的变换:转为Tensor,尺寸裁剪,标准化。
self.transform=transforms.Compose(
    transforms.ToTensor(),
    transforms.Resize((299,299)),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
)
 
最终的Dataset如下:
class myData(Dataset):
    def __init__(self,kind):
        super(myData, self).__init__()
        self.mode=kind
        self.transform=transforms.Compose(
            transforms.ToTensor(),
            transforms.Resize((299,299)),
            transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
        )
        if kind=="test":
            self.imgs=self.load_origin_data()
        else:
            self.imgs,self.labels=self.load_origin_data()
    def __getitem__(self, item):
        if self.mode=="test":
            return self.transform(self.imgs[item])
        else:
            return self.transform(self.imgs[item]),torch.tensor(self.labels[item])
    def __len__(self):
        return len(self.imgs)
    def load_origin_data(self):
        filelist = './data/%s_split_list.txt' % self.mode
        imgs,labels=[],[]
        data_dir=os.getcwd()+"/data"
        if self.mode=='train' or self.mode=='val':
            with open(filelist) as f:
                lines=[_.strip() for _ in f]
                if self.mode=='train':
                    np.random.shuffle(lines)
                    for l in lines:
                        img_path,label=l.split('	')
                        img_path=os.path.join(data_dir,img_path)
                        try:
                            img=Image.fromarray(cv2.imdecode(np.fromfile(img_path,dtype=np.float32),1))
                            imgs.append(img)
                            labels.append(int(label))
                        except Exception("The path %s"%img_path+" may be wrong") as e:
                            print(e)
                            continue
                    return imgs,labels
                elif self.mode=="test":
                    full_lines = os.listdir('data/cat_12_test/')
                    lines = [line.strip() for line in full_lines]
                    for img_path in lines:
                        img_path = os.path.join(data_dir, "cat_12_test/", img_path)
                        img = Image.open(img_path)
                        imgs.append(img)
                    return imgs
 
2.3 模型训练
我们刚刚说PyTorch的模型训练跟推理一般是通过迭代一个DataLoader对象来进行的,现在就是需要构建这个东西啦:
def get_Dataloader():
    img_datasets = {x: myData(x) for x in ['train', 'val', 'test']}
    dataset_sizes = {x: len(img_datasets[x]) for x in ['train', 'val', 'test']}
    train_loader = DataLoader(
        dataset=img_datasets['train'],
        batch_size=24,
        shuffle=True
    )
    val_loader = DataLoader(
        dataset=img_datasets['val'],
        batch_size=1,
        shuffle=False
    )
    test_loader = DataLoader(
        dataset=img_datasets['test'],
        batch_size=1,
        shuffle=False
    )
    dataloaders = {
        'train': train_loader,
        'val': val_loader,
        'test': test_loader
    }
    return dataset_sizes,dataloaders
 
接下来就是单纯的训练过程了。步骤总结如下:
- 参数设置阶段 
  
- 设置GPU
 - 设置优化器、损失函数、学习策略
 
 - 训练过程 
  
- 迭代DataLoader
 - 优化器梯度清零
 - 模型推理
 - 误差计算
 - 反向传播
 - 更新优化器、学习率
 
 - 模型评估 
  
- 计算每轮的误差累计值、精度
 - 选择最优精度并进行模型保存
 
 
def Train(model,criterion,optimizer,scheduler,num_epoches=25):
    dataset_sizes,dataloaders=get_Dataloader()
    device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    best_model_wts=copy.deepcopy(model.state_dict())
    best_acc=0.0
    for epoch in range(num_epoches):
        print("Epoch {}/{}".format(epoch+1,num_epoches))
        for phase in ['train','val']:
            if phase=="train":
                model.train()
            else:
                model.eval()
            trian_loss=0.0
            train_corrects=0
            for inputs,labels in dataloaders[phase]:
                inputs,labels=inputs.to(device),labels.to(device)
                optimizer.zero_grad()
                with torch.set_grad_enabled(phase=="train"):
                    # 上下文管理器,参数是Bool,用于确定是否对Block内的语句进行求导
                    y_pre=model(inputs)
                    _,y_pre=torch.max(y_pre,1)
                    loss=criterion(y_pre,labels)
                    if phase=="train":
                        loss.backward()
                        optimizer.step()
                trian_loss+=loss.item()*inputs.size(0)
                train_corrects+=torch.sum(y_pre==labels)
            if phase=="train":
                scheduler.step()
            epoch_loss=trian_loss/dataset_sizes[phase]
            epoch_acc=train_corrects.float()/dataset_sizes[phase]
            print("{} Loss :{:.4f} Acc {:.4}".format(phase,epoch_loss,epoch_acc))
            if phase=="val" and epoch_acc>best_acc:
                best_acc=epoch_acc
                best_model_wts=copy.deepcopy(model.state_dict())
    print("Best val Acc : {:4f}".format(best_acc))
    model.load_state_dict(best_model_wts)
    return model
 
三、迁移学习
迁移学习(Transfer Learning)就是利用预训练好的大模型参数去学习其他数据的分布。
这个过程我们一般不希望原始模型参数改变,因而一般需要做如下工作:
for param in model.parameters():
    param.requires_grad=False
 
然后,我们需要构架最后一层全连接层,用来学习新的数据集:
model.fc=nn.Linear(2048,num_classes)
 
也就是最后需要训练的就是这个全连接层了。
def Inception(device):
    # 用训练好的模型进行迁移
    model_ft=models.inception_v3(pretrained=True)
    # model_ft=models.resnet50(pretrained=True)
    # model_ft=models.alexnet(pretrained=True)
    num_ftrs=model_ft.fc.in_features
    model_ft.fc=nn.Linear(num_ftrs,12) # 设置全连接层最终结果
    
    model_ft=model_ft.to(device)
    cirterion=nn.CrossEntropyLoss()
    optimizer_ft=torch.optim.SGD(model_ft.parameters(),lr=0.001,momentum=0.9)
    exp_lr_scheduler=lr_scheduler.StepLR(optimizer_ft,step_size=5,gamma=0.1)
    model_ft=Train(model_ft,cirterion,optimizer_ft,exp_lr_scheduler,num_epoches=30)
 
四、结果分析
-  
Inception
Epoch 30/30 train Loss: 0.1065 Acc: 0.9858 val Loss: 0.3026 Acc: 0.8983 Best val Acc: 0.918336 -  
AlexNet
Epoch 30/30 train Loss: 0.1403 Acc: 0.9601 val Loss: 0.6815 Acc: 0.7750 Best val Acc: 0.779661 -  
ResNet50
Epoch 30/30 train Loss: 0.0480 Acc: 0.9973 val Loss: 0.3157 Acc: 0.9060 Best val Acc: 0.909091 
中间部分特征图的结果如下:

特征图嘛,主打的就是一个抽象。可以发现同一张图经过不同的卷积核作用后,有了全新的高维特征,这些特征也主打的就是一个难以解释,反正就看个乐。

基本上7个epoch就收敛了。
            




U8W/U8W-Mini使用与常见问题解决
QT多线程的5种用法,通过使用线程解决UI主界面的耗时操作代码,防止界面卡死。...
stm32使用HAL库配置串口中断收发数据(保姆级教程)
分享几个国内免费的ChatGPT镜像网址(亲测有效)
Allegro16.6差分等长设置及走线总结