编辑器

提示:环境ubuntu18.04 + anaconda下python3.8+torch1.9

ros中使用yolov5

  • 前言
  • 一、先将yolov5封装
  • 二、步骤
  • 1.创建一个新的脚本
  • 2.修改detect#
  • 修改yolov5/utils/dataset.py
  • 再次修改detect.py
  • 结果

  • 前言

    提示:参考的博客:
    封装yolov5: https://www.pythonheidong.com/blog/article/851830/44a42d351037d307d02d/


    提示:以下是我的思路,下面案例可供参考

    一、先将yolov5封装

    从github上下载ultralytics的yolov5。
    修改ultralytics/yolov5中detect.py源代码,使其可以通过import由其他python程序调用。

    二、步骤

    1.创建一个新的脚本

    代码如下(示例):

    import cv2
    import detect
    cap=cv2.VideoCapture(0)
    a=detect.detectapi(weights='weights/yolov5s.pt')
    while True:
    
        rec,img = cap.read()
    
        result,names =a.detect([img])
        img=result[0][0] #第一张图片的处理结果图片
        '''
        for cls,(x1,y1,x2,y2),conf in result[0][1]: #第一张图片的处理结果标签。
            print(cls,x1,y1,x2,y2,conf)
            cv2.rectangle(img,(x1,y1),(x2,y2),(0,255,0))
            cv2.putText(img,names[cls],(x1,y1-20),cv2.FONT_HERSHEY_DUPLEX,1.5,(255,0,0))
        '''
        cv2.imshow("vedio",img)
    
        if cv2.waitKey(1)==ord('q'):
            break
    

    2.修改detect#

    直接将下面的代码增加到detect.py中,无需修改原来的代码(可以直接创建新的程序复制原来的代码):

    # 增加运行参数,原来的参数是通过命令行解析对象提供的,这里改为由调用者在代码中提供。需要一个
    # 大体上完成一样功能的参数对象。
    # 我想要的功能是传一组由cv2读取的图片,交给api,然后得到一组打上标签的图片,以及每张图片对应的标签类别引索,位置信息,置信度的信息,还有类别名称字典
    # 要实现这个功能,需要权重文件,输入文件两个参数,其他参数与原代码命令行默认参数保持一致就行。
    class simulation_opt:# 参数对象。
    
        def __init__(self,weights,img_size=640,conf_thres=0.25,iou_thres=0.45,device='',view_img=False,
                     classes=None,agnostic_nms=False,augment=False,update=False,exist_ok=False):
            self.weights=weights
            self.source=None
            self.img_size=img_size
            self.conf_thres=conf_thres
            self.iou_thres=iou_thres
            self.device=device
            self.view_img=view_img
            self.classes=classes
            self.agnostic_nms=agnostic_nms
            self.augment=augment
            self.update=update
            self.exist_ok=exist_ok
    
    #增加一个新类,这个新类是在原来detect函数上进行删减。可以先复制原来的detect函数代码,再着手修改
    class detectapi:
        def __init__(self,weights,img_size=640):
            # 构造函数中先做好必要的准备,如初始化参数,加载模型
            ''' 删掉
            source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
            webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
            ('rtsp://', 'rtmp://', 'http://'))
            ''' #改为
            self.opt=simulation_opt(weights=weights,img_size=img_size)
            weights, imgsz= self.opt.weights, self.opt.img_size
    
            ''' 删掉
            # Directories
            #save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))  # increment run
            #(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True)  # make dir
            '''
        # Initialize
            set_logging()
            self.device = select_device(self.opt.device)
            self.half = self.device.type != 'cpu'  # half precision only supported on CUDA
    
        # Load model
            self.model = attempt_load(weights, map_location=self.device)  # load FP32 model
            self.stride = int(self.model.stride.max())  # model stride
            self.imgsz = check_img_size(imgsz, s=self.stride)  # check img_size
            if self.half:
                self.model.half()  # to FP16
    
        # Second-stage classifier
            self.classify = False
            if self.classify:
                self.modelc = load_classifier(name='resnet101', n=2)  # initialize
                self.modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=self.device)['model']).to(self.device).eval()
            '''
            self.names,和self.colors是由后面的代码拉到这里来的。names是类别名称字典,colors是画框时用到的颜色。
            '''
        # read names and colors
            self.names = self.model.module.names if hasattr(self.model, 'module') else self.model.names
            self.colors = [[random.randint(0, 255) for _ in range(3)] for _ in self.names]
    
    
        def detect(self,source): # 使用时,调用这个函数
            if type(source)!=list:
                    raise TypeError('source must be a list which contain  pictures read by cv2')
    
    
            '''删掉
            if webcam:
            view_img = check_imshow()
            cudnn.benchmark = True  # set True to speed up constant image size inference
            dataset = LoadStreams(source, img_size=imgsz, stride=stride)
            else:
            save_img = True
            dataset = LoadImages(source, img_size=imgsz, stride=stride)
    
            
            '''# 改为
            # Set Dataloader
            dataset = MyLoadImages(source, img_size=self.imgsz, stride=self.stride)
            # 原来是通过路径加载数据集的,现在source里面就是加载好的图片,所以数据集对象的实现要
            # 重写。修改代码后附。在utils.dataset.py上修改。
    
            '''移动到构造方法末尾。names是类别名称字典,colors是画框时用到的颜色。
            names = model.module.names if hasattr(model, 'module') else model.names
            colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]
            '''
        # Run inference
            if self.device.type != 'cpu':
                self.model(torch.zeros(1, 3, self.imgsz, self.imgsz).to(self.device).type_as(next(self.model.parameters())))  # run once
            result=[]
            ''' 删掉
            for path, img, im0s, vid_cap in dataset: 因为不用保存,所以path可以不要,因为不处理视频,所以vid_cap不要。
            ''' #改为
            for img, im0s in dataset:
                img = torch.from_numpy(img).to(self.device)
                img = img.half() if self.half else img.float()  # uint8 to fp16/32
                img /= 255.0  # 0 - 255 to 0.0 - 1.0
                if img.ndimension() == 3:
                    img = img.unsqueeze(0)
    
                # Inference
                # t1 = time_synchronized() #计算预测用时的,可以不要
                pred = self.model(img, augment=self.opt.augment)[0]
    
                # Apply NMS
                pred = non_max_suppression(pred, self.opt.conf_thres, self.opt.iou_thres, classes=self.opt.classes, agnostic=self.opt.agnostic_nms)
                # t2 = time_synchronized() #计算预测用时的,可以不要
    
                # Apply Classifier
                if self.classify:
                    pred = apply_classifier(pred, self.modelc, img, im0s)
    
    
                '''删掉
                 for i, det in enumerate(pred):  # detections per image
                if webcam:  # batch_size >= 1
                    p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count
                else:
                    p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0)
    
                p = Path(p)  # to Path
                save_path = str(save_dir / p.name)  # img.jpg
                txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}')  # img.txt
                s += '%gx%g ' % img.shape[2:]  # print string
                gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh
                if len(det):
                    # Rescale boxes from img_size to im0 size
                    det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
    
                    # Print results
                    for c in det[:, -1].unique():
                        n = (det[:, -1] == c).sum()  # detections per class
                        s += f"{n} {names[int(c)]}{'s' * (n > 1)}, "  # add to string
    
                    # Write results
                    for *xyxy, conf, cls in reversed(det):
                        if save_txt:  # Write to file
                            xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh
                            line = (cls, *xywh, conf) if opt.save_conf else (cls, *xywh)  # label format
                            with open(txt_path + '.txt', 'a') as f:
                                f.write(('%g ' * len(line)).rstrip() % line + '\n')
    
                        if save_img or view_img:  # Add bbox to image
                            label = f'{names[int(cls)]} {conf:.2f}'
                            plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3)
                ''' # 改为
                # Process detections
                det=pred[0] #原来的情况是要保持图片,因此多了很多关于保持路径上的处理。另外,pred
                # 其实是个列表。元素个数为batch_size。由于对于我这个api,每次只处理一个图片,
                # 所以pred中只有一个元素,直接取出来就行,不用for循环。
                im0 = im0s.copy() # 这是原图片,与被传进来的图片是同地址的,需要copy一个副本,否则,原来的图片会受到影响
                # s += '%gx%g ' % img.shape[2:]  # print string
                # gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh
                result_txt = []
                # 对于一张图片,可能有多个可被检测的目标。所以结果标签也可能有多个。
                # 每被检测出一个物体,result_txt的长度就加一。result_txt中的每个元素是个列表,记录着
                # 被检测物的类别引索,在图片上的位置,以及置信度
                if len(det):
                    # Rescale boxes from img_size to im0 size
                    det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
    
                    # Print results
                    '''
                    for c in det[:, -1].unique():
                        n = (det[:, -1] == c).sum()  # detections per class
                        s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, "  # add to string
                    '''
                    # Write results
    
                    for *xyxy, conf, cls in reversed(det):
    
                        #xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh
                        line = (int(cls.item()), [int(_.item()) for _ in xyxy], conf.item())  # label format
                        result_txt.append(line)
                        label = f'{self.names[int(cls)]} {conf:.2f}'
                        plot_one_box(xyxy, im0, label=label, color=self.colors[int(cls)], line_thickness=3)
                result.append((im0,result_txt)) # 对于每张图片,返回画完框的图片,以及该图片的标签列表。
            return result, self.names
    

    到这里detect.py中还有一些库没有引入

    修改yolov5/utils/dataset.py

    将下列代码直接复制

    class MyLoadImages:  # for inference
        def __init__(self, path, img_size=640, stride=32):
            for img in path:
                if type(img)!=np.ndarray or len(img.shape)!=3:
                    raise TypeError('there is a object which is not a picture read by cv2 in source')
    
            '''
            p = str(Path(path).absolute())  # os-agnostic absolute path
            if '*' in p:
                files = sorted(glob.glob(p, recursive=True))  # glob
            elif os.path.isdir(p):
                files = sorted(glob.glob(os.path.join(p, '*.*')))  # dir
            elif os.path.isfile(p):
                files = [p]  # files
            else:
                raise Exception(f'ERROR: {p} does not exist')
    
            images = [x for x in files if x.split('.')[-1].lower() in img_formats]
            videos = [x for x in files if x.split('.')[-1].lower() in vid_formats]
            ni, nv = len(images), len(videos)
            '''
            self.img_size = img_size
            self.stride = stride
            self.files = path
            self.nf = len(path)
            #self.video_flag = [False] * ni + [True] * nv
            self.mode = 'image'
            #if any(videos):
                #self.new_video(videos[0])  # new video
            #else:
                #self.cap = None
            #assert self.nf > 0, f'No images or videos found in {p}. ' \
                                #f'Supported formats are:\nimages: {img_formats}\nvideos: {vid_formats}'
    
        def __iter__(self):
            self.count = 0
            return self
    
        def __next__(self):
            if self.count == self.nf:
                raise StopIteration
            path = self.files[self.count]
            '''
            if self.video_flag[self.count]:
                # Read video
                self.mode = 'video'
                ret_val, img0 = self.cap.read()
                if not ret_val:
                    self.count += 1
                    self.cap.release()
                    if self.count == self.nf:  # last video
                        raise StopIteration
                    else:
                        path = self.files[self.count]
                        self.new_video(path)
                        ret_val, img0 = self.cap.read()
    
                self.frame += 1
                print(f'video {self.count + 1}/{self.nf} ({self.frame}/{self.nframes}) {path}: ', end='')
            '''
            # Read image
            self.count += 1
            #img0 = cv2.imread(path)  # BGR
            #assert img0 is not None, 'Image Not Found ' + path
            #print(f'image {self.count}/{self.nf} {path}: ', end='')
    
            # Padded resize
            img = letterbox(path, self.img_size, stride=self.stride)[0]
    
            # Convert
            img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
            img = np.ascontiguousarray(img)
    
            return img, path
    

    再次修改detect.py

    回到刚才的detect.py,还有一些模块没有引入
    set_logging和apply_classifier在utils/general.py中

    from utils.general import *
    

    attempt_load在models/experimental.py中

    from models.experimental import attempt_load
    

    load_classifier好像没有,所以我在utils/torch_utils.py中添加下列代码再引入

    def load_classifier(name='resnet101', n=2):
        # Loads a pretrained model reshaped to n-class output
        model = torchvision.models.__dict__[name](pretrained=True)
    
        # ResNet model properties
        # input_size = [3, 224, 224]
        # input_space = 'RGB'
        # input_range = [0, 1]
        # mean = [0.485, 0.456, 0.406]
        # std = [0.229, 0.224, 0.225]
    
        # Reshape output to n classes
        filters = model.fc.weight.shape[1]
        model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True)
        model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True)
        model.fc.out_features = n
        return model
    

    结果

    直接在你的yolov5虚拟环境下运行第一段代码,可以直接查看检测结果。
    这是我设想yolov5接入ros的第一步,下一步是将输入数据改为ros中的信息。


    来源:山上的稻草人

    物联沃分享整理
    物联沃-IOTWORD物联网 » ros中使用yolov5(1)

    发表评论