mmrotate旋转目标检测框架的学习与使用

目录

前言

一、环境配置

1. 下载checkpoint文件

2. 运行demo

二、制作自己的数据集

1. 标注数据

2. 标签格式转换

3. 可视化数据集

4. 数据集裁剪

三、 修改配置文件

1. 修改classes

2. 修改训练参数

四、训练并测试

1. 训练

2. 测试

3. 预测

五、总结

参考资料


前言

mmrotate旋转目标检测框架的学习与使用是在AutoDL的服务器上进行的,配置为ubuntu18.04,GPU RTX 2080 Ti * 1,以下是一个学习过程的记录,希望能对大家有所帮助。

一、环境配置

MMRotate 依赖 PyTorchMMCV 和 MMDetection,以下是安装的简要步骤。 更详细的安装指南请参考 安装文档

conda create -n open-mmlab python=3.7 pytorch==1.7.0 cudatoolkit=10.1 torchvision -c pytorch -y
conda activate open-mmlab
pip install openmim
mim install mmcv-full
mim install mmdet
git clone https://github.com/open-mmlab/mmrotate.git
cd mmrotate
pip install -r requirements/build.txt
pip install -v -e .

一切顺利之后,为了验证是否正确安装了 MMRotate,我们需进行以下操作。

1. 下载checkpoint文件

新建一个checkpoint文件夹,将checkpoint文件下载到该文件夹下,checkpoint文件可根据自己的需求进行下载,我下载的是s2anet_r50_fpn_1x_dota_le135-5dfcf396.pth和s2anet_r50_fpn_fp16_1x_dota_le135-5cac515c.pth

2. 运行demo

运行指令形式如下:

python demo/image_demo.py \
    ${IMG_ROOT} \
    ${CONFIG_FILE} \
    ${CHECKPOINT_FILE} \
    ${OUT_FILE}

IMG_ROOT:为待检测图像
CONFIG_FILE:为配置文件
CHECKPOINT_FILE:为训练好的权重。

OUT_FILE:输出图片的保存zaizh

输入以下指令:

 ​python ./demo/image_demo.py ./demo/demo.jpg ./configs/s2anet/s2anet_r50_fpn_fp16_1x_dota_le135.py ./checkpoint/s2anet_r50_fpn_fp16_1x_dota_le135-5cac515c.pth --out-file ./demo/result.jpg

 应该正常运行的,但是出现了错误(如果运行正常,会在demo文件夹下看到结果result.jpg):

有可能是在开始安装mmcv-full的时候,没有指定版本,选择直接安装,如下:

pip install mmcv-full

采用这样默认安装mmcv-full的方式,如果与你环境里的cuda和torch版本不匹配,就容易出现上面报错,卸载掉原来的mmcv。

pip uninstall mmcv-full

打开mmcv,点击下图已经建立好的对应版本,进行安装。

我的是torch1.7.0  cuda10.1,对应的install形式如下:

pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.7.0/index.html

 需要指定mmcv_version,前面报错有提示: Please install mmcv>=1.4.5, <=1.6.0,故选择1.6.0进行安装

pip install mmcv-full==1.6.0 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.7.0/index.html

 安装完成后,输入指令再次进行测试:

 python ./demo/image_demo.py ./demo/demo.jpg ./configs/s2anet/s2anet_r50_fpn_fp16_1x_dota_le135.py ./checkpoint/s2anet_r50_fpn_fp16_1x_dota_le135-5cac515c.pth --out-file ./demo/result.jpg

原图与输出结果的对比;

二、制作自己的数据集

1. 标注数据

使用的工具是rolabelimg,rolabelimg安装过程太过繁琐,为了节省精力,已提供rolabelimg.exe,可自行下载

链接:https://pan.baidu.com/s/1j9uV-_2zYC7pb6YTzaeLXQ  提取码:qxmg

为了进行后续操作,新建data文件夹,在data文件夹下新建DOTA文件夹,将标注好的数据放到DOTA文件夹下

2. 标签格式转换

新建voc2dota.py文件,将xml标签转换为DOTA的标签,注意代码中的图片和标签地址,以及图片格式,根据实际情况自行修改

