您现在的位置是:首页 >技术教程 >经典神经网络(6)ResNet及其在Fashion-MNIST数据集上的应用网站首页技术教程

经典神经网络(6)ResNet及其在Fashion-MNIST数据集上的应用

undo_try 2024-07-12 06:01:02
简介经典神经网络(6)ResNet及其在Fashion-MNIST数据集上的应用

经典神经网络(6)ResNet及其在Fashion-MNIST数据集上的应用

1 ResNet的简述

  1. ResNet 提出了一种残差学习框架来解决网络退化问题,从而训练更深的网络。这种框架可以结合已有的各种网络结构,充分发挥二者的优势。

  2. ResNet以三种方式挑战了传统的神经网络架构:

    • ResNet 通过引入跳跃连接来绕过残差层,这允许数据直接流向任何后续层。

      这与传统的、顺序的pipeline 形成鲜明对比:传统的架构中,网络依次处理低级feature 到高级feature

    • ResNet 的层数非常深,高达1202层。而ALexNet 这样的架构,网络层数要小两个量级。

    • 通过实验发现,训练好的 ResNet 中去掉单个层并不会影响其预测性能。而训练好的AlexNet 等网络中,移除层会导致预测性能损失。

  3. ImageNet分类数据集中,拥有152层的残差网络,以3.75% top-5 的错误率获得了ILSVRC 2015 分类比赛的冠军。

  4. 很多证据表明:残差学习是通用的,不仅可以应用于视觉问题,也可应用于非视觉问题。

  5. 论文地址: https://arxiv.org/pdf/1512.03385.pdf

  6. 卷积神经网络领域的两次技术爆炸,第一次是AlexNet,第二次就是ResNet了。

1.1 网络退化问题

  • 1、理论上来讲网络深度越深越好。网络越深,提取的图片特征越多越丰富,但随之会带来很多的问题(通过Batch Normalization 在很大程度上解决),比如过拟合或者计算量爆炸、梯度消失、梯度爆炸等,导致网络在一定深度下就达到了局部最优解。

  • 2、ResNet 论文作者发现:随着网络的深度的增加,准确率达到饱和之后迅速下降,而这种下降不是由过拟合引起的。这称作网络退化问题。如果更深的网络训练误差更大,则说明是由于优化算法引起的:越深的网络,求解优化问题越难。如下所示:更深的网络导致更高的训练误差和测试误差。

在这里插入图片描述

  • 3、理论上讲,较深的模型不应该比和它对应的、较浅的模型更差。因为较深的模型是较浅的模型的超空间。较深的模型可以这样得到:先构建较浅的模型,然后添加很多恒等映射的网络层。实际上我们的较深的模型后面添加的不是恒等映射,而是一些非线性层。因此,退化问题表明:通过多个非线性层来近似横等映射可能是困难的

在这里插入图片描述

  • 4、针对这⼀问题,何恺明等⼈提出了残差⽹络(ResNet)。它在2015年的ImageNet图像识别挑战赛夺魁,并深刻影响了后来的深度神经⽹络的设计。残差⽹络的核⼼思想是:每个附加层都应该更容易地包含原始函数作为其元素之⼀

1.2 残差块(residual blocks)

1.2.1 残差块的理解

在这里插入图片描述

1、假设需要学习的是映射 y = H(x),残差块使用堆叠的非线性层拟合残差:y = F(x,W) + x 。

其中:

  • x 和 y 是块的输入和输出向量。
  • F(x,W)是要学习的残差映射。因为 F(x,W) = H(x) - x,因此称F为残差。
  • + :通过快捷连接逐个元素相加来执行。快捷连接 指的是那些跳过一层或者更多层的连接。
    • 快捷连接简单的执行恒等映射,并将其输出添加到堆叠层的输出。
    • 快捷连接既不增加额外的参数,也不增加计算复杂度。
  • 相加之后通过非线性激活函数,这可以视作对整个残差块添加非线性,即 relu(y)

2、残差映射易于捕捉恒等映射的细微波动。比如5正常映射为5.1,加入残差后变成 5+0.1。此时输入变成5.2,对于没有残差结构的结果,影响仅为0.1/5.1 = 2%。而对于残差结构,变成 5+0.2 , 由0.1变成了0.2 影响为100%。

3、残差映射 H ( x ) = F ( x ) + x ,在反向传播的时候就变成了 H ′ ( x ) = F ′ ( x ) + 1,这里的加1也可以保证梯度消失现象

4、作者也证明了退化问题在任何数据集上都普遍存在。在imagenet上拿到冠军之后,迁移学习用到了coco同样拿到了好几个赛道的冠军,说明残差结构是普适的。最后又和VGG比了一下,比VGG深了8倍,计算复杂性却还比VGG小 。

