您现在的位置是:首页 >技术教程 >深度学习上采样下采样概念以及实现网站首页技术教程
深度学习上采样下采样概念以及实现
#pic_center =400x
系列文章:
参考博客
torch.nn.functional.interpolate函数
概念
上采样
简单说将图片放大,通过在像素键插入数据
1.插值,一般使用的是双线性插值,因为效果最好,虽然计算上比其他插值方式复杂,但是相对于卷积计算可以说不值一提,其他插值方式还有最近邻插值、三线性插值等;
2.转置卷积又或是说反卷积(Transpose Conv),通过对输入feature map间隔填充0,再进行标准的卷积计算,可以使得输出feature map的尺寸比输入更大;相比上池化,使用反卷积进行图像的“上采样”是可以被学习的(会用到卷积操作,其参数是可学习的)。
下采样
简单说是将图片缩小
主要目的有两个:1、使得图像符合显示区域的大小;2、生成对应图像的缩略图;
1.用stride为2的卷积层实现:卷积过程导致的图像变小是为了提取特征。下采样的过程是一个信息损失的过程,而池化层是不可学习的,用stride为2的可学习卷积层来代替pooling可以得到更好的效果,当然同时也增加了一定的计算量。
2.用stride为2的池化层实现:池化下采样是为了降低特征的维度。如Max-pooling和Average-pooling,目前通常使用Max-pooling,因为他计算简单而且能够更好的保留纹理特征。
实现
完整的代码在DDIM/models/diffusion.py
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(
x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=2,
padding=0)
def forward(self, x):
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
上采样
interpolate是插值的意思
x = torch.nn.functional.interpolate(
x, scale_factor=2.0, mode="nearest")
def interpolate(
input: Any,
size: Any | None = …,
scale_factor: Any | None = …,
mode: str = …,
align_corners: Any | None = …,
recompute_scale_factor: Any | None = …,
antialias: bool = …) -> None
参数:
-
input (Tensor) – 输入张量。
-
size (int or Tuple*[int] or* Tuple*[int,* int] or Tuple*[int,* int, int]) –输出大小。
-
scale_factor (float or Tuple*[float]*) – 指定输出为输入的多少倍数。如果输入为tuple,其也要制定为tuple类型。
- 注: size 和scale_factor指定一个即可
-
mode (str) –
可使用的上采样算法,有'nearest', 'linear', 'bilinear', 'bicubic' , 'trilinear'和'area'. 默认使用'nearest'。
-
align_corners (bool, optional) –
几何上,我们认为输入和输出的像素是正方形,而不是点。
如果设置为True,则输入和输出张量由其角像素的中心点对齐,从而保留角像素处的值。如果设置为False,则输入和输出张量由它们的角像素的角点对齐,插值使用边界外值的边值填充;当scale_factor保持不变时
,使该操作独立于输入大小
下采样
# 池化方式
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
#卷积形式
x = np.random.randint(1,10, [1,5,5])
x = torch.FloatTensor(x)
print(x)
in_channels = 1
conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
print(conv)
x = conv(x)
print(x.shape)