PyTorch数据归一化处理:transforms.Normalize及计算图像数据集的均值和方差

PyTorch数据归一化处理:transforms.Normalize及计算图像数据集的均值和方差

  • 1.数据归一化处理:transforms.Normalize
  • 1.1 理解torchvision
  • 1.2 数据标准化Normalize
  • 2.计算图像数据集的均值和方差
  • 2.1 使用PyTorch计算图像数据集的均值和方差(推荐)
  • 2.2 使用opencv和numpy计算图像数据集的均值和方差
  • 2.3 计算某个目录下所有图片的均值和方差
  • 参考资料
  • # Data
    print('==> Preparing data..')
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0)
    
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)
    
    

    1.数据归一化处理:transforms.Normalize

    1.1 理解torchvision

  • torchvision.transforms:常用的图像预处理方法
  • torchvision.datasets:常用的数据集Dataset实现
  • torchvision.models:常用的CV(预训练)模型实现
  • torchvision.transforms:常用的数据预处理方法,提升泛化能力,包括:数据中心化、数据标准化、缩放、裁剪、旋转、翻转、填充、噪声添加、灰度变换、线性变换、仿射变换、亮度、饱和度及对比度变换等

    数据增强又称为数据增广,数据扩增,它是对训练集进行变换,使训练集更丰富,从而让模型更具泛化能力。

    1.2 数据标准化Normalize

    功能:逐channel的对图像进行标准化(均值变为0,标准差变为1),可以加快模型的收敛
    output = (input – mean) / std
    mean:各通道的均值
    std:各通道的标准差
    inplace:是否原地操作

    思考:

    (1)据我所知,归一化就是要把图片3个通道中的数据整理到[-1, 1]区间。
    x = (x – mean(x))/std(x)
    只要输入数据集x确定了,mean(x)和std(x)也就是确定的数值了,为什么Normalize()函数还需要输入mean和std的数值呢????

    (2)RGB单个通道的值是[0, 255],所以一个通道的均值应该在127附近才对。
    如果Normalize()函数去计算 x = (x – mean)/std ,因为RGB是[0, 255],算出来的x就不可能落在[-1, 1]区间了。

    (3)在我看的了论文代码里面是这样的:
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    为什么就确定了这一组数值,这一组数值是怎么来的? 为什么这三个通道的均值都是小于1的值呢?

    理解:

    (1)针对第一个问题,mean 和 std 肯定要在normalize()之前自己先算好再传进去的,不然每次normalize()就得把所有的图片都读取一遍算出mean和std

    (2)针对第二个问题,有两种情况
    (a )如果是imagenet数据集,那么ImageNet的数据在加载的时候就已经转换成了[0, 1].
    (b) 应用了torchvision.transforms.ToTensor,其作用是将数据归一化到[0,1](是将数据除以255),transforms.ToTensor()会把HWC会变成C *H *W(拓展:格式为(h,w,c),像素顺序为RGB)

    (3)针对第三个问题:[0.485, 0.456, 0.406]这一组平均值是从imagenet训练集中抽样算出来的。

    继续有疑问:

    ToTensor 已经[0,1]为什么还要[0.485, 0.456, 0.406]?那么归一化后,为什么还要接一个Normalize()呢?Normalize()是对数据按通道进行标准化,即减去均值,再除以方差

    解答:

    别人的解答:数据如果分布在(0,1)之间,可能实际的bias,就是神经网络的输入b会比较大,而模型初始化时b=0的,这样会导致神经网络收敛比较慢,经过Normalize后,可以加快模型的收敛速度。因为对RGB图片而言,数据范围是[0-255]的,需要先经过ToTensor除以255归一化到[0,1]之后,再通过Normalize计算过后,将数据归一化到[-1,1]。

    是否可以这样理解:[0,1]只是范围改变了, 并没有改变分布,mean和std处理后可以让数据正态分布😂

    2.计算图像数据集的均值和方差

    2.1 使用PyTorch计算图像数据集的均值和方差(推荐)

    Pytorch图像预处理时,通常使用transforms.Normalize(mean, std)对图像按通道进行标准化,即减去均值,再除以方差。这样做可以加快模型的收敛速度。其中参数mean和std分别表示图像每个通道的均值和方差序列。

    Imagenet数据集的均值和方差为:mean=(0.485, 0.456, 0.406)std=(0.229, 0.224, 0.225),因为这是在百万张图像上计算而得的,所以我们通常见到在训练过程中使用它们做标准化。而对于特定的数据集,选择这个值的结果可能并不理想。接下来给出计算特定数据集的均值和方差的方法。

    import torch
    from torchvision.datasets import ImageFolder
    
    
    def getStat(train_data):
        '''
        Compute mean and variance for training data
        :param train_data: 自定义类Dataset(或ImageFolder即可)
        :return: (mean, std)
        '''
        print('Compute mean and variance for training data.')
        print(len(train_data))
        train_loader = torch.utils.data.DataLoader(
            train_data, batch_size=1, shuffle=False, num_workers=0,
            pin_memory=True)
        mean = torch.zeros(3)
        std = torch.zeros(3)
        for X, _ in train_loader:
            for d in range(3):
                mean[d] += X[:, d, :, :].mean()
                std[d] += X[:, d, :, :].std()
        mean.div_(len(train_data))
        std.div_(len(train_data))
        return list(mean.numpy()), list(std.numpy())
    
    
    if __name__ == '__main__':
        train_dataset = ImageFolder(root=r'./data/food/', transform=None)
        print(getStat(train_dataset))
    
    

    ./data/ready_chinese_food/的目录结构如下:

    getState()方法接收一个Dataset类(ImageFolder),然后累加所有图像三个通道的均值和方差,最后除以图像总数并返回。

    这里用食品数据集尚做的测试,测试集返回的结果如下所示:

    Compute mean and variance for training data.
    10000
    ([0.4940607, 0.4850613, 0.45037037], [0.20085774, 0.19870903, 0.20153421])
    

    2.2 使用opencv和numpy计算图像数据集的均值和方差

    import os
    import random
    
    import cv2
    import numpy as np
    
    # calculate means and std
    train_txt_path = './data/Label/TR.txt'
    base_path = './data/food'
    
    CNum = 66071  # 挑选多少图片进行计算
    
    img_h, img_w = 256, 256
    imgs = np.zeros([img_w, img_h, 3, 1])
    means, stdevs = [], []
    
    with open(train_txt_path, 'r') as f:
        lines = f.readlines()
        random.shuffle(lines)  # shuffle , 随机挑选图片
    
        for i in range(CNum):
            # img_path = os.path.join(base_path, lines[i].rstrip().split()[0])
            img_path = base_path + lines[i].rstrip().split()[0]
    
            img = cv2.imread(img_path)
            img = cv2.resize(img, (img_h, img_w))
            img = img[:, :, :, np.newaxis]
    
            imgs = np.concatenate((imgs, img), axis=3)
    
    imgs = imgs.astype(np.float32) / 255.
    
    for i in range(3):
        pixels = imgs[:, :, i, :].ravel()  # 拉成一行
        means.append(np.mean(pixels))
        stdevs.append(np.std(pixels))
    
    # cv2 读取的图像格式为BGR,PIL/Skimage读取到的都是RGB不用转
    means.reverse()  # BGR --> RGB
    stdevs.reverse()
    
    print("normMean = {}".format(means))
    print("normStd = {}".format(stdevs))
    print('transforms.Normalize(normMean = {}, normStd = {})'.format(means, stdevs))
    
    

    2.3 计算某个目录下所有图片的均值和方差

    import numpy as np
    import cv2
    import os
     
    # img_h, img_w = 32, 32
    img_h, img_w = 32, 48   #根据自己数据集适当调整,影响不大
    means, stdevs = [], []
    img_list = []
     
    imgs_path = 'D:/database/VOCdevkit/VOC2012/JPEGImages/'
    imgs_path_list = os.listdir(imgs_path)
     
    len_ = len(imgs_path_list)
    i = 0
    for item in imgs_path_list:
        img = cv2.imread(os.path.join(imgs_path,item))
        img = cv2.resize(img,(img_w,img_h))
        img = img[:, :, :, np.newaxis]
        img_list.append(img)
        i += 1
        print(i,'/',len_)    
     
    imgs = np.concatenate(img_list, axis=3)
    imgs = imgs.astype(np.float32) / 255.
     
    for i in range(3):
        pixels = imgs[:, :, i, :].ravel()  # 拉成一行
        means.append(np.mean(pixels))
        stdevs.append(np.std(pixels))
     
    # BGR --> RGB , CV读取的需要转换,PIL读取的不用转换
    means.reverse()
    stdevs.reverse()
    
    print("normMean = {}".format(means))
    print("normStd = {}".format(stdevs))
    

    参考资料

    1. https://blog.csdn.net/PanYHHH/article/details/107896526
    2. https://blog.csdn.net/weixin_38533896/article/details/85951903
    3. https://blog.csdn.net/dcrmg/article/details/102467434

    来源:紫芝

    物联沃分享整理
    物联沃-IOTWORD物联网 » PyTorch数据归一化处理:transforms.Normalize及计算图像数据集的均值和方差

    发表评论