# 将xml格式的标签转化为txt格式,并输出标注好的图片
import os
import xml.etree.ElementTree as ET
import math
import cv2 as cv

def voc_to_dota(xml_path, xml_name):
    txt_name = xml_name[:-4] + '.txt'
    txt_path = xml_path[:-4] + '/txt_label'
    if not os.path.exists(txt_path):
        os.makedirs(txt_path)
    txt_file = os.path.join(txt_path, txt_name)
    file_path = os.path.join(xml_path, file_list[i])
    tree = ET.parse(os.path.join(file_path))
    root = tree.getroot()
    # print(root[6][0].text)
    image_path = '../data/DOTA/images/'
    out_path = '../data/DOTA/outputImg/'
    filename = image_path + xml_name[:-4] + '.Jpeg'
    img = cv.imread(filename)
    with open(txt_file, "w+", encoding='UTF-8') as out_file:
        # out_file.write('imagesource:null' + '\n' + 'gsd:null' + '\n')
        for obj in root.findall('object'):
            name = obj.find('name').text
            difficult = obj.find('difficult').text
            # print(name, difficult)
            robndbox = obj.find('robndbox')
            cx = float(robndbox.find('cx').text)
            cy = float(robndbox.find('cy').text)
            w = float(robndbox.find('w').text)
            h = float(robndbox.find('h').text)
            angle = float(robndbox.find('angle').text)
            # print(cx, cy, w, h, angle)
            p0x, p0y = rotatePoint(cx, cy, cx - w / 2, cy - h / 2, -angle)
            p1x, p1y = rotatePoint(cx, cy, cx + w / 2, cy - h / 2, -angle)
            p2x, p2y = rotatePoint(cx, cy, cx + w / 2, cy + h / 2, -angle)
            p3x, p3y = rotatePoint(cx, cy, cx - w / 2, cy + h / 2, -angle)

            # 找最左上角的点
            dict = {p0y:p0x, p1y:p1x, p2y:p2x, p3y:p3x}
            list = find_topLeftPopint(dict)
            #print((list))
            if list[0] == p0x:
                list_xy = [p0x, p0y, p1x, p1y, p2x, p2y, p3x, p3y]
            elif list[0] == p1x:
                list_xy = [p1x, p1y, p2x, p2y, p3x, p3y, p0x, p0y]
            elif list[0] == p2x:
                list_xy = [p2x, p2y, p3x, p3y, p0x, p0y, p1x, p1y]
            else:
                list_xy = [p3x, p3y, p0x, p0y, p1x, p1y, p2x, p2y]

            # 在原图上画矩形 看是否转换正确
            cv.line(img, (int(list_xy[0]), int(list_xy[1])), (int(list_xy[2]), int(list_xy[3])), color=(255, 0, 0), thickness=3)
            cv.line(img, (int(list_xy[2]), int(list_xy[3])), (int(list_xy[4]), int(list_xy[5])), color=(0, 255, 0), thickness= 3)
            cv.line(img, (int(list_xy[4]), int(list_xy[5])), (int(list_xy[6]), int(list_xy[7])), color=(0, 0, 255), thickness = 2)
            cv.line(img, (int(list_xy[6]), int(list_xy[7])), (int(list_xy[0]), int(list_xy[1])), color=(255, 255, 0), thickness = 2)
            cv.imwrite(out_path + xml_name[:-4] + '.Jpeg', img)
            data = str(list_xy[0]) + " " + str(list_xy[1]) + " " + str(list_xy[2]) + " " + str(list_xy[3]) + " " + \
                   str(list_xy[4]) + " " + str(list_xy[5]) + " " + str(list_xy[6]) + " " + str(list_xy[7]) + " "
            data = data + name + " " + difficult + "\n"
            out_file.write(data)


def find_topLeftPopint(dict):
    dict_keys = sorted(dict.keys())  # y值
    temp = [dict[dict_keys[0]], dict[dict_keys[1]]]
    minx = min(temp)
    if minx == temp[0]:
        miny = dict_keys[0]
    else:
        miny = dict_keys[1]
    return [minx, miny]


