手把手教你制作自己的CIFAR数据集(附项目源码)
从CIFAR数据集制作开始教你训练自己的分类模型
目录
参考CIFAR的格式制作自己的数据集 使用自己制作的数据集训练模型
参考CIFAR的格式制作自己的数据集
代码已经公开在本人的Github,记得给我留颗星星,下面是代码使用的详细教程
import os
def getFlist(path):
root_dirs = []
for root, dirs, files in os.walk(path):
print('root_dir:', root)
print('sub_dirs:', dirs)
print('files:', files)
root_dirs.append(root)
print('root_dirs:', root_dirs[1:])
root_dirs = root_dirs[1:]
return root_dirs
def getChildList(root_dirs):
j = 0
f = open('data/cow_jpg.lst', 'w')#生成文件路径和类别索引
if __name__ == '__main__':
resDir = 'data'
f2 = open('data/object_list.txt', 'w')#生成类别和索引的对应表
root_dirs = getFlist(resDir)
k = 0
for root_dir in root_dirs:
f2.write('%s %s\n'%(root_dir,k))
k = k+1
f2.close()
getChildList(root_dirs)
print(root_dirs)
import os
import random
f = open('data/cow_jpg.lst')#上一步生成的图片路径文件
list = f.readlines()
print(len(list))
random.shuffle(list)
print(list)
set_num = int(float(len(list))*0.2)
#0.2为拆分阈值,0.2则是前20%为测试集,剩下的是训练集
test_list = list[:set_num]
train_list = list[set_num:]
print('================')
print(len(test_list))
print(len(train_list))
print(test_list and train_list)
f2 = open('data/cow_jpg_train.lst','w')
for i in train_list:
f2.write(i)
f3 = open('data/cow_jpg_test.lst','w')
for i in test_list:
f3.write(i)
f.close()
f2.close()
f3.close()
和 train_batch (文件名:data/cow_jpg_train.lst)
这样你就会得到:data_batch_0,…,test_batch,batches.meta等三类文件,与官方的CIFAR数据集完全一致,下面我们以任何一个使用CIFAR数据集的模型为例,进行测试
使用自己制作的数据集训练模型
打开data_utils.py,找到下面这段代码,将下载设置为否(download=False),找不到就算了,跳过这步
if args.dataset == "cifar10":
trainset = datasets.CIFAR10(root="./data",
train=True,
download=False,
transform=transform_train)
testset = datasets.CIFAR10(root="./data",
train=False,
download=False,
transform=transform_test) if args.local_rank in [-1, 0] else None
跑一下模型 train.py ,报错
Traceback (most recent call last):
File "/workspace/ViT-pytorch-main/train.py", line 347, in <module>
main()
File "/workspace/ViT-pytorch-main/train.py", line 342, in main
train(args, model)
File "/workspace/ViT-pytorch-main/train.py", line 158, in train
train_loader, test_loader = get_loader(args)
File "/workspace/ViT-pytorch-main/utils/data_utils.py", line 31, in get_loader
transform=transform_train)
File "/opt/conda/envs/ViT/lib/python3.6/site-packages/torchvision/datasets/cifar.py", line 93, in __init__
self._load_meta()
File "/opt/conda/envs/ViT/lib/python3.6/site-packages/torchvision/datasets/cifar.py", line 99, in _load_meta
' You can use download=True to download it'
RuntimeError: Dataset metadata file not found or corrupted. You can use download=True to download it
关闭CIFAR源码中的文件完整性验证
/opt/conda/envs/ViT/lib/python3.6/site-packages/torchvision/datasets/cifar.py
if not check_integrity(path, self.meta['md5']):
raise RuntimeError('Dataset metadata file not found or corrupted.' +
' You can use download=True to download it')
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
if not check_integrity(fpath, md5):
return False
找到你的程序代码中的num_classes = 10,将10修改为你的类别数量
cifar10Dataset-master
来源:准备深度学习一下