PyTorch框架中conv1d和conv2d的输入数据维度详解

文章目录

  • Conv1d
  • Conv2d
  • Conv1d

    Conv1d 的输入数据维度通常是一个三维张量,形状为 (batch_size, in_channels, sequence_length),其中:

    batch_size 表示当前输入数据的批次大小;
    in_channels 表示当前输入数据的通道数,对于文本分类任务通常为 1,对于图像分类任务通常为 3(RGB)、1(灰度)等;
    sequence_length 表示当前输入数据的序列长度,对于文本分类任务通常为词向量的长度,对于时序信号处理任务通常为时间序列的长度,对于图像分类任务通常为图像的高或宽。
    具体来说,Conv1d 模块会对第二维和第三维分别进行一维卷积操作,保留第一维(即批次大小)不变,输出一个新的三维张量,形状为 (batch_size, out_channels, new_sequence_length),其中 out_channels 表示卷积核的数量,new_sequence_length 表示卷积后的序列长度。

    示例:

    import torch
    import torch.nn as nn
    
    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv = nn.Sequential(
                nn.Conv1d(in_channels=1, out_channels=16, kernel_size=2),
                nn.ReLU(),
                # nn.MaxPool1d(kernel_size=2),
                nn.Conv1d(in_channels=16, out_channels=32, kernel_size=2),
                nn.ReLU(),
                # nn.MaxPool1d(kernel_size=2)
            )
            self.fc = nn.Linear(128, 2)
    
        def forward(self, x):
            x = x.unsqueeze(1)
            x = self.conv(x)
            x = x.view(x.size(0), -1)
            x = self.fc(x)
            return x
    x = torch.randn(200,6)
    # x = x.unsqueeze(1)
    net = Net()
    output = net(x)
    print(x.shape)
    

    Conv2d

    在 PyTorch 中,使用 nn.Conv2d 创建卷积层时,输入数据的维度应该是 (batch_size, input_channels, height, width)。其中,

    batch_size 表示当前输入数据的批次大小;
    input_channels 表示当前输入数据的通道数,对于彩色图像通常为 3(RGB),对于灰度图像通常为 1;
    height 和 width 分别表示输入数据的高和宽。因此,在 PyTorch 框架中,Conv2d 的输入数据维度应该是一个四维张量,形状为 (batch_size, input_channels, height, width)。

    物联沃分享整理
    物联沃-IOTWORD物联网 » PyTorch框架中conv1d和conv2d的输入数据维度详解

    发表评论