1.2.2 残差函数F的形式的可变性

  • 层数可变:论文中的实验包含有两层堆叠、三层堆叠,实际任务中也可以包含更多层的堆叠。

    如果F只有一层,则残差块退化线性层:y = Wx + x 。此时对网络并没有什么提升。

  • 连接形式可变:不仅可用于全连接层,可也用于卷积层。此时F代表多个卷积层的堆叠,而最终的逐元素加法+ 在两个feature map 上逐通道进行。

    此时 x 也是一个feature map,而不再是一个向量。

1.2.3 残差学习成功的原因

学习残差F(x,W)比学习原始映射H(x)要更容易。

  • 1、当原始映射H就是一个恒等映射时, 就是一个F零映射。此时求解器只需要简单的将堆叠的非线性连接的权重推向零即可。

    实际任务中原始映射 H可能不是一个恒等映射:

    • 如果H 更偏向于恒等映射(而不是更偏向于非恒等映射),则F就是关于恒等映射的抖动,会更容易学习。
    • 如果原始映射H 更偏向于零映射,那么学习 本身要更容易。但是在实际应用中,零映射非常少见,因为它会导致输出全为0。
  • 2、如果原始映射H是一个非恒等映射,则可以考虑对残差模块使用缩放因子。如Inception-Resnet 中:在残差模块与快捷连接叠加之前,对残差进行缩放。注意:ResNet 作者在随后的论文中指出:不应该对恒等映射进行缩放。

  • 3、可以通过观察残差 F的输出来判断:如果F的输出均为0附近的、较小的数,则说明原始映射H更偏向于恒等映射;否则,说明原始映射H更偏向于非横等映射。

1.2.4 残差块代码实现

在这里插入图片描述

from torch import nn
from torch.nn import functional as F
import torch

'''
⼀种是当use_1x1conv=False时,应⽤ReLU⾮线性函数之前,将输⼊添加到输出。
另⼀种是当use_1x1conv=True时,添加通过1 × 1卷积调整通道和分辨率



ResNet沿⽤了VGG完整的3 × 3卷积层设计。
残差块⾥⾸先有2个有相同输出通道数的3 × 3卷积层。
每个卷积层后接⼀个批量规范化层和ReLU激活函数。
然后我们通过跨层数据通路,跳过这2个卷积运算,将输⼊直接加在最后的ReLU激活函数前。
这样的设计要求2个卷积层的输出与输⼊形状⼀样,从⽽使它们可以相加。

如果想改变通道数,就需要引⼊⼀个额外的1 × 1卷积层来将输⼊变换成需要的形状后再做相加运算。
'''
class Residual(nn.Module):

    def __init__(self,input_channels, num_channels,use_1x1conv=False, strides=1):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, num_channels,kernel_size=3, padding=1, stride=strides)
        self.conv2 = nn.Conv2d(num_channels, num_channels,kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels, num_channels,
                                   kernel_size=1, stride=strides)
        else:
            self.conv3 = None

        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)

    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y += X
        return F.relu(Y)


if __name__ == '__main__':
    blk = Residual(3, 3)
    X = torch.rand(4, 3, 6, 6)
    Y = blk(X)
    print(Y.shape)  # 输⼊和输出形状⼀致 torch.Size([4, 3, 6, 6])
    blk = Residual(3, 6, use_1x1conv=True, strides=2)
    Y = blk(X)
    print(Y.shape)  # 在增加输出通道数的同时,减半输出的高和宽 torch.Size([4, 6, 3, 3])

1.3 ResNet网络

1.3.1 四种plain 网络

plain 网络:一些简单网络结构的叠加,如下图所示。图中给出了四种plain 网络,它们的区别主要是网络深度不同。其中,输入图片尺寸 224x224 。

ResNet 简单的在plain 网络上添加快捷连接来实现。

FLOPsfloating point operations 的缩写,意思是浮点运算量,用于衡量算法/模型的复杂度。

FLOPSfloating point per second的缩写,意思是每秒浮点运算次数,用于衡量计算速度。

在这里插入图片描述

相对于输入的feature map,残差块的输出feature map 尺寸可能会发生变化:

  • 输出 feature map 的通道数增加,此时需要扩充快捷连接的输出feature map 。否则快捷连接的输出 feature map 无法和残差块的feature map 累加。

    有两种扩充方式:

    • 直接通过 0 来填充需要扩充的维度。
    • 通过1x1 卷积来扩充维度。
  • 输出 feature map 的尺寸减半。此时需要对快捷连接执行步长为 2 的池化/卷积:如果快捷连接已经采用 1x1 卷积,则该卷积步长为2 ;否则采用步长为 2 的最大池化 。

1.3.2 模型预测能力

VGG-1934层 plain 网络Resnet-34
计算复杂度(FLOPs)19.6 billion3.5 billion3.6 billion

