ModelNet10/40数据集的下载及dataset代码分析

1.初识ModelNet

ModelNet10/40是一个3d图像分类的一个数据集,它里面的图像全部都是CAD手工绘制的点云图像。下面讲解一下如何下载,并打开数据集。
1.去它的官网下载你想要的数据集,然后解压。网址为:http://modelnet.cs.princeton.edu/
2.解压完成后,你会发现是.off的文件,无法直接打开。

3.这个时候需要下载一个软件,交meshlab,才能打开.off文件,像正常2d图像一样去显示出来。
下载的网站为:https://www.meshlab.net/#download

根据自己的电脑版本选择下载,然后安装。
4. 找到之前下载的数据集,然后右键打开方式,选meshlab


打开之后就是如下的3d图像,长按左键可以更改视图

2.数据集简介和dataset代码分析

ModelNet是一个很基础的点云图像分类数据集,在pointnet和pointnet++都有使用这个数据集进行分类、下面简要分析一下它的代码,代码连接为(https://github.com/yanx27/Pointnet_Pointnet2_pytorch),并展示一下如何读入3d点云数据集。

2.1 数据集下载


首先先去github上将代码下载下来,然后阅读readme文件,查看应该如何运行代码。
通过readme可以发现,使用离线的处理后的数据可以点击连接下载,然后把下载好的数据集放到指定的文件夹data/modelnet40_normal_resampled 下面。

然后在train_classficiation 下填写指定的参数就可以运行训练程序了。

2.2 离线的数据集


下载完成后,在对于的文件夹下就可以找到我们的数据集,它们都变成了tex文本,每个文件夹下放的都是一个类别。例如,上图打开的就是飞机这个类别的点云数据。每一个txt文件相当于2d图像中的一张图像。

继续打开可以发现它有很多行数据,每行有6个数据。这里的每行数据表示该图像有多少个点组成,例如1万行就是由1万个点组成的,每行数据的前三个是这个点在空间中的xyz坐标,后三个是这个点的颜色信息。至于他们其中为什么有负数,是因为他们都被压缩到-1,1这个区间内了。


这个路径下接着往下可以看到有一些单独的txt文件。
名字决定命运,有名字就可以看出来他们存放的是训练和测试使用的图像的名称,然后shape_name就是该数据集中模型的名称。

2.3 dataset代码分析


找到文件夹中的modelnet dataloader,并划到最下面,将路径更改为你存放数据集的路径,我们首先查看一些modelnet 的dataloader(dataset被封装在里面)会给我什么。

可以发现,当我们向dataloader索要数据时,它会给我们返回两个值,一个是图像数据,第二个是它的标签。其中12表示minibatch为12,1024表示将点云数据降采样到1024个点,6就是它的xyz和rgb特征信息。每一张图像都会有一个标签表示,用来区分它们的类别。

当我们向dataloader索要数据集,dataloader会找到dataset,通过它来获取我们的数据,并将它打包至我们设定的批次的大小。
dataset主要的功能就是根据路径去读取图片,并对它进行一些预处理。self.__getitem__和self.__len__是它最重要的两个方法,第一个是告诉它从哪里去读取图片,第二个是告诉它数据集有多少张图片。有了这两个参数,dataset就会根据我们的需求去读取图片。
self.__getitem__方法默认有一个index参数,用来标记读的是第几张图片,这个参数是默认的。从上图可以看到,我们将图片的路径都存在了self.datapath下,然后根据index去读取对应索引处的图片。
ModelNet dataset中init部分就是告诉dataset从那读图片,其余重要的代码我都进行了注释,还有不懂的地方可以根据单步调试查看更多的细节。

'''
@author: Xu Yan
@file: ModelNet.py
@time: 2021/3/19 15:51
'''
import os
import numpy as np
import warnings
import pickle

from tqdm import tqdm
from torch.utils.data import Dataset

warnings.filterwarnings('ignore')


def pc_normalize(pc):
    centroid = np.mean(pc, axis=0)
    pc = pc - centroid
    m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
    pc = pc / m
    return pc


def farthest_point_sample(point, npoint):
    """
    Input:
        xyz: pointcloud data, [N, D]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [npoint, D]
    """
    N, D = point.shape
    xyz = point[:,:3]
    centroids = np.zeros((npoint,))
    distance = np.ones((N,)) * 1e10
    farthest = np.random.randint(0, N)
    for i in range(npoint):
        centroids[i] = farthest
        centroid = xyz[farthest, :]
        dist = np.sum((xyz - centroid) ** 2, -1)
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = np.argmax(distance, -1)
    point = point[centroids.astype(np.int32)]
    return point


class ModelNetDataLoader(Dataset):
    def __init__(self, root, args, split='train', process_data=False):
        self.root = root # 数据集根路径
        self.npoints = args.num_point  # 采样点的数量
        self.process_data = process_data  # 将txt文件转换为.dat文件,仅在第一次读取的时候执行
        self.uniform = args.use_uniform_sample #  使用FPS(最远点采样)下采样数据
        self.use_normals = args.use_normals  # 是否使用RGB颜色信息
        self.num_category = args.num_category  # 选择数据集的类型,如 10分类和40分类

        if self.num_category == 10:
            self.catfile = os.path.join(self.root, 'modelnet10_shape_names.txt') # 获取点云的路径
        else:
            self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt')

        self.cat = [line.rstrip() for line in open(self.catfile)]
        # 将分类的名称从txt文件中读出来放入列表,例如 下面读的就是10分类的结果
        # ['bathtub', 'bed', 'chair', 'desk', 'dresser', 'monitor', 'night_stand', 'sofa', 'table', 'toilet']
        self.classes = dict(zip(self.cat, range(len(self.cat))))  # 将类别名字和索引对应起来放入自动  例如 0<--->airplane

        shape_ids = {}
        if self.num_category == 10:  # 将待训练的点云的名字从之前划分好的txt文件中读取出来。即将训练 测试 集中点云的名字拿出来 放入shae id字典 如 batchub_001
            shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_train.txt'))]
            shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_test.txt'))]
        else:
            shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))]
            shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))]

        assert (split == 'train' or split == 'test')
        shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]] # 去除shape_ids中的下划线,拿到图像的名称存入列表  如  batchub_001--->bathtub
        self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i
                         in range(len(shape_ids[split]))]
        # 遍历之前的shapeids , 将路径进行组合,找到对于图像对应的路径,并生成一个元组,将其存入self.datapath。其中,第一个元素是它的名称,第二个元素是它对于的路径。详情如下
        # ('bathtub', 'D:\\1Apython\\Pycharm_pojie\\3d\\Pointnet_Pointnet2_pytorch-master\\data\\modelnet40_normal_resampled\\bathtub\\bathtub_0001.txt')
        print('The size of %s data is %d' % (split, len(self.datapath)))

        if self.uniform:
            self.save_path = os.path.join(root, 'modelnet%d_%s_%dpts_fps.dat' % (self.num_category, split, self.npoints))
        else:
            self.save_path = os.path.join(root, 'modelnet%d_%s_%dpts.dat' % (self.num_category, split, self.npoints))

        if self.process_data:
            "第一次运行会处理一些数据,将txt文件转换为.dat 文件 存入 self.save中"
            if not os.path.exists(self.save_path):
                print('Processing data %s (only running in the first time)...' % self.save_path)
                self.list_of_points = [None] * len(self.datapath)
                self.list_of_labels = [None] * len(self.datapath)

                for index in tqdm(range(len(self.datapath)), total=len(self.datapath)):
                    fn = self.datapath[index]
                    cls = self.classes[self.datapath[index][0]]
                    cls = np.array([cls]).astype(np.int32)
                    point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)

                    if self.uniform:
                        point_set = farthest_point_sample(point_set, self.npoints)
                    else:
                        point_set = point_set[0:self.npoints, :]

                    self.list_of_points[index] = point_set
                    self.list_of_labels[index] = cls

                with open(self.save_path, 'wb') as f:
                    pickle.dump([self.list_of_points, self.list_of_labels], f)
            else:
                print('Load processed data from %s...' % self.save_path)
                with open(self.save_path, 'rb') as f:
                    self.list_of_points, self.list_of_labels = pickle.load(f)

    def __len__(self):
        return len(self.datapath)

    def _get_item(self, index):
        if self.process_data:
            # self.process_data就第一次执行的时候会 为True ,大多数情况都是执行下面的代码
            point_set, label = self.list_of_points[index], self.list_of_labels[index]
        else:
            fn = self.datapath[index] # 从self.datapath中获取点云数据,格式为(‘类别名称’,路径)
            cls = self.classes[self.datapath[index][0]] # 根据索引和类别名称的对应关系 将名称与索引对应起来 例如  airplane<---->0
            label = np.array([cls]).astype(np.int32) # 转换为数组形式
            point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32) # 使用np读点云数据 点云数据结果为 10000x6
            """
            读取txt文件我们通常使用 numpy 中的 loadtxt()函数

            numpy.loadtxt(fname, dtype=, comments='#', delimiter=None, converters=None, skiprows=0, usecols=None, unpack=False, ndmin=0)

            注:loadtxt的功能是读入数据文件,这里的数据文件要求每一行数据的格式相同。 delimiter:数据之间的分隔符。如使用逗号","。
            """
            if self.uniform:  # 默认为False,没有使用FPS算法进行筛选,降采样到self.npoints个点
                point_set = farthest_point_sample(point_set, self.npoints)
            else:  # 取前1024个数
                point_set = point_set[0:self.npoints, :]
                
        point_set[:, 0:3] = pc_normalize(point_set[:, 0:3]) # 只对点的前三个维度进行归一化,即坐标的归一化
        if not self.use_normals: # 如果不使用rgb信息 ,则 返只返回点云的xyz信息
            point_set = point_set[:, 0:3]

        return point_set, label[0] # 返回读取到的 经过降采样的点云数据 和标签

    def __getitem__(self, index):
        return self._get_item(index) # 调用self._get_item获取点云数据


