Pytorch数据读取机制(DataLoader)

小时候,乡愁是一枚小小的邮票,你在这头,我在那头;

长大后,乡愁是一张核酸证明,你在家里,我在隔离!

一、python读取机制

在学习Pytorch的数据读取之前,我们得先回顾一下这个数据读取到底是以什么样的逻辑存在的, 我们知道机器模型学习的五大模块,分别是数据,模型,损失函数,优化器,迭代训练。而这里的数据读取机制,很显然是位于数据模块的一个小分支,下面看一下数据模块的详细内容:

数据模块中,又可以大致分为上面不同的子模块, 而今天学习的DataLoader和DataSet就是数据读取子模块中的核心机制。 了解了上面这些框架,有利于把知识进行整合起来,到底学习的内容属于哪一块。下面正式开始DataLoader和Dataset的学习。

二、Dataloader

torch.utils.data.DataLoader(): 构建可迭代的数据装载器, 我们在训练的时候,每一个for循环,每一次iteration,就是从DataLoader中获取一个batch_size大小的数据的。

DataLoader的参数很多,但我们常用的主要有5个:

dataset: Dataset类, 决定数据从哪读取以及如何读取
bathsize: 批大小
num_works: 是否多进程读取机制
shuffle: 每个epoch是否乱序
drop_last: 当样本数不能被batchsize整除时, 是否舍弃最后一批数据

三、Dataset

torch.utils.data.Dataset(): Dataset抽象类, 所有自定义的Dataset都需要继承它,并且必须复写__getitem__()这个类方法(或__get_sample__())。

__getitem__方法的是Dataset的核心,作用是接收一个索引, 返回一个样本, 看上面的函数,参数里面接收index,然后我们需要编写究竟如何根据这个索引去读取我们的数据部分。

train函数是模型训练的入口。首先一些变量的更新采用自定义的AverageMeter类来管理,然后model.train()是设置为训练模式。 for i, (input, target) in enumerate(train_loader) 是数据迭代读取的循环函数,具体而言,当执行enumerate(train_loader)的时候,是先调用DataLoader类的__iter__方法,该方法里面再调用DataLoaderIter类的初始化操作__init__。而当执行for循环操作时,调用DataLoaderIter类的__next__方法,在该方法中通过self.collate_fn接口读取self.dataset数据时,就会调用TSNDataSet类的__getitem__方法,从而完成数据的迭代读取。读取到数据后就将数据从Tensor转换成Variable格式,然后执行模型的前向计算:output = model(input_var);损失函数计算: loss = criterion(output, target_var);准确率计算: prec1, prec5 = accuracy(output.data, target, topk=(1,5));模型参数更新等等。其中loss.backward()是损失回传, optimizer.step()是模型参数更新。

参考:https://blog.csdn.net/wuzhongqiang/article/details/105499476

https://blog.csdn.net/rytyy/article/details/105944813

来源:深度科研

物联沃分享整理
物联沃-IOTWORD物联网 » Pytorch数据读取机制(DataLoader)

发表评论