Python中repeat()和repeat_interleave()函数的详细比较和分析

最近在学习沐神的d2l的时候,深受其中代码的折磨,有些函数真的是从来没见过,组合起来更是让人头皮发麻,根本看不懂代码在写些什么。

写这篇文章,主要是为了总结一下Python当中的repeat()函数和repeat_interleave()函数,这两个函数在应用于Pytorch和Numpy数组的时候得到的结果也是不一样的,所以有很大的槽点需要注意!

首先是总结应用于Pytorch领域的repeat()函数和repeat_interleave()函数:

1.repeat()

话不多说,直接上代码:

import torch

# 创建一个张量
original_tensor = torch.tensor([[1, 2], [3, 4]])

# 沿着行和列方向分别重复张量
repeated_tensor = original_tensor.repeat(2, 3)
print(repeated_tensor)

输出为:

tensor([[1, 2, 1, 2, 1, 2],
        [3, 4, 3, 4, 3, 4],
        [1, 2, 1, 2, 1, 2],
        [3, 4, 3, 4, 3, 4]])

不难从输出当中得出结论:.repeat(2, 3)就是沿着第一个维度(行)重复 2 次,沿着第二个维度(列)重复 3 次,最终生成了一个 4×6 的张量。注意repeat是一组元素一组元素地重复,这与下面的repeat_interleave()函数是不相同的。

2.repeat_interleave()

该函数与repeat()函数的区别在于,它是沿着指定的维度复制张量元素

①不指定dim,重复次数为2次,表示将把给定的输入张量展平(flatten)为向量,然后将每个元素重复2次,并返回重复后的张量。

a = torch.randn(3,2)
a,a.repeat_interleave(2)

输出为:

(tensor([[-1.03, -0.32],
         [ 0.43,  0.78],
         [ 0.91, -0.11]]),
 tensor([-1.03, -1.03, -0.32, -0.32,  0.43,  0.43,  0.78,  0.78,  0.91,  0.91,
         -0.11, -0.11]))

②输入二维张量,指定dim=0,重复次数为3次,表示把输入张量每行元素重复3次

a = torch.randn(3,2)
a,torch.repeat_interleave(a,3,dim=0)

输出为:

(tensor([[ 0.14,  1.47],
         [-1.52, -0.62],
         [-0.24, -0.27]]),
 tensor([[ 0.14,  1.47],
         [ 0.14,  1.47],
         [ 0.14,  1.47],
         [-1.52, -0.62],
         [-1.52, -0.62],
         [-1.52, -0.62],
         [-0.24, -0.27],
         [-0.24, -0.27],
         [-0.24, -0.27]]))

③输入二维张量,指定dim=1,重复次数为3次,表示把输入张量每列元素重复3次

a = torch.randn(3,2)
a,torch.repeat_interleave(a,3,dim=1)

输出为:

(tensor([[-0.81,  0.56],
         [-2.41, -0.56],
         [ 0.38, -0.90]]),
 tensor([[-0.81, -0.81, -0.81,  0.56,  0.56,  0.56],
         [-2.41, -2.41, -2.41, -0.56, -0.56, -0.56],
         [ 0.38,  0.38,  0.38, -0.90, -0.90, -0.90]]))

④输入二维张量,指定dim=0,重复次数为一个张量列表[n1,n2,n3],表示在(dim=0)对应行上面重复n1,n2,n3遍,张量列表的长度必须与dim=0的维度的长度一样,否则会报错

a = torch.randn(3,2)
a,torch.repeat_interleave(a,torch.tensor([2,3,4]),dim=0)
#表示第一行重复2遍,第二行重复3遍,第三行重复4遍

输出为:

(tensor([[-0.79,  0.54],
         [-0.47, -0.25],
         [-0.13,  1.03]]),
 tensor([[-0.79,  0.54],
         [-0.79,  0.54],
         [-0.47, -0.25],
         [-0.47, -0.25],
         [-0.47, -0.25],
         [-0.13,  1.03],
         [-0.13,  1.03],
         [-0.13,  1.03],
         [-0.13,  1.03]]))

总结:可以看出,两个函数方法最大的区别就是repeat_interleave是一个元素一个元素地重复,而repeat是一组元素一组元素地重复

那到这里就完了吗?完全没有!经过测试发现,以上都是repeat()函数和repeat_interleave()函数应用于pytorch的tensor张量,但当它们应用于numpy数组时,结果又是不一样的!

例如:

test_array = torch.arange(9).reshape(3, 3)
print('采用torch tensor原始:\n', test_array)
print('采用torch tensor的repeat函数:\n', test_array.repeat(2, 1))
print('采用torch tensor的repeat_interleave函数:\n', test_array.repeat_interleave(2, dim=0))
test_array2 = np.arange(9).reshape(3, 3)
print('采用numpy array原始:\n', test_array2)
print('采用numpy array的repeat函数:\n', test_array2.repeat(2, 1))
print('采用numpy array的repeat_interleave函数:\n', test_array2.repeat_interleave(2, dim=0))

我们运行上述代码,看看结果怎么样:

采用torch tensor原始:
 tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])
采用torch tensor的repeat函数:
 tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8],
        [0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])
采用torch tensor的repeat_interleave函数:
 tensor([[0, 1, 2],
        [0, 1, 2],
        [3, 4, 5],
        [3, 4, 5],
        [6, 7, 8],
        [6, 7, 8]])
采用numpy array原始:
 [[0 1 2]
 [3 4 5]
 [6 7 8]]
采用numpy array的repeat函数:
 [[0 0 1 1 2 2]
 [3 3 4 4 5 5]
 [6 6 7 7 8 8]]
Traceback (most recent call last):
  File "D:/PythonProject/DiveIntoDeepLearning(LiMu)/main.py", line 82, in <module>
    print('采用numpy array的repeat_interleave函数:\n', test_array2.repeat_interleave(2, dim=0))
AttributeError: 'numpy.ndarray' object has no attribute 'repeat_interleave'

Process finished with exit code 1

从输出结果可以得出以下结论:

①pytorch当中的numpy.repeat(2, 1)是指在第一个维度(行)上复制两次,在第二个维度(列)上复制1次,并且是一组元素一组元素地复制;Numpy当中的.repeat(2, 1)是指在第二个维度上(列,对应dim值为1)复制两次,并且是一个一个元素的复制

②numpy没有repeat_interleave函数

物联沃分享整理
物联沃-IOTWORD物联网 » Python中repeat()和repeat_interleave()函数的详细比较和分析

发表评论