Python中torch.nn.Softmax()的用法和示例(dim=1和dim=2)

用法

torch.nn.Softmax() 是 PyTorch 中的一个类,用于实现 softmax 函数。softmax 函数是一种常用的激活函数,它可以将一个向量转换成一个概率分布,使得每个元素都是非负数且和为 1。softmax 函数通常在分类问题中使用,可以将一个多分类问题转换成多个二分类问题,从而得到每个类别的概率分布。

语法格式

torch.nn.Softmax(dim=None)

其中,dim 是要进行 softmax 的维度,缺省值为 None,表示对最后一维进行 softmax。

例子dim=1

import torch

x = torch.randn(2, 3)
print('x:', x)

softmax = torch.nn.Softmax(dim=1)
y = softmax(x)
print('y:', y)

输出

x: tensor([[ 1.3551,  0.3739,  0.5962],
            [-0.3465,  1.4536,  0.4576]])
y: tensor([[0.4989, 0.2238, 0.2773],
            [0.1018, 0.7325, 0.1656]])

在这个例子中,我们先使用 torch.randn() 生成一个大小为 (2, 3) 的张量 x。然后,我们定义一个 torch.nn.Softmax() 对象 softmax,将维度 dim=1 作为参数传入。接着,我们将张量 x 作为输入,调用 softmax() 方法,得到一个大小为 (2, 3) 的张量 y,表示经过 softmax 函数处理后的结果。可以看到,每行元素都是非负数且和为 1。

需要注意的是,torch.nn.Softmax() 在实际使用中通常与交叉熵损失函数一起使用,用于多分类问题的训练。

例子dim=2

dim=2 表示在第二个维度上进行 softmax 计算。

import torch

# 创建一个3D张量,形状为(2, 3, 4)
x = torch.randn(2, 3, 4)

# 使用dim=2进行softmax计算
softmax = torch.nn.Softmax(dim=2)
y = softmax(x)

print("Original tensor:")
print(x)
print("\nSoftmax tensor:")
print(y)

输出

Original tensor:
tensor([[[ 0.4769, -0.1835, -0.3167, -1.1385],
         [-0.5912,  0.4781, -0.6784, -0.4377],
         [-0.9624, -0.0528, -1.4899, -1.5107]],

        [[ 0.1033, -0.0107, -0.4888, -1.5489],
         [ 0.4071,  0.2163, -0.3167, -0.1252],
         [-1.7984, -1.1394, -1.5384, -0.3176]]])

Softmax tensor:
tensor([[[0.4669, 0.1745, 0.1527, 0.2060],
         [0.1668, 0.4647, 0.1311, 0.2374],
         [0.3005, 0.5028, 0.1452, 0.0515]],

        [[0.4474, 0.2594, 0.1248, 0.1684],
         [0.3616, 0.2983, 0.1426, 0.1975],
         [0.1055, 0.1555, 0.1084, 0.6307]]])

可以看到,原始张量中的每个值都经过了 softmax 计算,第二个维度上的值都被归一化到了 0 到 1 之间,并且在每个样本上的值之和都为 1。

总结

当张量的形状为二维时,dim=1 和 dim=2 的效果类似,因为此时张量的行数等于时间步数,列数等于特征数。在这种情况下,dim=1 和 dim=2 都将每一行的值进行归一化,输出的结果相同。

但是当张量的形状为三维及以上时,dim=1 和 dim=2 的效果就不同了。在序列到序列的任务中,通常需要对每个时间步上的输出进行归一化,因此需要使用 torch.nn.Softmax(dim=2)。在分类任务中,通常需要对每个样本的输出进行归一化,因此需要使用 torch.nn.Softmax(dim=1)。

总之,dim 参数的选择应该根据具体的任务需求来进行选择,而不是根据形状的维数来确定。

物联沃分享整理
物联沃-IOTWORD物联网 » Python中torch.nn.Softmax()的用法和示例(dim=1和dim=2)

发表评论