yolov5 python API(供其他程序调用)

你的yolov5🚀是否只局限于detect.py?如果其他程序要调用yolov5,就需要制作一个detect.py的python API。python无处不对象,制作detect API实际上就是制作detect类。

目录

  • 前言
  • 一、总体思路
  • 二、制作detect类
  • 二、调用detect类
  • 结语

  • 前言

    yolov5源码版本:截止2022.2.3
    链接:https://github.com/ultralytics/yolov5
    

    作为一个“CV”主义者,在此之前在各平台都没有找到合适的API代码。其中有一篇不错的文章https://www.pythonheidong.com/blog/article/851830/44a42d351037d307d02d/
    可惜代码版本过于“久远”,部分函数已经不适用了。本文以一种简单粗暴的方式制作与detect.py功能一样的API,即使源码更新,按照我的方法也能快速制作一个API供其他程序调用。

    一、总体思路

    其他程序调用yolo,实际上就是把图像传给detect.py。为了最大化实现detect.py的所有功能,最直接的方式是摄像头或者视频流把帧图像存储在‘date/images’目录中,然后把帧图像从‘runs/detect/exp’中读取出来。这种方法增加了处理时间,不过实测存储和读取图像这部分的延迟很低,即便是在树莓派上。

    二、制作detect类

    在detect.py中添加以下代码

    class DetectAPI:
        def __init__(self, weights='weights/yolov5s.pt', data='data/coco128.yaml', imgsz=None, conf_thres=0.25,
                     iou_thres=0.45, max_det=1000, device='0', view_img=False, save_txt=False,
                     save_conf=False, save_crop=False, nosave=False, classes=None, agnostic_nms=False, augment=False,
                     visualize=False, update=False, project='runs/detect', name='myexp', exist_ok=False, line_thickness=3,
                     hide_labels=False, hide_conf=False, half=False, dnn=False):
    
            if imgsz is None:
                self.imgsz = [640, 640]
            self.weights = weights
            self.data = data
            self.source = 'data/myimages'
            self.imgsz = [640, 640]
            self.conf_thres = conf_thres
            self.iou_thres = iou_thres
            self.max_det = max_det
            self.device = device
            self.view_img = view_img
            self.save_txt = save_txt
            self.save_conf = save_conf
            self.save_crop = save_crop
            self.nosave = nosave
            self.classes = classes
            self.agnostic_nms = agnostic_nms
            self.augment = augment
            self.visualize = visualize
            self.update = update
            self.project = project
            self.name = name
            self.exist_ok = exist_ok
            self.line_thickness = line_thickness
            self.hide_labels = hide_labels
            self.hide_conf = hide_conf
            self.half = half
            self.dnn = dnn
    
        def run(self):
            source = str(self.source)
            save_img = not self.nosave and not source.endswith('.txt')  # save inference images
            is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
            is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
            webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file)
            if is_url and is_file:
                source = check_file(source)  # download
    
            # Directories
            save_dir = increment_path(Path(self.project) / self.name, exist_ok=self.exist_ok)  # increment run
            (save_dir / 'labels' if self.save_txt else save_dir).mkdir(parents=True, exist_ok=True)  # make dir
    
            # Load model
            device = select_device(self.device)
            model = DetectMultiBackend(self.weights, device=device, dnn=self.dnn, data=self.data)
            stride, names, pt, jit, onnx, engine = model.stride, model.names, model.pt, model.jit, model.onnx, model.engine
            imgsz = check_img_size(self.imgsz, s=stride)  # check image size
    
            # Half
            self.half &= (pt or jit or onnx or engine) and device.type != 'cpu'  # FP16 supported on limited backends with CUDA
            if pt or jit:
                model.model.half() if self.half else model.model.float()
    
            # Dataloader
            if webcam:
                view_img = check_imshow()
                cudnn.benchmark = True  # set True to speed up constant image size inference
                dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt)
                bs = len(dataset)  # batch_size
            else:
                dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt)
                bs = 1  # batch_size
            vid_path, vid_writer = [None] * bs, [None] * bs
    
            # Run inference
            model.warmup(imgsz=(1, 3, *imgsz), half=self.half)  # warmup
            dt, seen = [0.0, 0.0, 0.0], 0
            for path, im, im0s, vid_cap, s in dataset:
                t1 = time_sync()
                im = torch.from_numpy(im).to(device)
                im = im.half() if self.half else im.float()  # uint8 to fp16/32
                im /= 255  # 0 - 255 to 0.0 - 1.0
                if len(im.shape) == 3:
                    im = im[None]  # expand for batch dim
                t2 = time_sync()
                dt[0] += t2 - t1
    
                # Inference
                visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if self.visualize else False
                pred = model(im, augment=self.augment, visualize=visualize)
                t3 = time_sync()
                dt[1] += t3 - t2
    
                # NMS
                pred = non_max_suppression(pred, self.conf_thres, self.iou_thres, self.classes, self.agnostic_nms,
                                           max_det=self.max_det)
                dt[2] += time_sync() - t3
    
                # Second-stage classifier (optional)
                # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
    
                # Process predictions
                for i, det in enumerate(pred):  # per image
                    seen += 1
                    if webcam:  # batch_size >= 1
                        p, im0, frame = path[i], im0s[i].copy(), dataset.count
                        s += f'{i}: '
                    else:
                        p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
    
                    p = Path(p)  # to Path
                    save_path = str(save_dir / p.name)  # im.jpg
                    txt_path = str(save_dir / 'labels' / p.stem) + (
                        '' if dataset.mode == 'image' else f'_{frame}')  # im.txt
                    s += '%gx%g ' % im.shape[2:]  # print string
                    gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh
                    imc = im0.copy() if self.save_crop else im0  # for save_crop
                    annotator = Annotator(im0, line_width=self.line_thickness, example=str(names))
                    if len(det):
                        # Rescale boxes from img_size to im0 size
                        det[:, :4] = scale_coords(im.shape[2:], det[:, :4], im0.shape).round()
    
                        # Print results
                        for c in det[:, -1].unique():
                            n = (det[:, -1] == c).sum()  # detections per class
                            s += f"{n} {names[int(c)]}{'s' * (n > 1)}, "  # add to string
    
                        mylabel = []
                        # Write results
                        for *xyxy, conf, cls in reversed(det):
                            if self.save_txt:  # Write to file
                                xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh
                                line = (cls, *xywh, conf) if self.save_conf else (cls, *xywh)  # label format
                                with open(txt_path + '.txt', 'a') as f:
                                    f.write(('%g ' * len(line)).rstrip() % line + '\n')
    
                            if save_img or self.save_crop or self.view_img:  # Add bbox to image
                                c = int(cls)  # integer class
                                label = None if self.hide_labels else (names[c] if self.hide_conf else f'{names[c]} {conf:.2f}')
                                mylabel.append(str(label))
                                annotator.box_label(xyxy, label, color=colors(c, True))
                                if self.save_crop:
                                    save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
    
                    # Print time (inference-only)
                    LOGGER.info(f'{s}Done. ({t3 - t2:.3f}s)')
    
                    # Stream results
                    im0 = annotator.result()
                    if self.view_img:
                        cv2.imshow(str(p), im0)
                        cv2.waitKey(1)  # 1 millisecond
    
                    # Save results (image with detections)
                    if save_img:
                        if dataset.mode == 'image':
                            cv2.imwrite(save_path, im0)
                        else:  # 'video' or 'stream'
                            if vid_path[i] != save_path:  # new video
                                vid_path[i] = save_path
                                if isinstance(vid_writer[i], cv2.VideoWriter):
                                    vid_writer[i].release()  # release previous video writer
                                if vid_cap:  # video
                                    fps = vid_cap.get(cv2.CAP_PROP_FPS)
                                    w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                                    h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                                else:  # stream
                                    fps, w, h = 30, im0.shape[1], im0.shape[0]
                                    save_path += '.mp4'
                                vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
                            vid_writer[i].write(im0)
    
            # Print results
            t = tuple(x / seen * 1E3 for x in dt)  # speeds per image
            LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
            if self.save_txt or save_img:
                s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if self.save_txt \
                    else ''
                LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
            if self.update:
                strip_optimizer(self.weights)  # update model (to fix SourceChangeWarning)
    
            return mylabel
            
    

    代码与run函数基本相同。run函数的思路是加载模型和图片,进行模型预测和推理。类中run函数修改了images目录,可自行修改。函数会返回识别到的物体标签以及对应的置信度,可用于其他处理。

    二、调用detect类

    下面给出使用这个API的一个例程,需要将yolov5源码文件夹放到程序目录中。

    import cv2
    import yolov5-master.detect
    import os
    
    video_capture = cv2.VideoCapture(0)
    detect_api = yolov5-master.detect.DetectAPI(exist_ok=True)
    
    while True:
    	k = cv2.waitKey(1)
        ret, frame = video_capture.read()
        
        path = '你的目录/yolov5-master/data/myimages'
        cv2.imwrite(os.path.join(path, 'test.jpg'), frame)
        
        label = detect_api.run()
        print(str(label))
        
        image = cv2.imread('你的目录/yolov5-master/runs/detect/myexp/test.jpg', flags=1)
        cv2.imshow("video", image)
    
        if k == 27:  # 按下ESC退出窗口
            break
    
    video_capture.release()
    
    

    实例化对象中参数exist_ok=True的作用是生成的exp目录会自行覆盖,不会有后面的exp1、exp2、exp3等,方便用于实时处理。

    结语

    本文假设你已经可以成功跑detect.py的基础上再去制作API接口。在使用IP摄像头或者视频流时,修改实例化中的参数即可。

    物联沃分享整理
    物联沃-IOTWORD物联网 » yolov5 python API(供其他程序调用)

    发表评论