深度学习11. CNN经典网络 LeNet-5实现CIFAR-10
创始人
2025-05-30 04:51:56
0

深度学习11. CNN经典网络 LeNet-5实现CIFAR-10

  • 一、CIFAR-10介绍
  • 二、PyTorch的 transforms 介绍
  • 三、实现步骤
    • 1. 准备数据
    • 2. 模型定义
    • 3. 训练与测试
  • 四、完整代码

在这里插入图片描述

本文在前节程序基础上,实现对CIFAR-10的训练与测试,以加深对LeNet-5网络的理解 。

首先,要了解LeNet-5并不适合训练 CIFAR-10 , 最后的正确率不会太理想 。

一、CIFAR-10介绍

CIFAR-10是一个常用的图像分类数据集,由10类共计60,000张32x32大小的彩色图像组成,每类包含6,000张图像。这些图像被平均分为了5个训练批次和1个测试批次,每个批次包含10,000张图像。CIFAR-10数据集中的10个类别分别为:飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车。

相比之下,MNIST是一个手写数字分类数据集,由10个数字(0-9)共计60,000个训练样本和10,000个测试样本组成,每个样本是一个28x28的灰度图像。

与MNIST相比,CIFAR-10更具挑战性,因为它是一个彩色图像数据集,每张图像包含更多的信息和细节,难度更高。此外,CIFAR-10的类别也更加多样化,更加贴近实际应用场景。因此,CIFAR-10更适合用于测试和评估具有更高难度的图像分类模型,而MNIST则更适合用于介绍和入门级别的模型训练和测试。

二、PyTorch的 transforms 介绍

PyTorch中的transforms是用于对数据进行预处理和增强的工具,主要用于图像数据的处理,它可以方便地对数据进行转换,使其符合神经网络的输入要求。

transforms的方法:

  • ToTensor : 将数据转换为PyTorch中的张量格式。
  • Normalize:对数据进行标准化,使其均值为0,方差为1,以便网络更容易训练。
  • Resize:调整图像大小。
  • RandomCrop:随机裁剪图像的一部分。
  • CenterCrop:从图像的中心裁剪出一部分。
  • RandomHorizontalFlip :以一定的概率随机水平翻转图像,以增加训练集的多样性。
  • RandomVerticalFlip:以一定的概率随机垂直翻转图像,以增加训练集的多样性。
  • RandomRotation:以一定的概率随机旋转图像。
  • ColorJitter:随机调整图像的亮度、对比度、饱和度和色调。
  • RandomErasing:随机擦除图像中的一部分区域,以增加训练集的多样性。

使用transforms可以方便地进行数据预处理和增强,提高模型的鲁棒性和泛化能力。在实际应用中,可以根据具体问题和需求进行选择和组合。

三、实现步骤

1. 准备数据

下面定义加载CIFAR-10数据集,首先会对图片进行一些处理:

  • transforms.RandomHorizontalFlip():随机水平翻转图像
  • transforms.RandomCrop(32, padding=4):随机裁剪图像,大小为32x32,边缘填充4个像素
  • transforms.ToTensor():将图像转换为张量,并归一化到[0,1]范围内
  • transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)):将张量标准化,使每个通道的均值为0.5,标准差为0.5

对数据的处理可以增加数据的多样性和丰富性,以提高神经网络的泛化能力和准确率。

transform = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomCrop(32, padding=4),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=32,shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=32,shuffle=False, num_workers=2)

2. 模型定义

class LeNet5(nn.Module):def __init__(self):super(LeNet5, self).__init__()# 定义卷积层C1,输入通道数为1,输出通道数为6,卷积核大小为5x5self.conv1 = nn.Conv2d(3, 6, kernel_size=5, stride=1)# 定义池化层S2,池化核大小为2x2,步长为2self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)# 定义卷积层C3,输入通道数为6,输出通道数为16,卷积核大小为5x5self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1)# 定义池化层S4,池化核大小为2x2,步长为2self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)# 定义全连接层F5,输入节点数为16x4x4=256,输出节点数为120self.fc1 = nn.Linear(16 * 5 * 5, 120)# 定义全连接层F6,输入节点数为120,输出节点数为84self.fc2 = nn.Linear(120, 84)# 定义输出层,输入节点数为84,输出节点数为10self.fc3 = nn.Linear(84, 10)def forward(self, x):# 卷积层C1x = self.conv1(x)# 池化层S2x = self.pool1(torch.relu(x))# 卷积层C3x = self.conv2(x)# 池化层S4x = self.pool2(torch.relu(x))# 全连接层F5x = x.view(-1, 16 * 5 * 5)x = self.fc1(x)x = torch.relu(x)# 全连接层F6x = self.fc2(x)x = torch.relu(x)# 输出层x = self.fc3(x)return x

和上节类似 , 这个模型定义了经典的 LeNet-5。它由两个卷积层、两个池化层和三个全连接层组成,层间通过一定的非线性激活函数进行连接。

  • 模型中的第一个卷积层(C1)的输入通道数是3,即输入的是3通道的图像数据;输出通道数为6,表示该层有6个卷积核,每个卷积核可以提取出一种特征,卷积核大小为5x5。
  • 第一个池化层(S2)的池化核大小为2x2,步长为2,可以将特征图的大小降低一半。
  • 卷积层(C3),输入通道数为6,输出通道数为16,卷积核大小为5x5。然后再经过一个池化层(S4),池化核大小为2x2,步长为2,同样可以将特征图的大小降低一半。
  • 接下来是三个全连接层(F5、F6、F7),其中 F5 的输入节点数为16x5x5=400,输出节点数为120;F6 的输入节点数为120,输出节点数为84;输出层的输入节点数为84,输出节点数为10,表示对10个类别进行分类。

