Pytorch Dataset类的使用(个人学习笔记)

训练模型一般都是先处理 数据的输入问题 和 预处理问题。

Pytorch提供了几个有用的工具:torch.utils.data.Dataset类 和 torch.utils.data.DataLoader类。

流程是先把 原始数据 转变成 torch.utils.data.Dataset类 ,

随后再把得到torch.utils.data.Dataset类 当作一个参数传递给 torch.utils.data.DataLoader类,

得到一个数据加载器,这个数据加载器每次可以返回一个 Batch 的数据供模型训练使用。

这一过程通常可以让我们把一张 生图 通过标准化、resize等操作转变成我们需要的 [B,C,H,W] 形状的 Tensor。

用原始数据都造出来的 Dataset子类 其实就是一个静态的数据池,这个数据池支持我们用 索引 得到某个数据,想要让这个数据池流动起来,源源不断地输出 Batch 还需要下一个工具 DataLoader类 。所以我们把创建的 Dataset子类 当参数传入 即将构建的DataLoader类才是使用Dataset子类最终目。

Dataset类的作用:提供一种方式去获取数据及其对应的真实Label

pycharm提供了三种利用python的方式

jupyter notebook,在pycharm中的python控制台,以及python文件

我觉得利用控制台和比较方便

图片库可以直接拖进来

help(Dataset)或者Dataset??查看帮助

dataset类是一个抽象类,所有的数据集想要在数据与标签之间建立映射,都需要继承这个类,所有的子类都需要重写__getitem__方法,该方法根据索引值获取每一个数据并且获取其对应的Label,子类也可以重写__len__方法,返回数据集的size大小。

一般__init__负责加载全部原始数据,初始化之类的。

__getitem__负责按索引取出某个数据,并对该数据做预处理。

但是对于如何加载原始数据以及如何预处理数据完全是由自己定义的,包括我们用 dataset[index] 取出的数据的组织形式都是完全自行定义的。

对于def __init__(self, root_dir, label_dir)

很简单,就是接收实例化时传入的参数:获取根目录路径、子目录路径

然后将两个路径进行组合,就得到了目标数据集的路径

我们将这个路径作为参数传入listdir函数,从而让img_path_list中存储该目录下所有文件名

此时通过索引就可以轻松获取每个文件名

接下来,我们要使用这些初始化的信息去获取其中的每一个图片的JpegImageFile对象

from torch.utils.data import Dataset  ##导入Dataset类
from PIL import Image  ## 读取图片的库,可以对图片进行可视化
##使用PIL来读取数据,它提供一个Image模块,可以让我们提取图像数据,我们先导入这个模块
import os  ## 关于系统操作的库,主要用来对文件路径操作,对数据所在文件路径进行字符串操作

class MyData(Dataset): ##首先我们创建一个类,类名为MyData,这个类要继承Dataset类

    def __init__(self, root_dir, label_dir):
##首先需要写的是__init__方法,此方法用于对象实例化,通常用来提供类中需要使用的变量
        self.root_dir = root_dir##根目录路径
        self.label_dir = label_dir##标签路径
        self.path = os.path.join(self.root_dir, self.label_dir)
##目标路径文件夹是两个路径拼在一起,os方法对字符串进行拼接
        self.img_path_list = os.listdir(self.path)
##从数据集目标路径中,获取所有文件的名字,存储到一个列表中,listdir方法会将路径下的所有文件名(包括后缀名)组成一个列表



    def __getitem__(self, idx):
##__getitem__方法,该方法根据索引值获取每一个数据并且获取其对应的Label
## 默认是item,但常改为idx,是index的缩写,以便以后数据集获取后我们使用索引编号访问每个数据
        img_name = self.img_path_list[idx]## 从文件名列表中获取了文件名
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
## 组装路径,获得了图片具体的路径
        img =Image.open(img_item_path)##使用PIL读取这个图像
        label = self.label_dir
        return img, label ##此处img是一个JpegImageFile对象,label是一个字符串。
##使用这个类进行实例化时,传入的参数是根目录路径,以及对应的label名,我们就可以得到一个MyData对象


    def __len__(self): ##__len__方法,返回数据集的size大小
        return len(self.img_path_list)
##__len__主要功能是获取数据集的长度,由于我们在初始化中已经获取了所有文件名的列表,所以只需要知道这个列表的长度,就知道了有多少个文件,也就是知道了有多少个具体的数据




##有了这个MyData对象后,我们可以直接使用索引来获取具体的图像对象,索引即可调用__getitem__方法,会返回我们根据索引提取到的对应数据的图像对象以及其label

root_dir = "hymenoptera_data/train" ##在这里我们用的是相对路径
ants_label_dir = "ants" ##指向存放train数据集里的ants数据的文件夹
bees_label_dir = "bees"
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)
train_dataset = ants_dataset + bees_dataset
img1, label1 = ants_dataset[0]  # 返回一个元组,返回值是__getitem__方法的返回值 
img2, label2 = bees_dataset[0]
from PIL import Image 
img_path = "D:\\DeepLearning\\dataset\\train\\ants\\0013035.jpg"##图片的绝对路径
img = Image.open(img_path) ##使用Image的open方法读取图片
img.size##读取图片大小
img.show()##查看图片
##控制台右侧可以看到一些属性

其他具体使用可以参考【深度学习】PyTorch Dataset类的使用与实例分析 – 知乎 (zhihu.com)构建数据集路径部分

以及B站https://www.bilibili.com/video/BV1hE411t7RN?p=7

物联沃分享整理
物联沃-IOTWORD物联网 » Pytorch Dataset类的使用(个人学习笔记)

发表评论