手把手教你用tensorflow2.3训练自己的分类数据集

配合视频一起食用这篇教程效果更佳:手把手教你用tensorflow2训练自己的数据集

tensorflow2.x版本对小白非常友好,2.x的api中对keras进行了合并,大家只需要安装tensorflow就可以使用里面封装好的keras,利用keras可以快速地加载数据集和构建模型,下面我们直接来看以下通过tensorflow2.3训练自己的分类数据集吧。

注:本文主要针对图片形式的数据集构建分类模型,文本数据、目标检测等任务暂不涉及。

本文使用到的代码已在码云上开源,请大家自行下载,star不迷路:

vegetables_tf2.3: 基于tensorflow2.3开发的水果蔬菜识别系统 (gitee.com)

另外我这边整理了一些物体分类的数据集,大家根据需要下载:

计算机视觉数据集清单-附赠tensorflow模型训练和使用教程_dejavu的博客-CSDN博客

数据集收集

数据集收集主要有3种方式,一种是使用某些机构或者组织开源出来的数据集,另一种是自己通过拍照或者爬虫的方式来自行获取数据集,还有一种是热心网友自己采集整理之后的数据集,下面的csdn链接中我给出了一些我整理的数据集,大家可以根据自己的需要下载使用。

计算机视觉数据集清单-附赠tensorflow模型训练和使用教程_dejavu的博客-CSDN博客

开源数据集

开源的分类数据集一般质量相对较好,数据集的所有者在发布前对数据集做了整理和清洗,直接使用开源的数据集可以帮助我们节省大量的时间,比较有名的有mnist数据集、cifar数据集等,另外大家可以在一些网站中寻找数据集,比如下列的几个网站:

和鲸社区 – Heywhale.com

UCI Machine Learning Repository

CSDN – 专业开发者社区

另外你也可以直接在搜索引擎中输入关键字来寻找数据集,比如你想要寻找垃圾分类的数据集,你可以在搜索栏中输入垃圾 分类 数据集等关键字来直接查找,一般会有热心的网友给出数据集的链接,下载即可。

image-20210616112527938

自行采集数据集

如果找不到相应的开源数据集,你也可以通过自己采集的方式来获取数据集,比如你可以通过拍照的方式来搜集你自己所需的数据集,或者是通过爬虫的方式来搜集数据集,这里有段爬虫爬取百度图片的代码,大家直接执行,输入自己想要爬取的图片名称和图片数量,即可爬取相应的图片,代码如下:

import requests
import re
import os

headers = {
    'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/84.0.4147.125 Safari/537.36'}
name = input('请输入要爬取的图片类别:')
num = 0
num_1 = 0
num_2 = 0
x = input('请输入要爬取的图片数量?(1等于60张图片,2等于120张图片):')
list_1 = []
for i in range(int(x)):
    name_1 = os.getcwd()
    name_2 = os.path.join(name_1, 'data/' + name)
    url = 'https://image.baidu.com/search/flip?tn=baiduimage&ie=utf-8&word=' + name + '&pn=' + str(i * 30)
    res = requests.get(url, headers=headers)
    htlm_1 = res.content.decode()
    a = re.findall('"objURL":"(.*?)",', htlm_1)
    if not os.path.exists(name_2):
        os.makedirs(name_2)
    for b in a:
        try:
            b_1 = re.findall('https:(.*?)&', b)
            b_2 = ''.join(b_1)
            if b_2 not in list_1:
                num = num + 1
                img = requests.get(b)
                f = open(os.path.join(name_1, 'data/' + name, name + str(num) + '.jpg'), 'ab')
                print('---------正在下载第' + str(num) + '张图片----------')
                f.write(img.content)
                f.close()
                list_1.append(b_2)
            elif b_2 in list_1:
                num_1 = num_1 + 1
                continue
        except Exception as e:
            print('---------第' + str(num) + '张图片无法下载----------')
            num_2 = num_2 + 1
            continue