# 转换成四点坐标
def rotatePoint(xc, yc, xp, yp, theta):
    xoff = xp - xc
    yoff = yp - yc
    cosTheta = math.cos(theta)
    sinTheta = math.sin(theta)
    pResx = cosTheta * xoff + sinTheta * yoff
    pResy = - sinTheta * xoff + cosTheta * yoff
    # pRes = (xc + pResx, yc + pResy)
    # 保留一位小数点
    return float(format(xc + pResx, '.1f')), float(format(yc + pResy, '.1f'))
    # return xc + pResx, yc + pResy

if __name__ == '__main__':
    root_path = '../data/DOTA/xml'
    file_list = os.listdir(root_path)
    for i in range(0, len(file_list)):
        if ('.xml' in file_list[i]) or ('.XML' in file_list[i]):
            voc_to_dota(root_path, file_list[i])
            print('----------------------------------------{}{}----------------------------------------'
                  .format(file_list[i], ' has Done!'))
        else:
            print(file_list[i] + ' is not xml file')

运行成功后,会生成txt_label文件

3. 可视化数据集

tools/misc/browse_dataset.py帮助用户浏览检测的数据集(包括图像和检测框的标注),或将图像
保存到指定目录,指令形式如下:

python tools/misc/browse_dataset.py ${CONFIG} [-h] [–skip-type ${SKIP_TYPE[SKIP_TYPE.
,→..]}] [–output-dir ${OUTPUT_DIR}] [–not-show] [–show-interval ${SHOW_INTERVAL}]

python tools/misc/browse_dataset.py configs/s2anet/s2anet_r50_fpn_1x_dota_le135.py  --output-dir out

运行成功后,在mmrotate/out文件下可以看到标注好的图片,以检验转换是否成功 

4. 数据集裁剪

 将上述数据集分成 train、test、val 、trainval等几部分(我是手动划分的:train 80%,test 10%,val 10%)以便于训练,其中每个部分的文件夹下都包含有 images(图像) 和 labelTxt(对应的txt标签)

 

将 train、test、val 中的图片进行裁剪,在mmrotate/tools/data/dota/split/ 路径下img_split.py文件(裁剪脚本) 以及 mmrotate/tools/data/dota/split/split_configs/ 路径下的配置文件,其文件内容就是img_split.py的配置信息,我们需要修改其中的参数,让其加载上述的train、test、val中的图像及标签,并进行裁剪,以ss_train.json为例,需要修改的地方有图片地址,标签地址,分割后保存地址,以及保存图片的格式(若按照前面的操作设置文件名,则无需进行修改,只需要注意图片格式),ss_test.json和ss_val.json的修改同理

 修改完上述文件后,运行如下指令,进行数据裁剪:

python tools/data/dota/split/img_split.py --base-json tools/data/dota/split/split_configs/ss_train.json
python tools/data/dota/split/img_split.py --base-json tools/data/dota/split/split_configs/ss_test.json
python tools/data/dota/split/img_split.py --base-json tools/data/dota/split/split_configs/ss_val.json
python tools/data/dota/split/img_split.py --base-json tools/data/dota/split/split_configs/ss_trainval.json

裁剪完后的数据集路径为mmrotate/data/split_ss_dota,结果如下:

 

三、 修改配置文件

数据集准备好之后,接下来,需要修改训练相关的配置信息:

1. 修改classes

以s2anet_r50_fpn_1x_dota_le135.py为例,将其中的num_classes=15改成num_classes=1(根据自己数据集的类别数量进行修改)

 同时,修改 mmrotate/mmrotate/datasets/dota.py 文件中的类别名称,注意只有一个类别时不要去掉逗号,修改图像数据集的后缀

最后,修改训练使用的数据集路径:找到并打开 mmrotate/configs/_base_/datasets/dotav1.py 文件,修改其中的 data_root 路径为自己裁剪的数据集路径

上述配置完成后即可训练,即可进行训练,其他可进行修改的地方如下:

2. 修改训练参数