ImageNet 验证集上执行10-crop 测试的结果。

  • A 类模型:快捷连接中,所有需要扩充的维度的填充 0 。
  • B 类模型:快捷连接中,所有需要扩充的维度通过1x1 卷积来扩充。
  • C 类模型:所有快捷连接都通过1x1 卷积来执行线性变换。

C 优于BB 优于A。但是 C 引入更多的参数,相对于这种微弱的提升,性价比较低。所以后续的ResNet 均采用 B 类模型。

模型top-1 误差率top-5 误差率
VGG-1628.07%9.33%
GoogleNet-9.15%
PReLU-net24.27%7.38%
plain-3428.54%10.02%
ResNet-34 A25.03%7.76%
ResNet-34 B24.52%7.46%
ResNet-34 C24.19%7.40%
ResNet-5022.85%6.71%
ResNet-10121.75%6.05%
ResNet-15221.43%5.71%

1.3.3 ResNet-18实现

import torch.nn as nn
import torch
from _06_Residual import Residual


class ResNet18(nn.Module):

    def __init__(self):
        super(ResNet18, self).__init__()
        self.model = self.get_net()

    def forward(self, X):
        X = self.model(X)
        return X


    def get_net(self):
        '''
        ResNet的前两层跟GoogLeNet中的⼀样:
           在输出通道数为64、步幅为2的7 × 7卷积层后,接步幅为2的3 × 3的最⼤汇聚层。
           不同之处在于ResNet每个卷积层后增加了批量规范化层。
        '''
        b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
                           nn.BatchNorm2d(64), nn.ReLU(),
                           nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

        '''
        GoogLeNet在后⾯接了4个由Inception块组成的模块。
        
        ResNet则使⽤4个由残差块组成的模块,每个模块使⽤若⼲个同样输出通道数的残差块。
        第⼀个模块的通道数同输⼊通道数⼀致。由于之前已经使⽤了步幅为2的最⼤汇聚层,所以⽆须减⼩⾼和宽。
        之后的每个模块在第⼀个残差块⾥将上⼀个模块的通道数翻倍,并将⾼和宽减半。
        '''
        b2 = nn.Sequential(*self.resnet_block(64, 64, 2, first_block=True))
        b3 = nn.Sequential(*self.resnet_block(64, 128, 2))
        b4 = nn.Sequential(*self.resnet_block(128, 256, 2))
        b5 = nn.Sequential(*self.resnet_block(256, 512, 2))
        net = nn.Sequential(b1, b2, b3, b4, b5,
                            nn.AdaptiveAvgPool2d((1, 1)),
                            nn.Flatten(), nn.Linear(512, 10))
        return net

    def resnet_block(self, input_channels, num_channels, num_residuals, first_block=False):
        blk = []
        for i in range(num_residuals):
            if i == 0 and not first_block:
                blk.append(Residual(input_channels, num_channels, use_1x1conv=True, strides=2))
            else:
                blk.append(Residual(num_channels, num_channels))
        return blk

if __name__ == '__main__':
    net = ResNet18()
    X = torch.rand(size=(1, 1, 224, 224), dtype=torch.float32)
    for layer in net.model:
        X = layer(X)
        print(layer.__class__.__name__, 'output shape:', X.shape)
Sequential output shape: torch.Size([1, 64, 56, 56])
Sequential output shape: torch.Size([1, 64, 56, 56])
Sequential output shape: torch.Size([1, 128, 28, 28])
Sequential output shape: torch.Size([1, 256, 14, 14])
Sequential output shape: torch.Size([1, 512, 7, 7])
AdaptiveAvgPool2d output shape: torch.Size([1, 512, 1, 1])
Flatten output shape: torch.Size([1, 512])
Linear output shape: torch.Size([1, 10])

2 ResNet-18在Fashion-MNIST数据集上的应用示例

2.1 创建ResNet网络模型

如1.2.4及1.3.3代码所示。

2.2 读取Fashion-MNIST数据集

其他所有的函数,与经典神经网络(1)LeNet及其在Fashion-MNIST数据集上的应用完全一致。

batch_size = 256

# 为了使Fashion-MNIST上的训练短⼩精悍,将输⼊的⾼和宽从224降到96,简化计算
train_iter,test_iter = get_mnist_data(batch_size,resize=96)

2.3 在GPU上进行模型训练

from _06_ResNet18 import ResNet18

# 初始化模型
net = ResNet18()

lr, num_epochs = 0.05, 10
train_ch(net, train_iter, test_iter, num_epochs, lr, try_gpu())

在这里插入图片描述

3 ResNet18微调进行水果图像分类

3.1 爬取水果图像数据

import requests
import urllib3
urllib3.disable_warnings()

import time
import os
import random
import pandas as pd
import shutil
import matplotlib.pyplot as plt
%matplotlib inline
import cv2
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler



# 忽略烦人的红色提示
import warnings
warnings.filterwarnings("ignore")