print('下载完成,总共下载{}张,成功下载:{}张,重复下载:{}张,下载失败:{}张'.format(num + num_1 + num_2, num, num_1, num_2))

数据集整理

放置到相应的子文件夹

数据集收集完成之后,我们还需要对数据集进行整理,如果是爬虫爬取的图片可能会有一些质量比较差的图片,那么整理之前还需要进行数据的清洗,删除质量不好的图片,数据集整理其实很简单,我们只需要将数据集进行归类即可,即相同类别的图片放在一个文件夹下,比如下面的这个数据集,白菜的文件夹下放的全是白菜的图片,土豆的文件夹下则放的全是土豆的图片。

image-20210616122055853

image-20210616122124060

划分训练集和测试集

注:如果是使用的开源数据集,开源数据集可能已经进行了数据集的划分,直接使用即可,不需要再次进行划分,比如这里是我下载到的农作物病虫害的数据集,已经分别提供了训练集、测试集和验证集,就不需要再次进行数据集的划分。

image-20210616122542476

为了方便我们进行数据集的加载,我们还需要将图片划分为训练集和测试集,如果需要的话你还需要划分出验证集,验证集在一般的任务中是可选的,因为是自己收集的数据集的话,数据量比较少,如果再划分验证集的话可能会导致训练量不够,这里我写了一段数据集划分的代码逻辑,大家输入原始的数据集位置和划分之后的数据集位置,指定数据集划分的比例,即可完成数据集的划分。

# 作者: 宋老狗
import os
import random
from shutil import copy2


def data_set_split(src_data_folder, target_data_folder, train_scale=0.8, val_scale=0.0, test_scale=0.2):
    '''
    读取源数据文件夹,生成划分好的文件夹,分为trian、val、test三个文件夹进行
    :param src_data_folder: 源文件夹 E:/biye/gogogo/note_book/torch_note/data/utils_test/data_split/src_data
    :param target_data_folder: 目标文件夹 E:/biye/gogogo/note_book/torch_note/data/utils_test/data_split/target_data
    :param train_scale: 训练集比例
    :param val_scale: 验证集比例
    :param test_scale: 测试集比例
    :return:
    '''
    print("开始数据集划分")
    class_names = os.listdir(src_data_folder)
    # 在目标目录下创建文件夹
    split_names = ['train', 'val', 'test']
    for split_name in split_names:
        split_path = os.path.join(target_data_folder, split_name)
        if os.path.isdir(split_path):
            pass
        else:
            os.mkdir(split_path)
        # 然后在split_path的目录下创建类别文件夹
        for class_name in class_names:
            class_split_path = os.path.join(split_path, class_name)
            if os.path.isdir(class_split_path):
                pass
            else:
                os.mkdir(class_split_path)

    # 按照比例划分数据集,并进行数据图片的复制
    # 首先进行分类遍历
    for class_name in class_names:
        current_class_data_path = os.path.join(src_data_folder, class_name)
        current_all_data = os.listdir(current_class_data_path)
        current_data_length = len(current_all_data)
        current_data_index_list = list(range(current_data_length))
        random.shuffle(current_data_index_list)

        train_folder = os.path.join(os.path.join(target_data_folder, 'train'), class_name)
        val_folder = os.path.join(os.path.join(target_data_folder, 'val'), class_name)
        test_folder = os.path.join(os.path.join(target_data_folder, 'test'), class_name)
        train_stop_flag = current_data_length * train_scale
        val_stop_flag = current_data_length * (train_scale + val_scale)
        current_idx = 0
        train_num = 0
        val_num = 0
        test_num = 0
        for i in current_data_index_list:
            src_img_path = os.path.join(current_class_data_path, current_all_data[i])
            if current_idx <= train_stop_flag:
                copy2(src_img_path, train_folder)
                # print("{}复制到了{}".format(src_img_path, train_folder))
                train_num = train_num + 1
            elif (current_idx > train_stop_flag) and (current_idx <= val_stop_flag):
                copy2(src_img_path, val_folder)
                # print("{}复制到了{}".format(src_img_path, val_folder))
                val_num = val_num + 1
            else:
                copy2(src_img_path, test_folder)
                # print("{}复制到了{}".format(src_img_path, test_folder))
                test_num = test_num + 1

            current_idx = current_idx + 1

        print("*********************************{}*************************************".format(class_name))
        print(
            "{}类按照{}:{}:{}的比例划分完成,一共{}张图片".format(class_name, train_scale, val_scale, test_scale, current_data_length))
        print("训练集{}:{}张".format(train_folder, train_num))
        print("验证集{}:{}张".format(val_folder, val_num))
        print("测试集{}:{}张".format(test_folder, test_num))


