您现在的位置是:首页 >技术杂谈 >nn.conv1d的输入问题网站首页技术杂谈

nn.conv1d的输入问题

CG大魔王 2023-07-02 12:00:02
简介nn.conv1d的输入问题

Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)

  • in_channels(int) – 输入信号的通道。在文本分类中,即为词向量的维度
  • out_channels(int) – 卷积产生的通道。有多少个out_channels,就需要多少个1维卷积
  • kernel_size(int or tuple) - 卷积核的尺寸,卷积核的大小为(k,),第二个维度是由in_channels来决定的,所以实际上卷积大小为kernel_size*in_channels
  • stride(int or tuple, optional) - 卷积步长
  • padding (int or tuple, optional)- 输入的每一条边补充0的层数
  • dilation(int or tuple, `optional``) – 卷积核元素之间的间距
  • groups(int, optional) – 从输入通道到输出通道的阻塞连接数
  • bias(bool, optional) - 如果bias=True,添加偏置

在一维卷积中,卷积核的尺寸实际上等于:

kernel_size* in_channels

这里可以认为in_channels即词向量的长度(固定长度),而kernel_size决定卷积核沿着语句长度(或者说时间序列长度)移动的大小。

一般来说我们自己的数据格式为(b,s,e),b是batch_size,s是时间序列sequence,e是词向量embeding。

但根据实验,conv1d貌似是在最后一维扫描我们的时间序列,所以需要对我们的数据格式进行修改。

# data[b,s,e]
data = data.permute(0,2,1)  # data[b,s,e] -> data[b,e,s]

风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。