您现在的位置是:首页 >其他 >Pytorch 模型集成(Model Ensembling)网站首页其他

Pytorch 模型集成(Model Ensembling)

Yuetianw 2023-07-02 12:00:02
简介Pytorch 模型集成(Model Ensembling)

Pytorch 模型集成(Model Ensembling)

这篇文章介绍如何使用torch.vmap对模型集成进行向量化。

模型集成将多个模型的预测结果组合在一起。传统上,这是通过分别在某些输入上运行每个模型,然后组合预测结果来完成的。但是,如果您正在运行具有相同架构的模型,则可以使用torch.vmap将它们组合在一起。vmap是一个函数变换,它将函数映射到输入张量的维度上。其中一个用例是通过向量化消除for循环并加速它们。

让我们使用简单MLP的集合来演示如何做到这一点。

import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)

# Here's a simple MLP
class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.flatten(1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x

让我们生成一批虚拟数据,并假装我们正在使用MNIST数据集。因此,虚拟图像为28x28,并且我们有一个大小为64的小批量。此外,假设我们要组合来自10个不同模型的预测结果。

device = 'cuda'
num_models = 10

data = torch.randn(100, 64, 1, 28, 28, device=device)
targets = torch.randint(10, (6400,), device=device)

models = [SimpleMLP().to(device) for _ in range(num_models)]

我们有几个选项可用于生成预测结果。也许我们想为每个模型提供不同的随机小批量数据。或者,也许我们想通过每个模型运行相同的小批量数据来组合预测结果(例如,如果我们正在测试不同模型初始化的效果)。

#1
minibatches = data[:num_models]
predictions_diff_minibatch_loop = [model(minibatch) for model, minibatch in zip(models, minibatches)]

#2
minibatch = data[0]
predictions2 = [model(minibatch) for model in models]

让我们使用vmap加速for循环。我们必须首先准备好使用vmap的模型。

首先,让我们通过堆叠每个参数来将模型的状态组合在一起。例如,model [i] .fc1.weight具有形状[784,128];我们将堆叠10个模型的.fc1.weight以产生形状为[10,784,128]的大权重。

PyTorch提供了torch.func.stack_module_state方便函数来执行此操作。

from torch.func import stack_module_state

params, buffers = stack_module_state(models)

接下来,我们需要定义一个要vmap的函数。该函数应该在给定参数、缓冲区和输入的情况下使用这些参数、缓冲区和输入运行模型。我们将使用torch.func.functional_call来帮助:

from torch.func import functional_call
import copy

# Construct a "stateless" version of one of the models. It is "stateless" in
# the sense that the parameters are meta Tensors and do not have storage.
base_model = copy.deepcopy(models[0])
base_model = base_model.to('meta')

def fmodel(params, buffers, x):
    return functional_call(base_model, (params, buffers), (x,))

选项1:使用每个模型的不同小批量获取预测结果。

默认情况下,vmap将函数映射到传递给该函数的所有输入的第一个维度上。在使用stack_module_state之后,每个params和buffers都具有额外的大小为“num_models”的维度,并且minibatches具有大小为“num_models”的维度。

print([p.size(0) for p in params.values()]) # show the leading 'num_models' dimension

assert minibatches.shape == (num_models, 64, 1, 28, 28) # verify minibatch has leading dimension of size 'num_models'

from torch import vmap

predictions1_vmap = vmap(fmodel)(params, buffers, minibatches)

# verify the vmap predictions match the
assert torch.allclose(predictions1_vmap, torch.stack(predictions_diff_minibatch_loop), atol=1e-3, rtol=1e-5)

选项2:使用相同的小批量数据获取预测结果。

vmap具有in_dims arg,指定要映射的维度。通过使用None,我们告诉vmap希望应用于所有10个模型的相同小批量。

predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch)

assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-3, rtol=1e-5)

快速说明:关于哪些类型的函数可以被vmap转换存在限制。最好转换的功能是纯函数:输出仅由具有没有副作用(例如突变)的输入确定的函数。vmap无法处理任意Python数据结构的突变,但它能够处理许多原地PyTorch操作。

实际表现

from torch.utils.benchmark import Timer
without_vmap = Timer(
    stmt="[model(minibatch) for model, minibatch in zip(models, minibatches)]",
    globals=globals())
with_vmap = Timer(
    stmt="vmap(fmodel)(params, buffers, minibatches)",
    globals=globals())
print(f'Predictions without vmap {without_vmap.timeit(100)}')
print(f'Predictions with vmap {with_vmap.timeit(100)}')
Predictions without vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f655e970970>
[model(minibatch) for model, minibatch in zip(models, minibatches)]
  1.53 ms
  1 measurement, 100 runs , 1 thread
Predictions with vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f655e971ae0>
vmap(fmodel)(params, buffers, minibatches)
  631.15 us
  1 measurement, 100 runs , 1 thread

使用vmap会有很大的速度提升!

一般来说,用vmap进行的矢量化应该比在for-loop中运行一个函数要快,而且与手动批处理有竞争力。不过也有一些例外,比如我们没有为某个特定的操作实现vmap规则,或者底层内核没有针对旧硬件(GPU)进行优化。

风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。