# 进度条库
from tqdm import tqdm
# http请求参数
cookies = {
    'BDqhfp': '%E7%8B%97%E7%8B%97%26%26NaN-1undefined%26%2618880%26%2621',
    'BIDUPSID': '06338E0BE23C6ADB52165ACEB972355B',
    'PSTM': '1646905430',
    'BAIDUID': '104BD58A7C408DABABCAC9E0A1B184B4:FG=1',
    'BDORZ': 'B490B5EBF6F3CD402E515D22BCDA1598',
    'H_PS_PSSID': '35836_35105_31254_36024_36005_34584_36142_36120_36032_35993_35984_35319_26350_35723_22160_36061',
    'BDSFRCVID': '8--OJexroG0xMovDbuOS5T78igKKHJQTDYLtOwXPsp3LGJLVgaSTEG0PtjcEHMA-2ZlgogKK02OTH6KF_2uxOjjg8UtVJeC6EG0Ptf8g0M5',
    'H_BDCLCKID_SF': 'tJPqoKtbtDI3fP36qR3KhPt8Kpby2D62aKDs2nopBhcqEIL4QTQM5p5yQ2c7LUvtynT2KJnz3Po8MUbSj4QoDjFjXJ7RJRJbK6vwKJ5s5h5nhMJSb67JDMP0-4F8exry523ioIovQpn0MhQ3DRoWXPIqbN7P-p5Z5mAqKl0MLPbtbb0xXj_0D6bBjHujtT_s2TTKLPK8fCnBDP59MDTjhPrMypomWMT-0bFH_-5L-l5js56SbU5hW5LSQxQ3QhLDQNn7_JjOX-0bVIj6Wl_-etP3yarQhxQxtNRdXInjtpvhHR38MpbobUPUDa59LUvEJgcdot5yBbc8eIna5hjkbfJBQttjQn3hfIkj0DKLtD8bMC-RDjt35n-Wqxobbtof-KOhLTrJaDkWsx7Oy4oTj6DD5lrG0P6RHmb8ht59JROPSU7mhqb_3MvB-fnEbf7r-2TP_R6GBPQtqMbIQft20-DIeMtjBMJaJRCqWR7jWhk2hl72ybCMQlRX5q79atTMfNTJ-qcH0KQpsIJM5-DWbT8EjHCet5DJJn4j_Dv5b-0aKRcY-tT5M-Lf5eT22-usy6Qd2hcH0KLKDh6gb4PhQKuZ5qutLTb4QTbqWKJcKfb1MRjvMPnF-tKZDb-JXtr92nuDal5TtUthSDnTDMRhXfIL04nyKMnitnr9-pnLJpQrh459XP68bTkA5bjZKxtq3mkjbPbDfn02eCKuj6tWj6j0DNRabK6aKC5bL6rJabC3b5CzXU6q2bDeQN3OW4Rq3Irt2M8aQI0WjJ3oyU7k0q0vWtvJWbbvLT7johRTWqR4enjb3MonDh83Mxb4BUrCHRrzWn3O5hvvhKoO3MA-yUKmDloOW-TB5bbPLUQF5l8-sq0x0bOte-bQXH_E5bj2qRCqVIKa3f',
    'BDSFRCVID_BFESS': '8--OJexroG0xMovDbuOS5T78igKKHJQTDYLtOwXPsp3LGJLVgaSTEG0PtjcEHMA-2ZlgogKK02OTH6KF_2uxOjjg8UtVJeC6EG0Ptf8g0M5',
    'H_BDCLCKID_SF_BFESS': 'tJPqoKtbtDI3fP36qR3KhPt8Kpby2D62aKDs2nopBhcqEIL4QTQM5p5yQ2c7LUvtynT2KJnz3Po8MUbSj4QoDjFjXJ7RJRJbK6vwKJ5s5h5nhMJSb67JDMP0-4F8exry523ioIovQpn0MhQ3DRoWXPIqbN7P-p5Z5mAqKl0MLPbtbb0xXj_0D6bBjHujtT_s2TTKLPK8fCnBDP59MDTjhPrMypomWMT-0bFH_-5L-l5js56SbU5hW5LSQxQ3QhLDQNn7_JjOX-0bVIj6Wl_-etP3yarQhxQxtNRdXInjtpvhHR38MpbobUPUDa59LUvEJgcdot5yBbc8eIna5hjkbfJBQttjQn3hfIkj0DKLtD8bMC-RDjt35n-Wqxobbtof-KOhLTrJaDkWsx7Oy4oTj6DD5lrG0P6RHmb8ht59JROPSU7mhqb_3MvB-fnEbf7r-2TP_R6GBPQtqMbIQft20-DIeMtjBMJaJRCqWR7jWhk2hl72ybCMQlRX5q79atTMfNTJ-qcH0KQpsIJM5-DWbT8EjHCet5DJJn4j_Dv5b-0aKRcY-tT5M-Lf5eT22-usy6Qd2hcH0KLKDh6gb4PhQKuZ5qutLTb4QTbqWKJcKfb1MRjvMPnF-tKZDb-JXtr92nuDal5TtUthSDnTDMRhXfIL04nyKMnitnr9-pnLJpQrh459XP68bTkA5bjZKxtq3mkjbPbDfn02eCKuj6tWj6j0DNRabK6aKC5bL6rJabC3b5CzXU6q2bDeQN3OW4Rq3Irt2M8aQI0WjJ3oyU7k0q0vWtvJWbbvLT7johRTWqR4enjb3MonDh83Mxb4BUrCHRrzWn3O5hvvhKoO3MA-yUKmDloOW-TB5bbPLUQF5l8-sq0x0bOte-bQXH_E5bj2qRCqVIKa3f',
    'indexPageSugList': '%5B%22%E7%8B%97%E7%8B%97%22%5D',
    'cleanHistoryStatus': '0',
    'BAIDUID_BFESS': '104BD58A7C408DABABCAC9E0A1B184B4:FG=1',
    'BDRCVFR[dG2JNJb_ajR]': 'mk3SLVN4HKm',
    'BDRCVFR[-pGxjrCMryR]': 'mk3SLVN4HKm',
    'ab_sr': '1.0.1_Y2YxZDkwMWZkMmY2MzA4MGU0OTNhMzVlNTcwMmM2MWE4YWU4OTc1ZjZmZDM2N2RjYmVkMzFiY2NjNWM4Nzk4NzBlZTliYWU0ZTAyODkzNDA3YzNiMTVjMTllMzQ0MGJlZjAwYzk5MDdjNWM0MzJmMDdhOWNhYTZhMjIwODc5MDMxN2QyMmE1YTFmN2QyY2M1M2VmZDkzMjMyOThiYmNhZA==',
    'delPer': '0',
    'PSINO': '2',
    'BA_HECTOR': '8h24a024042g05alup1h3g0aq0q'
}