最后,网络输出了分类结果。在前向传播过程中,经过卷积、池化、全连接等操作,每层的输出都要经过一定的非线性激活函数,这里使用的是 ReLU 函数(即 Rectified Linear Unit)。

3. 训练与测试

model = LeNet5()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
if __name__ == '__main__':# 定义模型保存路径和文件名model_path = 'model.pth'if os.path.exists(model_path):# 存在,直接加载模型model.load_state_dict(torch.load(model_path))print('Loaded model from', model_path)else:# 训练模型for epoch in range(epochs):model.train()for images, labels in train_loader:# 将数据放入模型optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 在测试集上测试模型model.eval()correct = 0with torch.no_grad():for images, labels in test_loader:# 将数据放入模型outputs = model(images)_, predicted = torch.max(outputs, 1)correct += (predicted == labels).sum().item()accuracy = 100 * correct / len(testset)print('Epoch [{}/{}], Loss: {:.4f}, Accuracy: {:.2f}%'.format(epoch + 1, epochs, loss.item(), accuracy))torch.save(model.state_dict(), 'model.pth')for i in range(10):img, label = next(iter(test_loader))img = img[i].unsqueeze(0)# 使用模型进行预测model.eval()with torch.no_grad():output = model(img)# 解码预测结果pred = output.argmax(dim=1).item()print(f'Predicted class: {pred}, actual value: {label[i]}')

四、完整代码

import osimport torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 定义 LeNet-5 模型
class LeNet5(nn.Module):def __init__(self):super(LeNet5, self).__init__()# 定义卷积层C1,输入通道数为1,输出通道数为6,卷积核大小为5x5self.conv1 = nn.Conv2d(3, 6, kernel_size=5, stride=1)# 定义池化层S2,池化核大小为2x2,步长为2self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)# 定义卷积层C3,输入通道数为6,输出通道数为16,卷积核大小为5x5self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1)# 定义池化层S4,池化核大小为2x2,步长为2self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)# 定义全连接层F5,输入节点数为16x4x4=256,输出节点数为120self.fc1 = nn.Linear(16 * 5 * 5, 120)# 定义全连接层F6,输入节点数为120,输出节点数为84self.fc2 = nn.Linear(120, 84)# 定义输出层,输入节点数为84,输出节点数为10self.fc3 = nn.Linear(84, 10)def forward(self, x):# 卷积层C1x = self.conv1(x)# print('卷积层C1后的形状:', x.shape)# 池化层S2x = self.pool1(torch.relu(x))# print('池化层S2后的形状:', x.shape)# 卷积层C3x = self.conv2(x)# print('卷积层C3后的形状:', x.shape)# 池化层S4x = self.pool2(torch.relu(x))# print('池化层S4后的形状:', x.shape)# 全连接层F5x = x.view(-1, 16 * 5 * 5)x = self.fc1(x)# print('全连接层F5后的形状:', x.shape)x = torch.relu(x)# 全连接层F6x = self.fc2(x)# print('全连接层F6后的形状:', x.shape)x = torch.relu(x)# 输出层x = self.fc3(x)# print('输出层后的形状:', x.shape)return x# 设置超参数
batch_size = 128
learning_rate = 0.01
epochs = 10# CIFAR-10
# 准备数据
transform = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomCrop(32, padding=4),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=32,shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=32,shuffle=False, num_workers=2)# 实例化模型和优化器
model = LeNet5()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
if __name__ == '__main__':# 定义模型保存路径和文件名model_path = 'model.pth'if os.path.exists(model_path):# 存在,直接加载模型model.load_state_dict(torch.load(model_path))print('Loaded model from', model_path)else:# 训练模型for epoch in range(epochs):model.train()for images, labels in train_loader:# 将数据放入模型optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 在测试集上测试模型model.eval()correct = 0with torch.no_grad():for images, labels in test_loader:# 将数据放入模型outputs = model(images)_, predicted = torch.max(outputs, 1)correct += (predicted == labels).sum().item()accuracy = 100 * correct / len(testset)print('Epoch [{}/{}], Loss: {:.4f}, Accuracy: {:.2f}%'.format(epoch + 1, epochs, loss.item(), accuracy))torch.save(model.state_dict(), 'model.pth')for i in range(10):img, label = next(iter(test_loader))img = img[i].unsqueeze(0)# 使用模型进行预测model.eval()with torch.no_grad():output = model(img)# 解码预测结果pred = output.argmax(dim=1).item()print(f'Predicted class: {pred}, actual value: {label[i]}')

最后训练准确率仅有47.13%。

相关内容

热门资讯

监控摄像头接入GB28181平... 流程简介将监控摄像头的视频在网站和APP中直播,要解决的几个问题是:1&...
Windows10添加群晖磁盘... 在使用群晖NAS时,我们需要通过本地映射的方式把NAS映射成本地的一块磁盘使用。 通过...
protocol buffer... 目录 目录 什么是protocol buffer 1.protobuf 1.1安装  1.2使用...
Fluent中创建监测点 1 概述某些仿真问题,需要创建监测点,用于获取空间定点的数据࿰...
educoder数据结构与算法...                                                   ...
MySQL下载和安装(Wind... 前言:刚换了一台电脑,里面所有东西都需要重新配置,习惯了所...
MFC文件操作  MFC提供了一个文件操作的基类CFile,这个类提供了一个没有缓存的二进制格式的磁盘...
在Word、WPS中插入AxM... 引言 我最近需要写一些文章,在排版时发现AxMath插入的公式竟然会导致行间距异常&#...
有效的括号 一、题目 给定一个只包括 '(',')','{','}'...
【Ctfer训练计划】——(三... 作者名:Demo不是emo  主页面链接:主页传送门 创作初心ÿ...