修改线程数、batch_size和测试集标签路径(mmrotate/configs/base/datasets/dotav1.py)

修改训练epoches(mmrotate/configs/_base_/schedules/schedule_1x.py)

注意:有的模型修改schedule_1x.py无效,因为其在继承了schedule_1x.py后进行了改写,例如,oriented_reppoints_r50_fpn_40e_dota_ms_le135.py里对上述参数已进行了设置,用此模型进行训练,只能在oriented_reppoints_r50_fpn_40e_dota_ms_le135.py中修改相应参数,而修改schedule_1x.py无效

修改模型训练的日志打印和load from加载预训练模型(configs/_base_/default_runtime.py),此处特别容易犯错,更换网络进行训练一定要及时替换相应的预训练模型,log_config封装了多个记录器挂钩,并允许设置间隔。现在MMCV支持WandbLoggerHook, MlflowLoggerHook和TensorboardLoggerHook。详细用法可以在官方文档中找到。

四、训练并测试

1. 训练

 python tools/train.py ${CONFIG_FILE} [optional arguments] 

–work_dir ${YOUR_WORK_DIR}

训练主要有两个参数: config:使用的模型文件 ; – -work-dir:训练得到的模型及配置信息保存的路径,新建如下路径work-dir/run/s2anet/保存训练结果,输入以下指令进行训练

python tools/train.py configs/s2anet/s2anet_r50_fpn_1x_dota_le135.py --work-dir work-dir/run/s2anet/

2. 测试

 python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [optional arguments]

测试主要有三个参数: config:使用的模型文件 ;checkpoint:训练得到的模型权重文件; –show-dir:预测结果存放的路径。新建如下路径work-dir/run/s2anet/保存测试结果,输入以下指令进行测试

python tools/test.py configs/s2anet/s2anet_r50_fpn_1x_dota_le135.py work-dir/run/s2anet/latest.pth --show-dir work-dir/output/s2anet/ --eval mAP

 运行报错,然后查看了一下test数据集,发现标签文件为空

 然后查看mmrotate/tools/data/dota/split/split_configs/ss_test.json文件,会发现没有标签文件路径,这就是造成上述现象的原因(DOTA数据集是在线验证,本地没有标签也可以检测训练效果,自行标注的数据集是离线验证,没有标签无法检测训练效果)

 解决办法是在ss_test.json文件中添加上路径,重新进行原数据的分割,或者修改mmrotate/configs/_base_/datasets/dotav1.py为测试集的路径为验证集或训练集路径

 然后再输入指令,可得到测试结果,更多测试操作见官方文档

python tools/test.py configs/s2anet/s2anet_r50_fpn_1x_dota_le135.py work-dir/run/s2anet/latest.pth --show-dir work-dir/output/s2anet/ --eval mAP

 

3. 预测

测试是为了检验训练效果,得到mAP、FPS等指标,进行模型性能评价,而预测则是将训练好的模型进行部署应用。输入以下指令进行单张图片的预测:

python ./demo/image_demo.py ./demo/00008.Jpeg ./configs/s2anet/s2anet_r50_fpn_1x_dota_le135.py ./work-dir/run/s2anet/latest.pth --out-file ./demo/result.Jpeg

要想实现文件夹下多张图片或者视频的预测,可以参考mmdection的相关资料对demo.py进行修改。

五、总结

上述是我以S2A为例从环境配置到训练自己的数据集全过程,在实际使用中,如果自己的图片尺寸不是特别大,可以跳过裁剪步骤,直接进行训练。

参考资料

Welcome to MMRotate’s documentation! — mmrotate 文档https://mmrotate.readthedocs.io/zh_CN/latest/

复现 S2ANet RTX 2080Tihttp://t.csdn.cn/vwruC

mmrotate旋转目标检测框架从环境配置到训练自己的数据集http://t.csdn.cn/1cGDN

基于MMRotate训练自定义数据集 做旋转目标检测 2022-3-30http://t.csdn.cn/A1vSE

物联沃分享整理
物联沃-IOTWORD物联网 » mmrotate旋转目标检测框架的学习与使用

发表评论