【python】【PyTorch】详细中文解释unsqueeze,代码和代码解读
目录
【python】【PyTorch】详细中文解释unsqueeze,代码和代码解读
unsqueeze() 函数的作用:
语法:
unsqueeze() 操作示例:
示例 1:将一个一维张量转换为二维张量
示例 2:在最后一维插入一个新维度
示例 3:负索引插入维度
示例 4:将二维张量转为三维张量
总结:
【python】【PyTorch】详细中文解释unsqueeze,代码和代码解读
在 PyTorch 中,
unsqueeze()是一个非常实用的函数,用于在张量的指定位置插入一个维度。
简而言之,
unsqueeze()通过增加一个长度为1的维度来扩展张量的维度。
unsqueeze() 函数的作用:
unsqueeze()函数将一个张量的维度增加 1。
这个函数常用于调整张量的形状,特别是在需要将一个二维或一维张量转换为更高维度的张量时。
语法:
torch.unsqueeze(input, dim)
input:输入张量。dim:指定要插入新维度的位置。dim 是一个整数,表示新维度的位置,取值范围是 [-input.dim() - 1, input.dim()]。如果 dim 为负数,它表示从最后一个维度开始计数。unsqueeze() 操作示例:
示例 1:将一个一维张量转换为二维张量
假设我们有一个一维张量 [1, 2, 3],我们希望通过 unsqueeze() 将其转换为一个二维张量,并在第 0 维度(最前面)插入一个新的维度。
import torch
# 创建一个一维张量
x = torch.tensor([1, 2, 3])
# 在第0维插入一个新的维度
y = torch.unsqueeze(x, 0)
print("Original shape:", x.shape) # 原始张量形状
print("New shape:", y.shape) # 新张量形状
print(y)
输出:
Original shape: torch.Size([3])
New shape: torch.Size([1, 3])
tensor([[1, 2, 3]])
x 的形状是 (3),表示这是一个包含 3 个元素的一维张量。torch.unsqueeze(x, 0) 后,在张量的第 0 维插入了一个新的维度。结果是一个形状为 (1, 3) 的二维张量。unsqueeze(0) 会在第一个维度(最前面)插入新的维度,表示这个张量现在有 1 行,3 列。示例 2:在最后一维插入一个新维度
假设我们希望将张量 [1, 2, 3] 变成形状为 (3, 1) 的二维张量,我们可以在第 1 维(最后一维)插入一个新的维度。
# 在第1维插入一个新的维度
z = torch.unsqueeze(x, 1)
print("Original shape:", x.shape)
print("New shape:", z.shape)
print(z)
输出:
Original shape: torch.Size([3])
New shape: torch.Size([3, 1])
tensor([[1],
[2],
[3]])
x 的形状是 (3),是一个一维张量。torch.unsqueeze(x, 1) 后,在第 1 维(即最后一个维度)插入了一个新的维度。结果是一个形状为 (3, 1) 的二维张量,表示这个张量现在有 3 行,1 列。示例 3:负索引插入维度
我们可以使用负数索引来指定维度的位置。负数表示从最后一个维度开始计数。
# 在倒数第一维(最后一维)插入一个新的维度
w = torch.unsqueeze(x, -1)
print("Original shape:", x.shape)
print("New shape:", w.shape)
print(w)
输出:
Original shape: torch.Size([3])
New shape: torch.Size([3, 1])
tensor([[1],
[2],
[3]])
torch.unsqueeze(x, -1) 等同于使用 torch.unsqueeze(x, 1),在张量的最后一个维度插入了一个新的维度。(3, 1) 的二维张量,表示张量现在有 3 行,1 列。示例 4:将二维张量转为三维张量
如果我们有一个形状为 (2, 3) 的二维张量,并希望将其转换为三维张量(例如,插入一个维度表示批次大小),我们可以使用 unsqueeze()。
# 创建一个二维张量
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 在第0维插入新维度
b = torch.unsqueeze(a, 0)
print("Original shape:", a.shape)
print("New shape:", b.shape)
print(b)
输出:
Original shape: torch.Size([2, 3])
New shape: torch.Size([1, 2, 3])
tensor([[[1, 2, 3],
[4, 5, 6]]])
a 的形状是 (2, 3),表示它有 2 行,3 列。torch.unsqueeze(a, 0) 后,在第 0 维(最前面)插入了一个新的维度,结果是一个形状为 (1, 2, 3) 的三维张量,表示这个张量现在有 1 个批次,2 行,3 列。总结:
unsqueeze() 函数用于增加张量的维度,可以通过指定维度位置插入一个新的维度(长度为 1)。作者:资源存储库