yolov5只训练数据集中的某几个类别

文章目录

  • 前言
  • 一、直接修改数据集标签
  • 二、修改加载labels的代码
  • 1.train
  • 2.create_dataloader
  • 3.LoadImagesAndLabels
  • 4.cache_labels
  • 5.verify_image_label
  • 总结

  • 前言

    提示:在训练网络过程中,我们找到的公开数据集可能有很多分类,但是我们的检测任务又不需要那么多,或者说是对自己的训练集做一个取舍:

    例如:一个训练集有猫和狗,但是我不想训练猫了,只想训练狗,所以就只加载狗的标签。


    基本思路:只训练某几类标签的话,那就需要修改dataset中的labels,本文提供两种思路

    一、直接修改数据集标签

    通过直接修改数据集标签(*.txt)来删去某种类别的数据。

    这种方法很直接,但是也意味着你多了一个整个数据集文件,虽然内存不大,但是感觉比较呆。

    二、修改加载labels的代码

    数据集labels在加载进dataloader过程中本身就有某些处理过程(如检验是否为空),我们可以在上面加些筛选条件就可以做到过滤效果。

    1.train

    在train.py文件下找到加载数据集的代码,如:

    # Trainloader
    train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
                                              hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=LOCAL_RANK,
                                              workers=workers, image_weights=opt.image_weights, quad=opt.quad,
                                              prefix=colorstr('train: '))
    

    然后我们进入create_dataloader继续跟踪

    2.create_dataloader

    找到加载数据集LoadImagesAndLabels:

    dataset = LoadImagesAndLabels(path, imgsz, batch_size,
                                          augment=augment,  # augment images
                                          hyp=hyp,  # augmentation hyperparameters
                                          rect=rect,  # rectangular training
                                          cache_images=cache,
                                          single_cls=single_cls,
                                          stride=int(stride),
                                          pad=pad,
                                          image_weights=image_weights,
                                          prefix=prefix)
    

    3.LoadImagesAndLabels

    其中,下面这一段代码是加载cache缓存文件,这里不细说,就把它简单看成数据集文件。如果cache已存在,就直接加载,不存在才创建,我们需要进入创建部分cache_labels

    try:
       cache, exists = np.load(cache_path, allow_pickle=True).item(), True  # load dict
       assert cache['version'] == self.cache_version  # same version
       assert cache['hash'] == get_hash(self.label_files + self.img_files)  # same hash
    except:
        cache, exists = self.cache_labels(cache_path, prefix), False  # cache
    

    4.cache_labels

    这个部分就是处理数据集的信息统计(如是否为空等),其中一段遍历整个数据集的代码

    pbar = tqdm(pool.imap(verify_image_label, zip(self.img_files, self.label_files, repeat(prefix))),
                            desc=desc, total=len(self.img_files))
    

    这段代码含义大致就是将img_files, label_files, prefix打包丢进verify_image_label函数中处理后返回

    5.verify_image_label

    这段函数就是我们的最终目标了,这里面有加载图片,标签的功能,还可以进行一定筛选,我们就从这里修改。找到加载labels的代码:

    withopen(lb_file) as f:
    	l = [x.split() for x in f.read().strip().splitlines() if len(x)]
    

    这段代码就是将labels的内容加载进列表l中,如这里有个label文件

    有类别6、7,通过代码加载进去就是

    list L 中有两个list,代表两个目标,每个list第一位就是类别。这个时候效果就很明显了,如果我们不想要类别6,我们只需要修改成

    withopen(lb_file) as f:
    	l = [x.split() for x in f.read().strip().splitlines() if len(x) and x[0]!='6']
    

    就行了,最后效果为

    思路就是这样,还有些其他的修改方法根据自己的需要再操作,内核就是对list的处理而已,基本功。

    总结

    上面都是我在做项目过程中遇到的问题,而且在csdn上没找到详细的解答,于是自己动手解决并分享。

    来源:Starkiron

    物联沃分享整理
    物联沃-IOTWORD物联网 » yolov5只训练数据集中的某几个类别

    发表评论