详解pytorch之tensor的拼接

tensor经常需要进行拼接、拆分与调换维度,比如通道拼接,比如通道调至最后一个维度等,本文的目的是详细讨论一下具体是怎么拼接的。如果本来就理解这其中的原理的童鞋就不用往下看了,肯定觉得啰嗦了~~

拼接即两个tensor按某一维度进行拼接,分两种情况,一个是不新增维度,一个是新增维度。

1.torch.cat(tensors, dim=0, *, out=None)  —不新增维度

tensors即要拼接的tensor列表或元组,按dim指定的维度进行拼接。

如下分别为按第0维拼接与按第1维拼接:

import torch

a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])

result = torch.cat([a, b], 0)
print(result)

结果:

tensor([[1, 2],
[3, 4],
[5, 6],
[7, 8]])

result = torch.cat([a, b], 1)
print(result)

结果:

tensor([[1, 2, 5, 6],
[3, 4, 7, 8]])

二维的情况其实还是蛮好理解的,不过还是先以二维的情况来讨论一下为什么是这样拼接。

注,本文全文讨论的维度都是从0维开始算的

如下图,一共是两个张量,红框表示按第0维排列的元素,a的0维有两个元素,b的0维也有两个元素。绿色表示按第1维排列的元素(自然是在第0维的某个元素里面来数了,比如a[0]),a[0],a[1],b[0],b[1]都有两个元素

接下来就是讨论怎么拼接,如果是按第0维拼接,即按照第0个维度把a、b的数据拼接起来,怎么拼呢,就是b的两个红框依次移到a的末尾就行了,可以理解为直接将两个黑框合成一个。

如果是按第1维拼接呢。那就是把a的第1个红框与b的第1个红框拼成一个,把a的第2个红框与b的第2个红框拼成一个。如下图:

按照如上的拼法,自然会有一个要求,a有两个红框,b就必须有两个红框,不然的话就不能按第1维拼接了。

接着来个稍微复杂一点的,下面这个结果是什么呢?

import torch

a = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
b = torch.tensor([[[9, 10], [11, 12]], [[13, 14], [15, 16]]])

result = torch.cat([a, b], 1)
print(result)

下面还是按照上面的分析方法来看。首先画图如下,红框表示第0维的元素,绿框是第1维的元素,蓝框是第2维的元素。

接下来按照第1维拼接,突然发现,这跟之前那个按照第1维拼接是一模一样的啊,还是把a的第1个红框与b的第1个红框拼成一个,把a的第2个红框与b的第2个红框拼成一个。

运行结果即:

tensor([[[ 1,  2],
[ 3,  4],
[ 9, 10],
[11, 12]],

[[ 5,  6],
[ 7,  8],
[13, 14],
[15, 16]]])

这里说明一下,上面1,2,3,4的红框与9,10,11,12的红框拼接,为啥不是下面这样的拼接结果呢?

因为上面这个就不是把两个红的合成一个,而是把两个绿的合成一个了,它实际上就是按第2维进行拼接了,接下来我们要讨论的就是:那如果是按第2维拼接呢

import torch

a = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
b = torch.tensor([[[9, 10], [11, 12]], [[13, 14], [15, 16]]])

result = torch.cat([a, b], 2)
print(result)

还是一样的道理,这次就是把a的第1个红框的第1个绿框与b的第1个红框的第1个绿框合成一个,把a的第1个红框的第2个绿框与b的第1个红框的第2个绿框合成一个,以此类推,如下图。

拼接的要求也变成了:a有两个红框,b也得有两个红框,a的红框里有两个绿框,b的红框里也得有两个绿框。

运行结果为:

tensor([[[ 1,  2,  9, 10],
[ 3,  4, 11, 12]],

[[ 5,  6, 13, 14],
[ 7,  8, 15, 16]]])

那可能还有一个问题,如果我想把1和9拼起来,把2和10拼起来呢?cat是没法完成这个操作了,这里先放个答案:运行如下代码即可做到

import torch

a = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
b = torch.tensor([[[9, 10], [11, 12]], [[13, 14], [15, 16]]])

result = torch.stack([a, b], 3)
print(result)

运行结果为:

tensor([[[[ 1,  9],
[ 2, 10]],

[[ 3, 11],
[ 4, 12]]],

[[[ 5, 13],
[ 6, 14]],

[[ 7, 15],
[ 8, 16]]]])

接下来具体说明一个stack。

2.torch.stack(tensors, dim=0, *, out=None)  —–新增维度

tensors即要拼接的tensor列表或元组,按dim指定的维度进行拼接。参数看起来跟cat一样,但是这里的维度含义并不一样,cat的拼接即按指定维度把数据拼起来,并不会新增维度。而stack是什么呢,下面结合具体的例子来说明,为了简化,这里只用下面这个例子来进行说明。

(1)在第0维新增维度进行拼接

import torch

a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])

result = torch.stack([a, b], 0)
print(result)

cat是把两个黑框合成一个,而stack呢,并不会合成一个,它是创建了一个两个黑框,第一个给a,第二个给b,结果就不是2维了,而是3维了,这次黑框就留下来了哦,它表示新的第0维的元素了,红框变成了第1维的元素,绿框变成了第2维的元素。

运行结果如下:

tensor([[[1, 2],
[3, 4]],

[[5, 6],
[7, 8]]])

(2)在第1维新增维度进行拼接

import torch

a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])

result = torch.stack([a, b], 1)
print(result)

cat是把两个红框合成一个红框,而stack是在红框中新增两个黑框,a的第1个红框里的元素放第1个黑框,b的第1个红框里的元素放第2个黑框,以此类推,如下图。

运行结果如下:

tensor([[[1, 2],
[5, 6]],

[[3, 4],
[7, 8]]])

(3)第2维新增维度进行拼接

import torch

a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])

result = torch.stack([a, b], 2)
print(result)

cat是没法按第2维拼接的,因为没有第2维。stack就是在第2维新增一个维度,即在每个绿框里新增两个黑框,分别把a、b对应的绿框里的数据填进去,如下图。

3.补充一下,上述的示例为了方便展示,都是合并两个张量,实际上自然可以合并超过两个张量,只不过上面是把两个框合成一个,那3个张量就是把3个框合成一个了,不详述,可以自行试验。

来源:扫地僧1234

物联沃分享整理
物联沃-IOTWORD物联网 » 详解pytorch之tensor的拼接

发表评论