Pytorch中torch.stack() 函数解析

一. torch.stack()函数解析

1. 函数说明:

1.1 官网torch.stack(),函数定义及参数说明如下图所示:

函数定义及参数说明

1.2 函数功能

沿一个新维度对输入一系列张量进行连接,序列中所有张量应为相同形状,stack 函数返回的结果会新增一个维度。也即是把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度上面进行堆叠。

1.3 参数列表

  • tensors :为一系列输入张量,类型为turple和List
  • dim :新增维度的(下标)位置,当dim = -1时默认最后一个维度;范围必须介于 0 到输入张量的维数之间,默认是dim=0,在第0维进行连接
  • 返回值:输出新增维度后的张量
  • 2. 代码举例

    2.1 dim = 0 : 在第0维进行连接,相当于在行上进行组合(输入张量为一维,输出张量为两维)

    import torch
    #二维输入张量a,b
    a = torch.tensor([1, 2, 3])
    b = torch.tensor([11, 22, 33])
    c = torch.stack([a, b],dim=0)#在第0维进行连接,相当于在行上进行组合(输入张量为一维,输出张量为两维)
    print(a)
    print(b)
    print(c)
    
    输出结果如下:
    tensor([1, 2, 3])
    tensor([11, 22, 33])
    tensor([[ 1,  2,  3],
            [11, 22, 33]])
    

    2.2 dim = 1 :在第1维进行连接,相当于在对应行上面对列元素进行组合(输入张量为一维,输出张量为两维)

    import torch
    #二维输入张量a,b
    a = torch.tensor([1, 2, 3])
    b = torch.tensor([11, 22, 33])
    c = torch.stack([a, b],dim=1)#在第1维进行连接,相当于在对应行上面对列元素进行组合(输入张量为一维,输出张量为两维)
    print(a)
    print(b)
    print(c)
    
    输出结果如下:
    tensor([1, 2, 3])
    tensor([11, 22, 33])
    tensor([[ 1, 11],
            [ 2, 22],
            [ 3, 33]])
    

    2.3 dim=0:表示在第0维进行连接,相当于在通道维度上进行组合(输入张量为两维,输出张量为三维),注意:此处输入张量维度为二维,因此dim最大只能为2。

    import torch
    #二维输入张量a,b
    a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
    b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]])
    c = torch.stack([a, b],dim=0)#在第0维进行连接,相当于在通道维度上进行组合(输入张量为两维,输出张量为三维)
    print(a)
    print(b)
    print(c)
    
    输出结果如下所示:
    tensor([[1, 2, 3],
            [4, 5, 6],
            [7, 8, 9]])
    tensor([[11, 22, 33],
            [44, 55, 66],
            [77, 88, 99]])
    tensor([[[ 1,  2,  3],
             [ 4,  5,  6],
             [ 7,  8,  9]],
    
            [[11, 22, 33],
             [44, 55, 66],
             [77, 88, 99]]])
    

    2.4 dim=1:表示在第1维进行连接,相当于对相应通道中每个行进行组合,注意:此处输入张量维度为二维,因此dim最大只能为2。

    import torch
    #二维输入张量a,b
    a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
    b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]])
    c = torch.stack([a, b], 1)#在第1维进行连接,相当于对相应通道中每个行进行组合
    print(a)
    print(b)
    print(c)
    
    输出结果如下所示:
    tensor([[1, 2, 3],
            [4, 5, 6],
            [7, 8, 9]])
    tensor([[11, 22, 33],
            [44, 55, 66],
            [77, 88, 99]])
    tensor([[[ 1,  2,  3],
             [11, 22, 33]],
    
            [[ 4,  5,  6],
             [44, 55, 66]],
    
            [[ 7,  8,  9],
             [77, 88, 99]]])
    

    2.5 dim=2:表示在第2维进行连接,相当于对相应行中每个列元素进行组合,注意:此处输入张量维度为二维,因此dim最大只能为2。

    import torch
    #二维输入张量a,b
    a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
    b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]])
    c = torch.stack([a, b], 2)#在第2维进行连接,相当于对相应行中每个列元素进行组合
    print(a)
    print(b)
    print(c)
    
    输出结果如下所示:
    tensor([[1, 2, 3],
            [4, 5, 6],
            [7, 8, 9]])
    tensor([[11, 22, 33],
            [44, 55, 66],
            [77, 88, 99]])
    tensor([[[ 1, 11],
             [ 2, 22],
             [ 3, 33]],
    
            [[ 4, 44],
             [ 5, 55],
             [ 6, 66]],
    
            [[ 7, 77],
             [ 8, 88],
             [ 9, 99]]])
    

    2.6 dim=3:表示在第3维进行连接,相当于对相应行中每个列元素进行组合(输入维度大小为3维,因此dim=3最后一维始终代表为列),注意:此处输入张量维度为三维,因此dim最大只能为3。

    import torch
    #三维输入张量a,b
    a = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],[[10, 20, 30], [40, 50, 60], [70, 80, 90]]])
    b = torch.tensor([[[11, 22, 33], [44, 55, 66], [77, 88, 99]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]])
    c = torch.stack([a, b], 3)#表示在第3维进行连接,相当于对相应行中每个列元素进行组合(最后一维是第三维,始终代表为列)
    print(a)
    print(b)
    print(c)
    
    输出结果如下所示:
    tensor([[[ 1,  2,  3],
             [ 4,  5,  6],
             [ 7,  8,  9]],
    
            [[10, 20, 30],
             [40, 50, 60],
             [70, 80, 90]]])
    tensor([[[ 11,  22,  33],
             [ 44,  55,  66],
             [ 77,  88,  99]],
    
            [[110, 220, 330],
             [440, 550, 660],
             [770, 880, 990]]])
    tensor([[[[  1,  11],
              [  2,  22],
              [  3,  33]],
    
             [[  4,  44],
              [  5,  55],
              [  6,  66]],
    
             [[  7,  77],
              [  8,  88],
              [  9,  99]]],
    
    
            [[[ 10, 110],
              [ 20, 220],
              [ 30, 330]],
    
             [[ 40, 440],
              [ 50, 550],
              [ 60, 660]],
    
             [[ 70, 770],
              [ 80, 880],
              [ 90, 990]]]])
    

    2.7 dim=4 (错误维度:因为此处输入张量维度为三维,所以dim最大只能为3,此处维度为4,因此会报错)

    import torch
    #三维输入张量a,b
    a = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],[[10, 20, 30], [40, 50, 60], [70, 80, 90]]])
    b = torch.tensor([[[11, 22, 33], [44, 55, 66], [77, 88, 99]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]])
    c = torch.stack([a, b], 4)
    print(a)
    print(b)
    print(c)
    
    输出错误:
    IndexError: Dimension out of range (expected to be in range of [-4, 3], but got 4)
    
    

    来源:cv_lhp

    物联沃分享整理
    物联沃-IOTWORD物联网 » Pytorch中torch.stack() 函数解析

    发表评论