(一)使用pytorch搭建模型并训练
为了学习一下使用TensorRT进行推理的全过程,便想着写一个TensorRT推理手写数字分类的小例程。这个例程包括使用pytorch进行LeNet网络的搭建、训练、保存pytorch格式的模型(pth)、将模型(pth)转为onnx通用格式、使用tensorRT解析onnx模型进行推理等。
本节介绍使用pytorch进行手写数字分类网络的搭建,并进行训练。
网络结构图如下所示:
结构说明:输入是单通道的12828的灰度图像,经过卷积、池化、卷积、池化后shape变为5044(50为通道数)。将其展平后维度为1*800,然后连接一个维度为500的线性层C5,C5层的输出经过ReLU函数激活后再连接一个维度为10的线性层C6,C6层的输出就为网络的输出。
一般来说,我们要求的是输入图片属于某一类的概率,所有我们要将C6的输出通过softmax函数进行转换。
新建model.py文件,包含以下代码:
# 搭建网络模型
import torch
import torch.nn as nn
import torch.nn.functional as Ffrom torchinfo import summary #用来打印网络层的信息
# from torchkeras import summary module 'torch.backends' has no attribute 'mps'class Net(nn.Module):def __init__(self) -> None:super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 20, kernel_size=5)self.conv2 = nn.Conv2d(20, 50, kernel_size=5)self.fc1 = nn.Linear(800, 500)self.fc2 = nn.Linear(500, 10)def forward(self, x):x = F.max_pool2d(self.conv1(x), kernel_size=2, stride=2)x = F.max_pool2d(self.conv2(x), kernel_size=2, stride=2)x = x.view(-1, 800) # 将其展平x = F.relu(self.fc1(x))x = self.fc2(x)return F.log_softmax(x, dim=1) # 要使用NLLLoss()损失函数,所以输出要先经过log_softmaxif __name__ == "__main__":net = Net()summary(net, (1,1,28,28))
新建train.py,包含以下代码:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variablefrom model import Net
import numpy as np
import os
import torch.utils.data
from random import randint
class MnistModel(object):def __init__(self):self.batch_size = 64 # 训练batch_sizeself.test_batch_size = 100 # 测试batch_size self.learning_rate = 0.0025 #学习率self.sgd_momentum = 0.9self.log_interval = 100# 构造数据self.train_loader = torch.utils.data.DataLoader(datasets.MNIST("./tmp/mnist/data",train=True,download=True,transform=transforms.Compose( # 预处理:对训练数据只进行标准化[transforms.ToTensor(),transforms.Normalize((0.1307),(0.3081,))])),batch_size=self.batch_size,shuffle=True,num_workers=4,timeout=600,)self.test_loader = torch.utils.data.DataLoader(datasets.MNIST("./tmp/mnist/data",train=False,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307),(0.3081,))])),batch_size = self.test_batch_size,shuffle = True,num_workers=4,timeout=600,)self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")self.network = Net()self.network.to(self.device) def learn(self, num_epochs=2): # 训练两个epoch# # Train the network for a single epochdef train(epoch):self.network.train()optimizer = optim.SGD(self.network.parameters(), lr=self.learning_rate, momentum=self.sgd_momentum) # 使用SGD优化器for batch, (data, target) in enumerate(self.train_loader):data, target = Variable(data.to(self.device)), Variable(target.to(self.device))optimizer.zero_grad()output = self.network(data)loss = F.nll_loss(output, target).to(self.device)loss.backward()optimizer.step()if batch % self.log_interval == 0: #每100个batch打印一次信息print("Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(epoch,batch * len(data),len(self.train_loader.dataset),100.0 * batch / len(self.train_loader),loss.data.item(),))# Test the networkdef test(epoch):self.network.eval()test_loss = 0correct = 0for data, target in self.test_loader:with torch.no_grad():data, target = Variable(data.to(self.device)), Variable(target.to(self.device))output = self.network(data)test_loss += F.nll_loss(output, target).data.item()pred = output.data.max(1)[1] # 输出最大值的索引为预测的类别correct += pred.eq(target.data).cpu().sum()test_loss /= len(self.test_loader)v # 测试集每一个batch的平均损失print("\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(test_loss, correct, len(self.test_loader.dataset), 100.0 * correct / len(self.test_loader.dataset)))for e in range(num_epochs):train(e + 1)test(e + 1)
train_model = MnistModel()
train_model.learn()
总的来说,训练代码中没有太值得让人注意的地方。如果非要说有,那我觉得以下三点可能是要注意的地方:
可以看到,两个epoch后,模型的准确率为99%,所以我选择停止训练,然后保存模型。
在pytorch中保存模型有两种形式,一种是保存整个网络,一种是只保存网络中的参数。
保存整个网络的方法如下:
# 保存整个网络
torch.save(net, path)
# 加载网络
model = torch.load(path)
只保存网络中的参数的方法如下:
# 保存
torch.save(net.state_dict(), path)
# 加载
model = model.load_state_dict(torch.load(path))
在这个demo中,我们只需要在train.py后加上
torch.save(net.state_dict(), './model.pth')
就可以保存模型为model.pth文件。
本节我们进行了模型的搭建、训练以及保存模型。下一节我们将介绍如何将我们保存的pth文件转为onnx通用格式,同时对我们转成的onnx文件进行检查和验证。