headers = {
    'Connection': 'keep-alive',
    'sec-ch-ua': '" Not;A Brand";v="99", "Google Chrome";v="97", "Chromium";v="97"',
    'Accept': 'text/plain, */*; q=0.01',
    'X-Requested-With': 'XMLHttpRequest',
    'sec-ch-ua-mobile': '?0',
    'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.99 Safari/537.36',
    'sec-ch-ua-platform': '"macOS"',
    'Sec-Fetch-Site': 'same-origin',
    'Sec-Fetch-Mode': 'cors',
    'Sec-Fetch-Dest': 'empty',
    'Referer': 'https://image.baidu.com/search/index?tn=baiduimage&ipn=r&ct=201326592&cl=2&lm=-1&st=-1&fm=result&fr=&sf=1&fmq=1647837998851_R&pv=&ic=&nc=1&z=&hd=&latest=&copyright=&se=1&showtab=0&fb=0&width=&height=&face=0&istype=2&dyTabStr=MCwzLDIsNiwxLDUsNCw4LDcsOQ%3D%3D&ie=utf-8&sid=&word=%E7%8B%97%E7%8B%97',
    'Accept-Language': 'zh-CN,zh;q=0.9'
}


def download_single_class(file_path, keyword, DOWNLOAD_NUM=100):

    if not os.path.exists(file_path + "/dataset"):
        os.makedirs(file_path + "/dataset")
        print(f'新建{file_path}/dataset文件夹')

    if not os.path.exists(file_path + "/dataset/" + keyword):
        os.makedirs(file_path + "/dataset/"+ keyword)
        print('新建文件夹:{}/dataset/{}'.format(file_path, keyword))
    else:
        print('文件夹:{}/dataset/{}已经存在,之后将爬取的图片保存到该文件夹中'.format(file_path, keyword))


    count = 1

    with tqdm(total=DOWNLOAD_NUM, position=0, leave=True) as pbar:
        # 爬取第几张
        num = 1
        # 是否继续爬取
        FLAG = True
        while FLAG:
            page = 30 * count
            params = (
                ('tn', 'resultjson_com'),
                ('logid', '12508239107856075440'),
                ('ipn', 'rj'),
                ('ct', '201326592'),
                ('is', ''),
                ('fp', 'result'),
                ('fr', ''),
                ('word', f'{keyword}'),
                ('queryWord', f'{keyword}'),
                ('cl', '2'),
                ('lm', '-1'),
                ('ie', 'utf-8'),
                ('oe', 'utf-8'),
                ('adpicid', ''),
                ('st', '-1'),
                ('z', ''),
                ('ic', ''),
                ('hd', ''),
                ('latest', ''),
                ('copyright', ''),
                ('s', ''),
                ('se', ''),
                ('tab', ''),
                ('width', ''),
                ('height', ''),
                ('face', '0'),
                ('istype', '2'),
                ('qc', ''),
                ('nc', '1'),
                ('expermode', ''),
                ('nojc', ''),
                ('isAsync', ''),
                ('pn', f'{page}'),
                ('rn', '30'),
                ('gsm', '1e'),
                ('1647838001666', ''),
            )

            response = requests.get('https://image.baidu.com/search/acjson', headers=headers, params=params,
                                    cookies=cookies)
            if response.status_code == 200:
                try:
                    json_data = response.json().get("data")

                    if json_data:
                        for x in json_data:
                            type = x.get("type")
                            if type not in ["gif"]:
                                img = x.get("thumbURL")
                                fromPageTitleEnc = x.get("fromPageTitleEnc")
                                try:
                                    resp = requests.get(url=img, verify=False)
                                    time.sleep(1)
                                    # print(f"链接 {img}")

                                    # 保存文件名
                                    # file_save_path = f'dataset/{keyword}/{num}-{fromPageTitleEnc}.{type}'
                                    file_save_path = file_path + f'/dataset/{keyword}/{num}.{type}'
                                    with open(file_save_path, 'wb') as f:
                                        f.write(resp.content)
                                        f.flush()
                                        # print('第 {} 张图像 {} 爬取完成'.format(num, fromPageTitleEnc))
                                        num += 1
                                        pbar.update(1)  # 进度条更新

                                    # 爬取数量达到要求
                                    if num > DOWNLOAD_NUM:
                                        FLAG = False
                                        print('{} 张图像爬取完毕'.format(num - 1))
                                        break

                                except Exception:
                                    pass
                except:
                    pass
            else:
                break

            count += 1
