torchvision.datasets.CIFAR10模块使用讲解

torchvision.dataset数据集

所有的数据集都是torch.utils.data.Dataset的子类, 它们实现了__getitem__和__len__方法。因此,它们都可以传递给torch.utils.data.DataLoader.

 transform = transforms.Compose(
  [transforms.ToTensor(),
 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

 trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
  download=True, transform=transform)
 trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
  shuffle=True, num_workers=2)

目前为止,收录的数据集包括:

__all__ = ('LSUN', 'LSUNClass',
           'ImageFolder', 'DatasetFolder', 'FakeData',
           'CocoCaptions', 'CocoDetection',
           'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', 'QMNIST',
           'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
           'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k',
           'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet',
           'Caltech101', 'Caltech256', 'CelebA', 'WIDERFace', 'SBDataset',
           'VisionDataset', 'USPS', 'Kinetics400', "Kinetics", 'HMDB51', 'UCF101',
           'Places365', 'Kitti', "INaturalist", "LFWPeople", "LFWPairs"
           )

torchvision.datasets.CIFAR10模块

CIFAR-10和CIFAR-100为8000万张微小图像数据集的子集。它们是由 Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton.收集的。
CIFAR-10数据集包含10类60000幅32×32彩色图像,每个类6000幅图像。训练图像50000张,测试图像10000张。该数据集被分为5个训练批和1个测试批,每个批包含10000张图像。测试批正好包含从每个类中随机选择的1000张图像。训练批以随机顺序包含剩余的图像,但有些训练批可能包含一个类的图像多于另一个类。在它们之间,训练批恰好包含来自每个类的5000张图像。
以下是数据集中的类,以及每个类的10张随机图片:

这些类是完全互斥的。汽车和卡车之间没有重叠。“汽车”包括轿车、越野车之类的东西。“卡车”只包括大卡车。这两项都不包括皮卡。

torchvision.datasets.CIFAR10源码

class CIFAR10(VisionDataset):
    """`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.

    Args:
        root (string): Root directory of dataset where directory
            ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
        train (bool, optional): If True, creates dataset from training set, otherwise
            creates from test set.
        transform (callable, optional): A function/transform that takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.

    """
    base_folder = 'cifar-10-batches-py'
    url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
    filename = "cifar-10-python.tar.gz"
    tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
    train_list = [
        ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
        ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
        ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
        ['data_batch_4', '634d18415352ddfa80567beed471001a'],
        ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
    ]

    test_list = [
        ['test_batch', '40351d587109b95175f43aff81a1287e'],
    ]
    meta = {
        'filename': 'batches.meta',
        'key': 'label_names',
        'md5': '5ff9c542aee3614f3951f8cda6e48888',
    }

    def __init__(
            self,
            root: str,
            train: bool = True,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            download: bool = False,
    ) -> None:

        super(CIFAR10, self).__init__(root, transform=transform,
                                      target_transform=target_transform)

        self.train = train  # training set or test set

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

        if self.train:
            downloaded_list = self.train_list
        else:
            downloaded_list = self.test_list

        self.data: Any = []
        self.targets = []

        # now load the picked numpy arrays
        for file_name, checksum in downloaded_list:
            file_path = os.path.join(self.root, self.base_folder, file_name)
            with open(file_path, 'rb') as f:
                entry = pickle.load(f, encoding='latin1')
                self.data.append(entry['data'])
                if 'labels' in entry:
                    self.targets.extend(entry['labels'])
                else:
                    self.targets.extend(entry['fine_labels'])

        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
        self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC

        self._load_meta()

    def _load_meta(self) -> None:
        path = os.path.join(self.root, self.base_folder, self.meta['filename'])
        if not check_integrity(path, self.meta['md5']):
            raise RuntimeError('Dataset metadata file not found or corrupted.' +
                               ' You can use download=True to download it')
        with open(path, 'rb') as infile:
            data = pickle.load(infile, encoding='latin1')
            self.classes = data[self.meta['key']]
        self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self) -> int:
        return len(self.data)

    def _check_integrity(self) -> bool:
        root = self.root
        for fentry in (self.train_list + self.test_list):
            filename, md5 = fentry[0], fentry[1]
            fpath = os.path.join(root, self.base_folder, filename)
            if not check_integrity(fpath, md5):
                return False
        return True

    def download(self) -> None:
        if self._check_integrity():
            print('Files already downloaded and verified')
            return
        download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)

    def extra_repr(self) -> str:
        return "Split: {}".format("Train" if self.train is True else "Test")

参数说明

root (string):数据集所在目录的根目录 如果download设置为True。“cifar-10-batches-py '”存在,则将被保存至该目录
train :如果为True,则从训练集创建数据集,否则从测试集创建。
transform::(bool,可选)一个接受PIL图像的函数/变换 并返回转换后的版本。 如”transforms“,“RandomCrop”
target_transform:(可调用,可选):一个作用于目的转换函数。
download (bool,可选):如果为true,则从internet下载数据集 ,将其放在根目录中。 如果数据集已经下载,则不会 再次下载。

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
 download=True, transform=transform)

来源:liu_jie_bin

物联沃分享整理
物联沃-IOTWORD物联网 » torchvision.datasets.CIFAR10模块使用讲解

发表评论