Pytorch 张量列表转换为张量 List of Tensor to Tensor 使用 torch.stack()

比如我现在有一个 List 每个元素是一个 shape 相同的 Tensor,我想将它们连接成一个统一的 Tensor。

a=b=c= torch.rand([4,5])
list_of_tensor = [a,b,c]
print(list_of_tensor)
输出:
[tensor([[2.1911e-01, 4.8939e-01, 5.1264e-01, 4.2860e-01, 4.2832e-01],
        [5.5072e-01, 9.0650e-01, 9.4573e-01, 2.9587e-01, 3.8711e-01],
        [8.3788e-01, 1.6358e-01, 3.9210e-01, 4.1913e-01, 4.8324e-01],
        [8.8101e-01, 8.8954e-04, 9.5448e-01, 9.0539e-01, 3.2410e-01]]), tensor([[2.1911e-01, 4.8939e-01, 5.1264e-01, 4.2860e-01, 4.2832e-01],
        [5.5072e-01, 9.0650e-01, 9.4573e-01, 2.9587e-01, 3.8711e-01],
        [8.3788e-01, 1.6358e-01, 3.9210e-01, 4.1913e-01, 4.8324e-01],
        [8.8101e-01, 8.8954e-04, 9.5448e-01, 9.0539e-01, 3.2410e-01]]), tensor([[2.1911e-01, 4.8939e-01, 5.1264e-01, 4.2860e-01, 4.2832e-01],
        [5.5072e-01, 9.0650e-01, 9.4573e-01, 2.9587e-01, 3.8711e-01],
        [8.3788e-01, 1.6358e-01, 3.9210e-01, 4.1913e-01, 4.8324e-01],
        [8.8101e-01, 8.8954e-04, 9.5448e-01, 9.0539e-01, 3.2410e-01]])]

使用 torch.stack() 来将它们堆叠为一个 Tensor。

tensor_all = torch.stack(list_of_tensor)
print(tensor_all)
print(tensor_all.shape)
输出:
tensor([[[2.1911e-01, 4.8939e-01, 5.1264e-01, 4.2860e-01, 4.2832e-01],
         [5.5072e-01, 9.0650e-01, 9.4573e-01, 2.9587e-01, 3.8711e-01],
         [8.3788e-01, 1.6358e-01, 3.9210e-01, 4.1913e-01, 4.8324e-01],
         [8.8101e-01, 8.8954e-04, 9.5448e-01, 9.0539e-01, 3.2410e-01]],

        [[2.1911e-01, 4.8939e-01, 5.1264e-01, 4.2860e-01, 4.2832e-01],
         [5.5072e-01, 9.0650e-01, 9.4573e-01, 2.9587e-01, 3.8711e-01],
         [8.3788e-01, 1.6358e-01, 3.9210e-01, 4.1913e-01, 4.8324e-01],
         [8.8101e-01, 8.8954e-04, 9.5448e-01, 9.0539e-01, 3.2410e-01]],

        [[2.1911e-01, 4.8939e-01, 5.1264e-01, 4.2860e-01, 4.2832e-01],
         [5.5072e-01, 9.0650e-01, 9.4573e-01, 2.9587e-01, 3.8711e-01],
         [8.3788e-01, 1.6358e-01, 3.9210e-01, 4.1913e-01, 4.8324e-01],
         [8.8101e-01, 8.8954e-04, 9.5448e-01, 9.0539e-01, 3.2410e-01]]])
torch.Size([3, 4, 5])

参考:

torch.stack — PyTorch 1.12 documentationhttps://discuss.pytorch.org/t/how-to-turn-a-list-of-tensor-to-tensor/8868/4

来源:Haulyn5

物联沃分享整理
物联沃-IOTWORD物联网 » Pytorch 张量列表转换为张量 List of Tensor to Tensor 使用 torch.stack()

发表评论