# 测试爬取香蕉
file_path = 'D:pythonkagglepictures_classfication_data'
download_single_class(file_path,'香蕉', DOWNLOAD_NUM=2)
# 爬取多类水果
class_list = ['苹果','梨','葡萄','火龙果','大枣',
              '柑橘','柚子','桃','杏','西瓜',
              '荔枝','甘蔗','柿子','羊角蜜','香蕉',
              '菠萝','芒果','哈密瓜','石榴','椰子'
            ]

for class_name in class_list:
    download_single_class(file_path,class_name)
新建文件夹:D:pythonkagglepictures_classfication_data/dataset/苹果


100%|██████████| 100/100 [03:17<00:00,  1.98s/it]


100 张图像爬取完毕

新建文件夹:D:pythonkagglepictures_classfication_data/dataset/椰子


100%|██████████| 100/100 [02:57<00:00,  1.77s/it]

100 张图像爬取完毕

3.2 划分为训练集和测试集

file_path = file_path + '/dataset'

classes = os.listdir(file_path)

# 创建 train 文件夹
os.mkdir(os.path.join(file_path, 'train'))

# 创建 test 文件夹
os.mkdir(os.path.join(file_path, 'val'))

# 在 train 和 test 文件夹中创建各类别子文件夹
for fruit in classes:
    os.mkdir(os.path.join(file_path, 'train', fruit))
    os.mkdir(os.path.join(file_path, 'val', fruit))
test_frac = 0.2  # 测试集比例
random.seed(123) # 随机数种子,便于复现
df = pd.DataFrame()

print('{:^18} {:^18} {:^18}'.format('类别', '训练集数据个数', '测试集数据个数'))

for fruit in classes: # 遍历每个类别

    # 读取该类别的所有图像文件名
    old_dir = os.path.join(file_path, fruit)
    images_filename = os.listdir(old_dir)
    random.shuffle(images_filename) # 随机打乱


    # 划分训练集和测试集
    testset_numer = int(len(images_filename) * test_frac) # 测试集图像个数
    testset_images = images_filename[:testset_numer]      # 获取拟移动至 test 目录的测试集图像文件名
    trainset_images = images_filename[testset_numer:]     # 获取拟移动至 train 目录的训练集图像文件名

    # 移动图像至 test 目录
    for image in testset_images:
        old_img_path = os.path.join(file_path, fruit, image)         # 获取原始文件路径
        new_test_path = os.path.join(file_path, 'val', fruit, image) # 获取 test 目录的新文件路径
        shutil.move(old_img_path, new_test_path) # 移动文件

    # 移动图像至 train 目录
    for image in trainset_images:
        old_img_path = os.path.join(file_path, fruit, image)           # 获取原始文件路径
        new_train_path = os.path.join(file_path, 'train', fruit, image) # 获取 train 目录的新文件路径
        shutil.move(old_img_path, new_train_path) # 移动文件

     # 删除旧文件夹
    assert len(os.listdir(old_dir)) == 0 # 确保旧文件夹中的所有图像都被移动走
    shutil.rmtree(old_dir) # 删除文件夹

    # 输出每一类别的数据个数
    print('{:^18} {:^18} {:^18}'.format(fruit, len(trainset_images), len(testset_images)))
    # 保存到表格中
    df = df.append({'class':fruit, 'trainset':len(trainset_images), 'testset':len(testset_images)}, ignore_index=True)
        类别              训练集数据个数            测试集数据个数      
