pytorch之深入理解collate_fn

文章目录

  • 前言
  • dataset
  • dataloader之collate_fn
  • 应用情形
  • 前言

    import torch.utils.data as tud
    

    collate_fn:即用于collate的function,用于整理数据的函数。
    说到整理数据,你当然也要会用tud.Dataset,因为这个你定义好后,才会产生数据嘛,产生了数据我们才能整理数据嘛,而整理数据我们使用collate_fn

    dataset

    我们必须先看看tud.Dataset如何使用,以一个例子为例:

    class mydataset(tud.Dataset):
        def __init__(self,data):
            self.data=data
        def __len__(self):#必须重写
            return len(self.data)
        def __getitem__(self,idx):#必须重写
            return self.data[idx]
    
    #构造训练数据
    a=np.random.rand(4,3)#4个数据,每一个数据是一个向量。
    print(a)
    

    #制作dataset
    dataset=mydataset(a)
    
    len(dataset)#调用了你上面定义的def __len__()那个函数
    #4
    
    dataset[0]#调用了你上面定义的def __getitem__()那个函数,传入的idx=0,也就是取第0个数据。
    #array([0.56998216, 0.72663738, 0.3706266 ])
    

    dataloader之collate_fn

    dataloader=tud.DataLoader(dataset,batch_size=2)
    

    batch_size=2即一个batch里面会有2个数据。我们以第1个batch为例,tud.DataLoader会根据dataset取出前2个数据,然后弄成一个列表,如下:

    batch=[dataset[0],dataset[1]]
    batch
    

    [array([0.56998216, 0.72663738, 0.3706266 ]),
    array([0.3403586 , 0.13931333, 0.71030221])]

    然后将上面这个batch作为参数交给collate_fn这个函数进行进一步整理数据,然后得到real_batch,作为返回值。如果你不指定这个函数是什么,那么会调用pytorch内部的collate_fn

    也就是说,我们如果自己要指定这个函数,collate_fn应该定义成下面这个样子。

    def my_collate(batch):#batch上面说过,是dataloader传进来的。
    	***#你自己定义怎么整理数据。下面会说。
    	real_batch=***
    	return real_batch
    

    那么pytorch内部默认的collate_fn函数长什么样呢?我们先观察下面的例子:

    it=iter(dataloader)
    nex=next(it)#我们展示第一个batch经过collate_fn之后的输出结果
    print(nex)
    

    tensor([[0.5700, 0.7266, 0.3706],
    [0.3404, 0.1393, 0.7103]], dtype=torch.float64)

    上面这个返回的结果就是real_batch。也就是collate_fn函数的返回值!!也就是说collate_fn将batch变成了上面的real_batch。

    我们重新写一遍

    batch:
    [array([0.56998216, 0.72663738, 0.3706266 ]),
    array([0.3403586 , 0.13931333, 0.71030221])]
    real_batch:
    tensor([[0.5700, 0.7266, 0.3706],
    [0.3404, 0.1393, 0.7103]], dtype=torch.float64)

    将batch变成上述real_batch很容易呀,就是把一个列表,变成了矩阵,我们也会!我们下面就来自己写一个collate_fn实现这个功能。

    def my_collate(batch):
        real_batch=np.array(batch)
        real_batch=torch.from_numpy(real_batch)
        return real_batch
    
    dataloader2=tud.DataLoader(dataset,batch_size=2,collate_fn=my_collate)
    
    it=iter(dataloader2)
    nex=next(it)#我们展示第一个batch经过collate_fn之后的输出结果
    print(nex)
    

    tensor([[0.5700, 0.7266, 0.3706],
    [0.3404, 0.1393, 0.7103]], dtype=torch.float64)

    这不就和默认的collate_fn的输出结果一样了嘛!

    应用情形

    通常,我们并不需要使用这个函数,因为pytorch内部有一个默认的。但是,如果你的数据不规整,使用默认的会报错。例如下面的数据。
    假设我们还是4个输入,但是维度不固定的。之前我们是每一个数据的维度都为3。

    a=[[1,2],[3,4,5],[1],[3,4,9]]
    dataset=mydataset(a,b)
    dataloader=tud.DataLoader(dataset,batch_size=2)
    it=iter(dataloader)
    nex=next(it)
    nex
    

    使用默认的collate_fn,直接报错,要求相同维度。

    这个时候,我们可以使用自己的collate_fn,避免报错。

    不过话说回来,我个人感受是:

    在这里避免报错好像也没有什么用,因为大多数的神经网络都是定长输入的,而且很多的操作也要求相同维度才能相加或相乘,所以:这里不报错,后面还是报错。如果后面解决这个问题的方法是:在不足维度上进行补0操作,那么我们为什么不在建立dataset之前先补好呢?所以,collate_fn这个东西的应用场景还是有限的。不过,明白其原理总是好事。


    完结撒花


    来源:音程

    物联沃分享整理
    物联沃-IOTWORD物联网 » pytorch之深入理解collate_fn

    发表评论