if __name__ == '__main__':
    src_data_folder = "C:/Users/Scm97/Desktop/dejahu/data"  # todo 原始数据集目录
    target_data_folder = "C:/Users/Scm97/Desktop/dejahu/split_data"  # todo 数据集分割之后存放的目录
    data_set_split(src_data_folder, target_data_folder)

注:路径中最好不要出现中文

数据集划分之后,记住训练集和测试集的位置,接下来,我们就可以开始训练我们的模型了。

下面以花卉识别,我给大家演示一下,data是演示目录,目录下存放的是5个子文文件夹,对应5种花卉,每个子文件夹下存放了相应的花卉图片,split_data是新建的空文件夹,用于存放分割之后的数据集,这时候只需要修改代码种的两处即可。

image-20210616125843872

image-20210616130151242

代码默认训练集占80%,测试集占20%,修改完成之后右键直接执行即可。

image-20210616130259610

执行之后你就可以得到划分好的数据集

image-20210616130407052

这个时候记住训练集和测试集的目录,开始大干一场吧。

测试集目录为:C:/Users/Scm97/Desktop/dejahu/split_data/train

训练集目录为:C:/Users/Scm97/Desktop/dejahu/split_data/test

环境搭建

本次教程需要大家实现配置好python的环境,我们需要使用到anaconda和pycharm,不熟悉环境配置的同学可以看我得这篇博客,我在这里就不再进行赘述了。

如何在pycharm中配置anaconda的虚拟环境_dejavu的博客-CSDN博客

训练模型

模型训练的代码种,以cnn模型的训练为例,train_cnn.py是训练cnn模型的代码,只需要修改三处即可,如下所示

image-20210616131845980

train_mobilnet.py是训练mobilenet模型的代码,训练的模型将会保存在models目录下,这里也是只需修改三处即可。

image-20210616131957245

注:代码最后一行的epochs指的是跑的训练的轮数,这里默认是30,大家可以根据自己的需要增加或减少训练的轮数

修改之后直接运行即可,等代码跑完后模型就会保存在models目录下

image-20210616140111560

另外,在results目录下你可以找到模型训练的过程图

image-20210616140207484

模型训练的过程中会输出数据集的类名,这里记录一下,在后面的模型使用中会用到。

image-20210616134038441

测试模型

模型的测试的代码为test_model.py,也是只需要改动几处代码即可完成测试

改动如下:

image-20210616140427881

image-20210616140521754

测试的基本流程是:加载数据、加载模型、测试、保存结果

测试之后在命令行中会输出每个模型的准确率,并且会在results目录下生成相应的热力图

image-20210616140754011
image-20210616140824493

热力图中对应了每个类别的准确率,如下所示,是mobilenet测试的热力图。

heatmap_mobilenet

使用模型

模型的时候中,我们通过Pyqt5来构建图形化界面,用户可以上传图片,并在系统中调用我们训练好的模型进行图片类别的预测。

window.py代码中修改四处即可完成基本功能

image-20210616141502873

启动看看吧!

image-20210616141525770

快去试试你自己的数据集吧!

来源:肆十二

物联沃分享整理
物联沃-IOTWORD物联网 » 手把手教你用tensorflow2.3训练自己的分类数据集

发表评论