if __name__ == '__main__':
    import torch
    import argparse

    def parse_args():
        '''PARAMETERS'''
        parser = argparse.ArgumentParser('training')
        parser.add_argument('--use_cpu', action='store_true', default=False, help='use cpu mode')
        parser.add_argument('--gpu', type=str, default='0', help='specify gpu device')
        parser.add_argument('--batch_size', type=int, default=4, help='batch size in training')
        parser.add_argument('--model', default='pointnet2_cls_ssg', help='model name [default: pointnet_cls]')
        parser.add_argument('--num_category', default=10, type=int, choices=[10, 40], help='training on ModelNet10/40')
        parser.add_argument('--epoch', default=200, type=int, help='number of epoch in training')
        parser.add_argument('--learning_rate', default=0.001, type=float, help='learning rate in training')
        parser.add_argument('--num_point', type=int, default=1024, help='Point Number')
        parser.add_argument('--optimizer', type=str, default='Adam', help='optimizer for training')
        parser.add_argument('--log_dir', type=str, default=None, help='experiment root')
        parser.add_argument('--decay_rate', type=float, default=1e-4, help='decay rate')
        parser.add_argument('--use_normals', action='store_true', default=True, help='use normals')
        parser.add_argument('--process_data', action='store_true', default=True, help='save data offline') # 处理数据数据
        parser.add_argument('--use_uniform_sample', action='store_true', default=False, help='use uniform sampiling')
        return parser.parse_args()

    args = parse_args()
    root = r'D:\1Apython\Pycharm_pojie\3d\Pointnet_Pointnet2_pytorch-master\data\modelnet40_normal_resampled'
    data = ModelNetDataLoader(root=root,args=args, split='train')
    DataLoader = torch.utils.data.DataLoader(data, batch_size=12, shuffle=True)
    for point, label in DataLoader:
        print('point.shape:\n',point.shape)
        print('label.shape:\n',label.shape)

来源:正在学习的浅语

物联沃分享整理
物联沃-IOTWORD物联网 » ModelNet10/40数据集的下载及dataset代码分析

发表评论