# 数据集各类别数量统计表格,导出为 csv 文件
df['total'] = df['trainset'] + df['testset']
df.head()
df.to_csv('数据量统计.csv', index=False)
classtestsettrainsettotal
0哈密瓜20.080.0100.0
1大枣20.080.0100.0
220.080.0100.0
3柑橘20.080.0100.0
4柚子20.080.0100.0

3.3 可视化图像文件

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes

#读取图像,解决imread不能读取中文路径路径的问题
def cv_imread(file_path):
    cv_img = cv2.imdecode(np.fromfile(file_path,dtype=np.uint8),-1)
    return cv_img
# 读取训练集【西瓜】文件夹所有的图像
folder_path = os.path.join(file_path ,'train' ,  '西瓜')

images = []
for each_img in os.listdir(folder_path):
    img_path = os.path.join(folder_path, each_img)
    img_bgr = cv_imread(img_path)
    img_rgb = cv2.cvtColor(img_bgr,cv2.COLOR_BGR2RGB)
    images.append(img_rgb)

show_images([images[i] for i in range(32)],num_rows=4, num_cols=8, scale=1.0)

在这里插入图片描述

3.4 利用ResNet-18迁移学习,训练数据集

'''
将dataset水果分类打成zip压缩包,上传的linux机器上,用GPU训练

Linux下的默认编码是UTF8,Windows下生成的zip文件中的编码是GBK/GB2312等.zip文件

在Linux下解压时出现乱码问题.执行一下命令:
unzip -O GB18030 dataset.zip
'''
file_path = '/root/autodl-fs/data/fruit20/dataset'
def try_gpu(i=0):
    if torch.cuda.device_count() >= i + 1:
        return torch.device(f'cuda:{i}')
    return torch.device('cpu')

'''
1、图像预处理
'''
from torchvision import transforms

# 训练集图像预处理:缩放裁剪、图像增强、转 Tensor、归一化
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                     ])

# 测试集图像预处理-RCTN:缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
                                    ])
'''
2、载入水果图像分类数据集
'''
train_path = os.path.join(file_path, 'train')
test_path = os.path.join(file_path, 'val')
print('训练集路径', train_path)
print('测试集路径', test_path)

from torchvision import datasets
# 载入训练集
train_dataset = datasets.ImageFolder(train_path, train_transform)
# 载入测试集
test_dataset = datasets.ImageFolder(test_path, test_transform)

print('训练集图像数量', len(train_dataset))
print('类别个数', len(train_dataset.classes))
print('各类别名称', train_dataset.classes)
print('测试集图像数量', len(test_dataset))
print('类别个数', len(test_dataset.classes))
print('各类别名称', test_dataset.classes)
训练集路径 /root/autodl-fs/data/fruit20/dataset/train
测试集路径 /root/autodl-fs/data/fruit20/dataset/val
训练集图像数量 1600
类别个数 20
各类别名称 ['哈密瓜', '大枣', '杏', '柑橘', '柚子', '柿子', '桃', '梨', '椰子', '火龙果', '甘蔗', '石榴', '羊角蜜', '芒果', '苹果', '荔枝', '菠萝', '葡萄', '西瓜', '香蕉']
测试集图像数量 400
类别个数 20
各类别名称 ['哈密瓜', '大枣', '杏', '柑橘', '柚子', '柿子', '桃', '梨', '椰子', '火龙果', '甘蔗', '石榴', '羊角蜜', '芒果', '苹果', '荔枝', '菠萝', '葡萄', '西瓜', '香蕉']
'''
3、类别索引  映射字典
'''
# 各类别名称
class_names = train_dataset.classes
n_class = len(class_names)
# 映射关系:类别 到 索引号
train_dataset.class_to_idx
# 映射关系:索引号 到 类别
idx_to_labels = {y:x for x,y in train_dataset.class_to_idx.items()}
idx_to_labels
{0: '哈密瓜',
 1: '大枣',
 2: '杏',
 3: '柑橘',
 4: '柚子',
 5: '柿子',
 6: '桃',
 7: '梨',
 8: '椰子',
 9: '火龙果',
 10: '甘蔗',
 11: '石榴',
 12: '羊角蜜',
 13: '芒果',
 14: '苹果',
 15: '荔枝',
 16: '菠萝',
 17: '葡萄',
 18: '西瓜',
 19: '香蕉'}
# 保存为本地的 npy 文件
np.save('idx_to_labels.npy', idx_to_labels)
np.save('labels_to_idx.npy', train_dataset.class_to_idx)
'''
4、加载数据集
'''
from torch.utils.data import DataLoader

BATCH_SIZE = 32

# 训练集的数据加载器
train_iter = DataLoader(train_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=True,
                          num_workers=0
                         )

