您现在的位置是:首页 >其他 >【模型复杂度】torchsummary、torchstat和profile的使用网站首页其他
【模型复杂度】torchsummary、torchstat和profile的使用
简介【模型复杂度】torchsummary、torchstat和profile的使用
模型的复杂度分析也是不同模型比较的重要指标,包括模型参数、浮点运算次数(Floating point operations,FLOPs),内存占用和运存占用等,记录一下可以评价模型复杂度的方法。
1. torchsummary
torchsummary可计算模型的总参数和每一层的参数,但无法计算FLOPs。
以resnet18为例
import torchsummary
import torchvision.models as modelss
model = modelss.resnet18(pretrained=True)
torchsummary.summary(model, (3, 224, 224), device='cpu')
也可写为:
torchsummary.summary(model.cuda(), (3, 224, 224))
输出:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 112, 112] 9,408
BatchNorm2d-2 [-1, 64, 112, 112] 128
ReLU-3 [-1, 64, 112, 112] 0
MaxPool2d-4 [-1, 64, 56, 56] 0
Conv2d-5 [-1, 64, 56, 56] 36,864
BatchNorm2d-6 [-1, 64, 56, 56] 128
ReLU-7 [-1, 64, 56, 56] 0
Conv2d-8 [-1, 64, 56, 56] 36,864
BatchNorm2d-9 [-1, 64, 56, 56] 128
ReLU-10 [-1, 64, 56, 56] 0
BasicBlock-11 [-1, 64, 56, 56] 0
Conv2d-12 [-1, 64, 56, 56] 36,864
BatchNorm2d-13 [-1, 64, 56, 56] 128
ReLU-14 [-1, 64, 56, 56] 0
Conv2d-15 [-1, 64, 56, 56] 36,864
BatchNorm2d-16 [-1, 64, 56, 56] 128
ReLU-17 [-1, 64, 56, 56] 0
BasicBlock-18 [-1, 64, 56, 56] 0
Conv2d-19 [-1, 128, 28, 28] 73,728
BatchNorm2d-20 [-1, 128, 28, 28] 256
ReLU-21 [-1, 128, 28, 28] 0
Conv2d-22 [-1, 128, 28, 28] 147,456
BatchNorm2d-23 [-1, 128, 28, 28] 256
Conv2d-24 [-1, 128, 28, 28] 8,192
BatchNorm2d-25 [-1, 128, 28, 28] 256
ReLU-26 [-1, 128, 28, 28] 0
BasicBlock-27 [-1, 128, 28, 28] 0
Conv2d-28 [-1, 128, 28, 28] 147,456
BatchNorm2d-29 [-1, 128, 28, 28] 256
ReLU-30 [-1, 128, 28, 28] 0
Conv2d-31 [-1, 128, 28, 28] 147,456
BatchNorm2d-32 [-1, 128, 28, 28] 256
ReLU-33 [-1, 128, 28, 28] 0
BasicBlock-34 [-1, 128, 28, 28] 0
Conv2d-35 [-1, 256, 14, 14] 294,912
BatchNorm2d-36 [-1, 256, 14, 14] 512
ReLU-37 [-1, 256, 14, 14] 0
Conv2d-38 [-1, 256, 14, 14] 589,824
BatchNorm2d-39 [-1, 256, 14, 14] 512
Conv2d-40 [-1, 256, 14, 14] 32,768
BatchNorm2d-41 [-1, 256, 14, 14] 512
ReLU-42 [-1, 256, 14, 14] 0
BasicBlock-43 [-1, 256, 14, 14] 0
Conv2d-44 [-1, 256, 14, 14] 589,824
BatchNorm2d-45 [-1, 256, 14, 14] 512
ReLU-46 [-1, 256, 14, 14] 0
Conv2d-47 [-1, 256, 14, 14] 589,824
BatchNorm2d-48 [-1, 256, 14, 14] 512
ReLU-49 [-1, 256, 14, 14] 0
BasicBlock-50 [-1, 256, 14, 14] 0
Conv2d-51 [-1, 512, 7, 7] 1,179,648
BatchNorm2d-52 [-1, 512, 7, 7] 1,024
ReLU-53 [-1, 512, 7, 7] 0
Conv2d-54 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-55 [-1, 512, 7, 7] 1,024
Conv2d-56 [-1, 512, 7, 7] 131,072
BatchNorm2d-57 [-1, 512, 7, 7] 1,024
ReLU-58 [-1, 512, 7, 7] 0
BasicBlock-59 [-1, 512, 7, 7] 0
Conv2d-60 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-61 [-1, 512, 7, 7] 1,024
ReLU-62 [-1, 512, 7, 7] 0
Conv2d-63 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-64 [-1, 512, 7, 7] 1,024
ReLU-65 [-1, 512, 7, 7] 0
BasicBlock-66 [-1, 512, 7, 7] 0
AdaptiveAvgPool2d-67 [-1, 512, 1, 1] 0
Linear-68 [-1, 1000] 513,000
================================================================
Total params: 11,689,512
Trainable params: 11,689,512
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 62.79
Params size (MB): 44.59
Estimated Total Size (MB): 107.96
----------------------------------------------------------------
对Vision Transformer也同样适用
import torchsummary
import timm
model = timm.create_model('vit_small_patch16_224', pretrained=True)
torchsummary.summary(model.cuda(), (3, 224, 224))
输出为:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 768, 14, 14] 590,592
Identity-2 [-1, 196, 768] 0
PatchEmbed-3 [-1, 196, 768] 0
Dropout-4 [-1, 197, 768] 0
LayerNorm-5 [-1, 197, 768] 1,536
Linear-6 [-1, 197, 2304] 1,769,472
Dropout-7 [-1, 8, 197, 197] 0
Linear-8 [-1, 197, 768] 590,592
Dropout-9 [-1, 197, 768] 0
Attention-10 [-1, 197, 768] 0
Identity-11 [-1, 197, 768] 0
LayerNorm-12 [-1, 197, 768] 1,536
Linear-13 [-1, 197, 2304] 1,771,776
GELU-14 [-1, 197, 2304] 0
Dropout-15 [-1, 197, 2304] 0
Linear-16 [-1, 197, 768] 1,770,240
Dropout-17 [-1, 197, 768] 0
Mlp-18 [-1, 197, 768] 0
Identity-19 [-1, 197, 768] 0
Block-20 [-1, 197, 768] 0
LayerNorm-21 [-1, 197, 768] 1,536
Linear-22 [-1, 197, 2304] 1,769,472
Dropout-23 [-1, 8, 197, 197] 0
Linear-24 [-1, 197, 768] 590,592
Dropout-25 [-1, 197, 768] 0
Attention-26 [-1, 197, 768] 0
Identity-27 [-1, 197, 768] 0
LayerNorm-28 [-1, 197, 768] 1,536
Linear-29 [-1, 197, 2304] 1,771,776
GELU-30 [-1, 197, 2304] 0
Dropout-31 [-1, 197, 2304] 0
Linear-32 [-1, 197, 768] 1,770,240
Dropout-33 [-1, 197, 768] 0
Mlp-34 [-1, 197, 768] 0
Identity-35 [-1, 197, 768] 0
Block-36 [-1, 197, 768] 0
LayerNorm-37 [-1, 197, 768] 1,536
Linear-38 [-1, 197, 2304] 1,769,472
Dropout-39 [-1, 8, 197, 197] 0
Linear-40 [-1, 197, 768] 590,592
Dropout-41 [-1, 197, 768] 0
Attention-42 [-1, 197, 768] 0
Identity-43 [-1, 197, 768] 0
LayerNorm-44 [-1, 197, 768] 1,536
Linear-45 [-1, 197, 2304] 1,771,776
GELU-46 [-1, 197, 2304] 0
Dropout-47 [-1, 197, 2304] 0
Linear-48 [-1, 197, 768] 1,770,240
Dropout-49 [-1, 197, 768] 0
Mlp-50 [-1, 197, 768] 0
Identity-51 [-1, 197, 768] 0
Block-52 [-1, 197, 768] 0
LayerNorm-53 [-1, 197, 768] 1,536
Linear-54 [-1, 197, 2304] 1,769,472
Dropout-55 [-1, 8, 197, 197] 0
Linear-56 [-1, 197, 768] 590,592
Dropout-57 [-1, 197, 768] 0
Attention-58 [-1, 197, 768] 0
Identity-59 [-1, 197, 768] 0
LayerNorm-60 [-1, 197, 768] 1,536
Linear-61 [-1, 197, 2304] 1,771,776
GELU-62 [-1, 197, 2304] 0
Dropout-63 [-1, 197, 2304] 0
Linear-64 [-1, 197, 768] 1,770,240
Dropout-65 [-1, 197, 768] 0
Mlp-66 [-1, 197, 768] 0
Identity-67 [-1, 197, 768] 0
Block-68 [-1, 197, 768] 0
LayerNorm-69 [-1, 197, 768] 1,536
Linear-70 [-1, 197, 2304] 1,769,472
Dropout-71 [-1, 8, 197, 197] 0
Linear-72 [-1, 197, 768] 590,592
Dropout-73 [-1, 197, 768] 0
Attention-74 [-1, 197, 768] 0
Identity-75 [-1, 197, 768] 0
LayerNorm-76 [-1, 197, 768] 1,536
Linear-77 [-1, 197, 2304] 1,771,776
GELU-78 [-1, 197, 2304] 0
Dropout-79 [-1, 197, 2304] 0
Linear-80 [-1, 197, 768] 1,770,240
Dropout-81 [-1, 197, 768] 0
Mlp-82 [-1, 197, 768] 0
Identity-83 [-1, 197, 768] 0
Block-84 [-1, 197, 768] 0
LayerNorm-85 [-1, 197, 768] 1,536
Linear-86 [-1, 197, 2304] 1,769,472
Dropout-87 [-1, 8, 197, 197] 0
Linear-88 [-1, 197, 768] 590,592
Dropout-89 [-1, 197, 768] 0
Attention-90 [-1, 197, 768] 0
Identity-91 [-1, 197, 768] 0
LayerNorm-92 [-1, 197, 768] 1,536
Linear-93 [-1, 197, 2304] 1,771,776
GELU-94 [-1, 197, 2304] 0
Dropout-95 [-1, 197, 2304] 0
Linear-96 [-1, 197, 768] 1,770,240
Dropout-97 [-1, 197, 768] 0
Mlp-98 [-1, 197, 768] 0
Identity-99 [-1, 197, 768] 0
Block-100 [-1, 197, 768] 0
LayerNorm-101 [-1, 197, 768] 1,536
Linear-102 [-1, 197, 2304] 1,769,472
Dropout-103 [-1, 8, 197, 197] 0
Linear-104 [-1, 197, 768] 590,592
Dropout-105 [-1, 197, 768] 0
Attention-106 [-1, 197, 768] 0
Identity-107 [-1, 197, 768] 0
LayerNorm-108 [-1, 197, 768] 1,536
Linear-109 [-1, 197, 2304] 1,771,776
GELU-110 [-1, 197, 2304] 0
Dropout-111 [-1, 197, 2304] 0
Linear-112 [-1, 197, 768] 1,770,240
Dropout-113 [-1, 197, 768] 0
Mlp-114 [-1, 197, 768] 0
Identity-115 [-1, 197, 768] 0
Block-116 [-1, 197, 768] 0
LayerNorm-117 [-1, 197, 768] 1,536
Linear-118 [-1, 197, 2304] 1,769,472
Dropout-119 [-1, 8, 197, 197] 0
Linear-120 [-1, 197, 768] 590,592
Dropout-121 [-1, 197, 768] 0
Attention-122 [-1, 197, 768] 0
Identity-123 [-1, 197, 768] 0
LayerNorm-124 [-1, 197, 768] 1,536
Linear-125 [-1, 197, 2304] 1,771,776
GELU-126 [-1, 197, 2304] 0
Dropout-127 [-1, 197, 2304] 0
Linear-128 [-1, 197, 768] 1,770,240
Dropout-129 [-1, 197, 768] 0
Mlp-130 [-1, 197, 768] 0
Identity-131 [-1, 197, 768] 0
Block-132 [-1, 197, 768] 0
LayerNorm-133 [-1, 197, 768] 1,536
Identity-134 [-1, 768] 0
Linear-135 [-1, 1000] 769,000
================================================================
Total params: 48,602,344
Trainable params: 48,602,344
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 237.11
Params size (MB): 185.40
Estimated Total Size (MB): 423.09
----------------------------------------------------------------
2. torchstat
torchstat可计算模型参数(params)、浮点运算次数(FLOPs)、内存占用(memory)和运存占用(MemR+W),主打的就是一个方方面面。
from torchstat import stat
import torchvision.models as modelss
model = modelss.resnet18(pretrained=True)
stat(model.to('cpu'), (3, 224, 224))
输出为:
module name input shape output shape params memory(MB) MAdd Flops MemRead(B) MemWrite(B) duration[%] MemR+W(B)
0 conv1 3 224 224 64 112 112 9408.0 3.06 235,225,088.0 118,013,952.0 639744.0 3211264.0 7.48% 3851008.0
1 bn1 64 112 112 64 112 112 128.0 3.06 3,211,264.0 1,605,632.0 3211776.0 3211264.0 6.54% 6423040.0
2 relu 64 112 112 64 112 112 0.0 3.06 802,816.0 802,816.0 3211264.0 3211264.0 0.53% 6422528.0
3 maxpool 64 112 112 64 56 56 0.0 0.77 1,605,632.0 802,816.0 3211264.0 802816.0 5.68% 4014080.0
4 layer1.0.conv1 64 56 56 64 56 56 36864.0 0.77 231,010,304.0 115,605,504.0 950272.0 802816.0 5.85% 1753088.0
5 layer1.0.bn1 64 56 56 64 56 56 128.0 0.77 802,816.0 401,408.0 803328.0 802816.0 1.96% 1606144.0
6 layer1.0.relu 64 56 56 64 56 56 0.0 0.77 200,704.0 200,704.0 802816.0 802816.0 0.08% 1605632.0
7 layer1.0.conv2 64 56 56 64 56 56 36864.0 0.77 231,010,304.0 115,605,504.0 950272.0 802816.0 2.34% 1753088.0
8 layer1.0.bn2 64 56 56 64 56 56 128.0 0.77 802,816.0 401,408.0 803328.0 802816.0 2.01% 1606144.0
9 layer1.1.conv1 64 56 56 64 56 56 36864.0 0.77 231,010,304.0 115,605,504.0 950272.0 802816.0 2.46% 1753088.0
10 layer1.1.bn1 64 56 56 64 56 56 128.0 0.77 802,816.0 401,408.0 803328.0 802816.0 2.04% 1606144.0
11 layer1.1.relu 64 56 56 64 56 56 0.0 0.77 200,704.0 200,704.0 802816.0 802816.0 0.07% 1605632.0
12 layer1.1.conv2 64 56 56 64 56 56 36864.0 0.77 231,010,304.0 115,605,504.0 950272.0 802816.0 2.45% 1753088.0
13 layer1.1.bn2 64 56 56 64 56 56 128.0 0.77 802,816.0 401,408.0 803328.0 802816.0 2.09% 1606144.0
14 layer2.0.conv1 64 56 56 128 28 28 73728.0 0.38 115,505,152.0 57,802,752.0 1097728.0 401408.0 6.13% 1499136.0
15 layer2.0.bn1 128 28 28 128 28 28 256.0 0.38 401,408.0 200,704.0 402432.0 401408.0 0.29% 803840.0
16 layer2.0.relu 128 28 28 128 28 28 0.0 0.38 100,352.0 100,352.0 401408.0 401408.0 0.11% 802816.0
17 layer2.0.conv2 128 28 28 128 28 28 147456.0 0.38 231,110,656.0 115,605,504.0 991232.0 401408.0 3.17% 1392640.0
18 layer2.0.bn2 128 28 28 128 28 28 256.0 0.38 401,408.0 200,704.0 402432.0 401408.0 0.29% 803840.0
19 layer2.0.downsample.0 64 56 56 128 28 28 8192.0 0.38 12,744,704.0 6,422,528.0 835584.0 401408.0 3.46% 1236992.0
20 layer2.0.downsample.1 128 28 28 128 28 28 256.0 0.38 401,408.0 200,704.0 402432.0 401408.0 0.31% 803840.0
21 layer2.1.conv1 128 28 28 128 28 28 147456.0 0.38 231,110,656.0 115,605,504.0 991232.0 401408.0 1.70% 1392640.0
22 layer2.1.bn1 128 28 28 128 28 28 256.0 0.38 401,408.0 200,704.0 402432.0 401408.0 0.31% 803840.0
23 layer2.1.relu 128 28 28 128 28 28 0.0 0.38 100,352.0 100,352.0 401408.0 401408.0 0.09% 802816.0
24 layer2.1.conv2 128 28 28 128 28 28 147456.0 0.38 231,110,656.0 115,605,504.0 991232.0 401408.0 1.78% 1392640.0
25 layer2.1.bn2 128 28 28 128 28 28 256.0 0.38 401,408.0 200,704.0 402432.0 401408.0 0.33% 803840.0
26 layer3.0.conv1 128 28 28 256 14 14 294912.0 0.19 115,555,328.0 57,802,752.0 1581056.0 200704.0 2.81% 1781760.0
27 layer3.0.bn1 256 14 14 256 14 14 512.0 0.19 200,704.0 100,352.0 202752.0 200704.0 0.28% 403456.0
28 layer3.0.relu 256 14 14 256 14 14 0.0 0.19 50,176.0 50,176.0 200704.0 200704.0 0.08% 401408.0
29 layer3.0.conv2 256 14 14 256 14 14 589824.0 0.19 231,160,832.0 115,605,504.0 2560000.0 200704.0 3.50% 2760704.0
30 layer3.0.bn2 256 14 14 256 14 14 512.0 0.19 200,704.0 100,352.0 202752.0 200704.0 0.27% 403456.0
31 layer3.0.downsample.0 128 28 28 256 14 14 32768.0 0.19 12,794,880.0 6,422,528.0 532480.0 200704.0 2.51% 733184.0
32 layer3.0.downsample.1 256 14 14 256 14 14 512.0 0.19 200,704.0 100,352.0 202752.0 200704.0 0.27% 403456.0
33 layer3.1.conv1 256 14 14 256 14 14 589824.0 0.19 231,160,832.0 115,605,504.0 2560000.0 200704.0 5.04% 2760704.0
34 layer3.1.bn1 256 14 14 256 14 14 512.0 0.19 200,704.0 100,352.0 202752.0 200704.0 0.30% 403456.0
35 layer3.1.relu 256 14 14 256 14 14 0.0 0.19 50,176.0 50,176.0 200704.0 200704.0 0.08% 401408.0
36 layer3.1.conv2 256 14 14 256 14 14 589824.0 0.19 231,160,832.0 115,605,504.0 2560000.0 200704.0 2.06% 2760704.0
37 layer3.1.bn2 256 14 14 256 14 14 512.0 0.19 200,704.0 100,352.0 202752.0 200704.0 0.28% 403456.0
38 layer4.0.conv1 256 14 14 512 7 7 1179648.0 0.10 115,580,416.0 57,802,752.0 4919296.0 100352.0 3.63% 5019648.0
39 layer4.0.bn1 512 7 7 512 7 7 1024.0 0.10 100,352.0 50,176.0 104448.0 100352.0 0.25% 204800.0
40 layer4.0.relu 512 7 7 512 7 7 0.0 0.10 25,088.0 25,088.0 100352.0 100352.0 0.06% 200704.0
41 layer4.0.conv2 512 7 7 512 7 7 2359296.0 0.10 231,185,920.0 115,605,504.0 9537536.0 100352.0 5.48% 9637888.0
42 layer4.0.bn2 512 7 7 512 7 7 1024.0 0.10 100,352.0 50,176.0 104448.0 100352.0 0.24% 204800.0
43 layer4.0.downsample.0 256 14 14 512 7 7 131072.0 0.10 12,819,968.0 6,422,528.0 724992.0 100352.0 2.82% 825344.0
44 layer4.0.downsample.1 512 7 7 512 7 7 1024.0 0.10 100,352.0 50,176.0 104448.0 100352.0 0.26% 204800.0
45 layer4.1.conv1 512 7 7 512 7 7 2359296.0 0.10 231,185,920.0 115,605,504.0 9537536.0 100352.0 3.18% 9637888.0
46 layer4.1.bn1 512 7 7 512 7 7 1024.0 0.10 100,352.0 50,176.0 104448.0 100352.0 0.24% 204800.0
47 layer4.1.relu 512 7 7 512 7 7 0.0 0.10 25,088.0 25,088.0 100352.0 100352.0 0.06% 200704.0
48 layer4.1.conv2 512 7 7 512 7 7 2359296.0 0.10 231,185,920.0 115,605,504.0 9537536.0 100352.0 4.86% 9637888.0
49 layer4.1.bn2 512 7 7 512 7 7 1024.0 0.10 100,352.0 50,176.0 104448.0 100352.0 0.23% 204800.0
50 avgpool 512 7 7 512 1 1 0.0 0.00 0.0 0.0 0.0 0.0 0.64% 0.0
51 fc 512 1000 513000.0 0.00 1,023,000.0 512,000.0 2054048.0 4000.0 1.02% 2058048.0
total 11689512.0 25.65 3,638,757,912.0 1,821,399,040.0 2054048.0 4000.0 100.00% 101756992.0
=================================================================================================================================================================
Total params: 11,689,512
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
Total memory: 25.65MB
Total MAdd: 3.64GMAdd
Total Flops: 1.82GFlops
Total MemR+W: 97.04MB
其中MAdd为网络乘和加的理论量,数值上FLOPs为MAdd的一半
尴尬的是,torchstat似乎没法计算Vision Transformer类模型,会报错,应该是特征扁平化引起的,更尴尬的是我也不知道怎么改
3. profile
profile可计算模型参数(params)和浮点运算次数(FLOPs)。
import torch
from thop import profile
import torchvision.models as modelss
model = modelss.resnet18(pretrained=True)
dummy_input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, (dummy_input,))
print('the flops is {}G, the params is {}M'.format(round(flops / (10 ** 9), 2), round(params / (10 ** 6), 2)))
输出为:
the flops is 1.82G, the params is 11.69M
可见torchsummary、torchstat和profile计算的模型参数一致,torchstat和profile计算的FLOPs也一致
profile可用于Vision Transformer类模型,主打的就是一个雨露均沾
import torch
from thop import profile
import timm
model = timm.create_model('vit_small_patch16_224', pretrained=True)
dummy_input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, (dummy_input,))
print('the flops is {}G, the params is {}M'.format(round(flops / (10 ** 9), 2), round(params / (10 ** 6), 2)))
输出为:
the flops is 9.42G, the params is 48.6M
风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。