CIFAR-10数据集分类的ResNet实现详解(Python版)

一、数据集下载

下载地址:CIFAR-10 and CIFAR-100 datasets

进入下载地址后选择CIFAR-10 python version进行下载

下载完成后解压得到如下文件

其中从data_batch1到data_batch5是已划分好的训练集,每个训练集有10000张图片。test_batch是测试结,有10000张图片。

二、实现数据读取部分

import pickle
import os
import numpy as np
def get_cifar(root=""):
    def load_file(filename):
        with open(filename,'rb') as fo:
            data = pickle.load(fo,encoding='latin1')
        return data
    data_batch1=load_file(os.path.join(root,'data_batch_1'))
    data_batch2 = load_file(os.path.join(root, 'data_batch_2'))
    data_batch3 = load_file(os.path.join(root, 'data_batch_3'))
    data_batch4 = load_file(os.path.join(root, 'data_batch_4'))
    data_batch5 = load_file(os.path.join(root, 'data_batch_5'))
    dataset=[]
    labelset= []
    for data in [data_batch1,data_batch2,data_batch3,data_batch4,data_batch5]:
        img_data=(data["data"])
        img_lable=(data["labels"])
        dataset.append(img_data)
        labelset.append(img_lable)
    dataset = np.concatenate(dataset)
    labelset = np.concatenate(labelset)
    return dataset,labelset
def get_cif_test(root=""):
    def load_file(filename):
        with open(filename,'rb') as fo:
            data=pickle.load(fo,encoding='latin1')
            return data
    data_batch1 = load_file(os.path.join(root,'test_batch'))
    dataset = []
    labelset = []
    for data in [data_batch1]:
        img_data = (data["data"])
        img_label = (data["labels"])
        dataset.append(img_data)
        labelset.append(img_label)
    dataset=np.concatenate(dataset)
    labelset = np.concatenate(labelset)
    return dataset,labelset
def get_cif_data(root=""):
    train_dataset,label_dataset = get_cifar(root=root)
    test_dataset,test_label_dataset = get_cif_test(root=root)
    return train_dataset,label_dataset,test_dataset,test_label_dataset
if __name__ == "main":
    train_dataset, label_dataset, test_dataset, test_label_dataset=get_cif_data(root="D:/shuju/shuzi")
    train_dataset=np.reshape(train_dataset,[len(train_dataset),3,32,32])
    test_dataset = np.reshape(test_dataset,[len(test_label_dataset),3,32,32])
    label_dataset = np.array(label_dataset)
    test_label_dataset = np.array(test_label_dataset)

其中get_cifar函数用来读取训练集数据,get_cifar_test函数用来读取测试集函数。

三、构建模型和训练

import numpy as np
import get_data
import torch
import torchvision.models as models

train_dataset,label_dataset,test_dataset,test_label_dataset = get_data.get_cif_data(root="D:/shuju/shuzi")
train_dataset=np.reshape(train_dataset,[len(train_dataset),3,32,32]).astype(np.float32)/255
test_dataset = np.reshape(test_dataset,[len(test_dataset),3,32,32]).astype(np.float32)/255
label_dataset = np.array(label_dataset)
test_label_dataset = np.array(test_label_dataset)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = models.resnet18()
model = model.to(device)
#model = torch.compile(model)
optimizer = torch.optim.Adam(model.parameters(),lr=2e-5)
loss_fn = torch.nn.CrossEntropyLoss()
batch_size = 128
train_num = len(label_dataset)//batch_size
for epoch in range(63):
    train_loss = 0.0
    total_correct = 0
    for i in range(train_num):
        start = i*batch_size
        end = (i+1)*batch_size
        x_batch = torch.from_numpy(train_dataset[start:end]).to(device)
        y_batch = torch.from_numpy(label_dataset[start:end]).to(device)

        pred = model(x_batch)
        loss = loss_fn(pred,y_batch.long())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss+=loss.item()
        total_correct += (pred.argmax(1) == y_batch).sum().item()
    train_loss/=train_num
    train_accuracy = total_correct / (train_num * batch_size)
    test_num = 2048
    x_test = torch.from_numpy(test_dataset[:test_num]).to(device)
    y_test = torch.from_numpy(test_label_dataset[:test_num]).to(device)
    pred = model(x_test)
    test_accuaacy = (pred.argmax(1) == y_test).type(torch.float32).sum().item()/test_num
    print("epoch:",epoch,"train_loss",
          round(train_loss,2),";accuracy:",round(train_accuracy,2),'test_accuracy',round(test_accuaacy,2))

作者:一尾清风915

物联沃分享整理
物联沃-IOTWORD物联网 » CIFAR-10数据集分类的ResNet实现详解(Python版)

发表回复