# 测试集的数据加载器
test_iter = DataLoader(test_dataset,
                         batch_size=BATCH_SIZE,
                         shuffle=False,
                         num_workers=0
                        )
'''
5、微调最后一层,创建resnet-18模型
'''
net = torchvision.models.resnet18(pretrained=True) # 载入预训练模型

# 修改全连接层,使得全连接层的输出与当前数据集类别数对应
# 新建的层默认 requires_grad=True
net.fc = nn.Linear(net.fc.in_features, n_class)
'''
6、模型训练
'''

import torch.nn as nn
from AccumulatorClass import Accumulator

def accuracy(y_hat, y):
    """计算预测正确的数量"""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
         y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())



def evaluate_accuracy_gpu(net, data_iter, device=None):
    """使⽤GPU计算模型在数据集上的精度"""
    if isinstance(net, nn.Module):
        net.eval() # 设置为评估模式
        if not device:
            device = next(iter(net.parameters())).device
    # 正确预测的数量,总预测的数量
    metric = Accumulator(2)
    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(X, list):
                # BERT微调所需的
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            y = y.to(device)
            metric.add(accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]
from AnimatorClass import Animator
from TimerClass import Timer


def train_ch(net, train_iter, test_iter, num_epochs, lr, device):
    """⽤GPU训练模型"""
    print('training on', device)

    net.to(device)
    optimizer = torch.optim.SGD(net.fc.parameters(), lr=lr)
    
    # 只微调训练最后一层全连接层的参数,其它层冻结
    # optimizer = torch.optim.Adam(net.fc.parameters())
    # 学习率降低策略
    # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
    # 交叉熵损失
    loss = nn.CrossEntropyLoss()

    animator = Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])
    timer, num_batches = Timer(), len(train_iter)
    num_batches = len(train_iter)
    best_test_accuracy = 0.0

    for epoch in range(num_epochs):
        # 训练损失之和,训练准确率之和,样本数
        metric = Accumulator(3)
        net.train()
        for i, (X, y) in enumerate(train_iter):
            timer.start()
            optimizer.zero_grad()
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            l.backward()
            optimizer.step()
            # lr_scheduler.step()
            with torch.no_grad():
                metric.add(l * X.shape[0], accuracy(y_hat, y), X.shape[0])
            timer.stop()
            train_l = metric[0] / metric[2]
            train_acc = metric[1] / metric[2]
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))

        test_acc = evaluate_accuracy_gpu(net, test_iter)
        
        if test_acc > best_test_accuracy:
            # 删除旧的最佳模型文件(如有)
            old_best_checkpoint_path = 'checkpoint/best-{:.3f}.pth'.format(best_test_accuracy)
            if os.path.exists(old_best_checkpoint_path):
                os.remove(old_best_checkpoint_path)
            # 保存新的最佳模型文件
            new_best_checkpoint_path = 'checkpoint/best-{:.3f}.pth'.format(test_acc)
            torch.save(net, new_best_checkpoint_path)
            print('保存新的最佳模型', 'checkpoint/best-{:.3f}.pth'.format(test_acc))
            best_test_accuracy = test_acc
            
        
        animator.add(epoch + 1, (None, None, test_acc))
    print(f'best_test_accuracy = {best_test_accuracy:.3f}')    
    print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on {str(device)}')
# 初始化模型
lr, num_epochs = 0.1, 10
train_ch(net, train_iter, test_iter, num_epochs, lr, try_gpu())
best_test_accuracy = 0.840
loss 0.546, train acc 0.832, test acc 0.810
565.2 examples/sec on cuda:0

在这里插入图片描述

3.5 模型预测

best_test_accuracy = 0.840

# 载入最佳模型作为当前模型
net = torch.load('checkpoint/best-{:.3f}.pth'.format(best_test_accuracy))
net.to(try_gpu())



test_iter = DataLoader(test_dataset,
                         batch_size=BATCH_SIZE,
                         shuffle=True,
                         num_workers=0
                        )


def get_fruit_labels(labels):
    """返回fruit20数据集的⽂本标签"""
    text_labels = test_dataset.classes
    return [text_labels[int(i)] for i in labels]


def predict(net,test_iter, n=10):

    for X,y in test_iter:
        trues = get_fruit_labels(y[0:n])

        outputs = net(X.to(try_gpu())) # 输入模型,执行前向预测
        _, preds = torch.max(outputs, 1)
    
        preds = get_fruit_labels(
        preds.cpu().numpy()[0:n]
    )

        print('trues:',trues)
        print('preds:',preds)
        break


predict(net,test_iter)
trues: ['菠萝', '梨', '柿子', '大枣', '芒果', '菠萝', '芒果', '苹果', '桃', '菠萝']
preds: ['菠萝', '梨', '柿子', '大枣', '芒果', '菠萝', '芒果', '苹果', '桃', '菠萝']
风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。