1、Conv1d 定义
class torch.nn.Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)

自然语言处理中一个句子序列,一维的,所以使用Conv1d,此时卷积核(没有batch_size,参数是共享的)除去chanel,也是一维的。 

2、参数 

in_channels(int) – 输入信号的通道。在文本分类中,即为词向量的维度
out_channels(int) – 卷积产生的通道。有多少个out_channels,就需要多少个1维卷积
kernel_size(int or tuple) – 卷积核的尺寸,卷积核的大小为(k,),第二个维度是由in_channels来决定的,所以实际上卷积大小为kernel_size*in_channels
stride(int or tuple, optional) – 卷积步长
padding (int or tuple, optional)- 输入的每一条边补充0的层数
dilation(int or tuple, `optional``) – 卷积核元素之间的间距
groups(int, optional) – 从输入通道到输出通道的阻塞连接数
bias(bool, optional) – 如果bias=True,添加偏置

3、shape:


4、代码例子:

官方例子1: 

input1 = torch.randn(20, 16, 50) # torch.Size([20, 16, 50])
m = nn.Conv1d(16, 33, 3, stride=2) # Conv1d(16, 33, kernel_size=(3,), stride=(2,))
output = m(input1) # torch.Size([20, 33, 24])

验证Shape中conv1d 关于输出Lout的公式:Lout =⌊50+2*0 – 1*(3-1) -1⌋/2 + 1 = 24

例子2:

import torch
import torch.nn as nn

# 卷积大小为kernel_size*in_channels, 此处也即 3 * 4, 每个卷积核产生一维的输出数据,长度与输入数据的长度和stride有关,根据ouotput可知是3,第二个参数2也就卷积核的数量
m = nn.Conv1d(4, 2, 3, stride=2)

# 第一个参数理解为batch的大小,输入是4 * 9格式 
input = torch.randn(1, 4, 9)
print(input)
output = m(input)
print(output)
print(output.size())

 输出如下:

tensor([[[-0.2105, -1.0958,  0.7299,  1.1003,  2.3175,  0.8186, -1.7510,  -0.1925,  0.8591],
         [ 1.0991, -0.3016,  1.5633,  0.6162,  0.3150,  1.0413,  1.0571,  -0.7014,  0.2239],
         [-0.0658,  0.4755, -0.6653, -0.0696,  0.3483, -0.0360, -0.4665,   1.2606,  1.3365],
         [-0.0186, -1.1802, -0.8835, -1.1813, -0.5145, -0.0534, -1.2568,   0.3211, -2.4793]]])

tensor([[[-0.8012,  0.0589,  0.1576, -0.8222],
         [-0.8231, -0.4233,  0.7178, -0.6621]]], grad_fn=<SqueezeBackward1>)

torch.Size([1, 2, 4])
 

第一个卷积核进行如下操作:

 得到输出1*4的输出:

[-0.8012, 0.0589, 0.1576, -0.8222]

第二个卷积核进行类似操作:

得到输出1*4的输出:

[-0.8231, -0.4233, 0.7178, -0.6621]

合并得到最后的2*4的结果:

 

输入的input为 4 * 9 ,输出为 2 * 4。

验证Shape中conv1d 关于输出Lout的公式:Lout =⌊ 9+2*0 – 1*(3-1) -1⌋/2 + 1 = 4

 参考:

1、pytorch之nn.Conv1d详解_若之辰的博客-CSDN博客_conv1d

2、简要解释什么是Conv1d,Conv2d,Conv3d_音程的博客-CSDN博客_conv1d

3、torch.nn.Conv1d及一维卷积举例说明_拉轰小郑郑的博客-CSDN博客_torch一维卷积

物联沃分享整理
物联沃-IOTWORD物联网 » pytorch之nn.Conv1d详解

发表评论