365天深度学习训练营-第J5周:DenseNet+SE-Net实战
创始人
2024-06-01 07:41:57
0

目录

 一、前言

二、论文解读

三、代码实现

1、定义SE块

2、se模块插入到Dense net代码实现


 一、前言

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊|接辅导、项目定制
● 难度:夯实基础⭐⭐
● 语言:Python3、Pytorch3
● 时间:3月4日-3月10日
🍺要求:
1. 在DenseNet系列算法中插入SE-Net通道注意力机制,并完成猴痘病识别
2. 改进思路是否可以迁移到其他地方呢
3. 测试集accuracy到达89%(拔高,可选)

二、论文解读

论文:Squeeze-and-Excitation Networks

"Squeeze-and-Excitation Networks"是一篇由Jie Hu,Li Shen和Gang Sun在2018年发表的论文,介绍了一种基于自适应特征重标定的网络结构,用于提高卷积神经网络(CNN)的性能。这个网络结构被称为“Squeeze-and-Excitation(SE)网络”。

SE网络的核心思想是利用自适应的特征重标定来增强网络的表达能力,使其能够更好地处理不同类别之间的差异。具体来说,SE网络在每个通道上引入一个Squeeze操作和一个Excitation操作。

Squeeze操作将每个通道的特征图压缩成一个数值,并且将其作为该通道的全局特征描述符。这可以通过使用全局平均池化来实现。然后,Excitation操作将该描述符作为输入,并生成一个权重向量,该向量可以动态地调整每个通道的权重,以强化重要的特征并抑制不重要的特征。这可以通过使用一系列全连接层和非线性激活函数来实现。

SE网络可以通过简单地添加SE块来嵌入到任何卷积神经网络中,而不需要对网络架构进行大规模修改。这使得SE网络非常易于实现,并且在多个视觉任务上都可以提高性能。

在论文中,作者通过在ImageNet和CIFAR-10数据集上进行实验,证明了SE网络在各种任务上都可以提高CNN的性能,包括图像分类、目标检测和语义分割等。作者还展示了SE网络对于不同类别之间的差异建模能力强于其他网络结构。

总的来说,SE网络通过引入自适应的特征重标定来增强CNN的表达能力,并在各种视觉任务上取得了显著的性能提升。

三、代码实现

1、定义SE块

SE网络的代码实现相对简单,可以通过在现有卷积神经网络的基础上添加SE块来完成。

首先,需要定义SE块的代码实现,例如:

''' Squeeze Excitation Module '''
class SEModule(nn.Module):def __init__(self, in_channel, filter_sq=16):super(SEModule, self).__init__()self.se = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),nn.Flatten(),nn.Linear(in_channel, in_channel//filter_sq),nn.ReLU(True),nn.Linear(in_channel//filter_sq, in_channel),nn.Sigmoid())#self.se = nn.Sequential(#    nn.AdaptiveAvgPool2d((1,1)),#    nn.Conv2d(in_channel, in_channel//filter_sq, kernel_size=1),#    nn.ReLU(),#    nn.Conv2d(in_channel//filter_sq, in_channel, kernel_size=1),#    nn.Sigmoid()#)def forward(self, inputs):x = self.se(inputs)s1, s2 = x.size(0), x.size(1)x = torch.reshape(x, (s1, s2, 1, 1))x = inputs * xreturn x

在这个SE块中,我们首先使用nn.AdaptiveAvgPool2d(1)对输入的特征图进行全局平均池化,然后通过两个全连接层和非线性激活函数来计算出每个通道的权重向量。最后,将该向量应用于原始特征图,以生成SE增强的特征图。

2、se模块插入到Dense net代码实现

在将SE模块插入到DenseNet的代码实现中,您需要进行以下步骤:

定义DenseNet模型的主体部分。您可以选择使用一个或多个密集块(dense block),每个密集块包含多个卷积层和批标准化层。每个密集块之间还可以插入跳跃连接(skip connection)以简化模型的训练。在这个例子中,我们将使用一个具有四个密集块的DenseNet模型:

''' Basic unit of DenseBlock (using bottleneck layer) '''
class DenseLayer(nn.Sequential):def __init__(self, in_channel, growth_rate, bn_size, drop_rate):super(DenseLayer, self).__init__()self.add_module('norm1', nn.BatchNorm2d(in_channel))self.add_module('relu1', nn.ReLU(inplace=True))self.add_module('conv1', nn.Conv2d(in_channel, bn_size*growth_rate,kernel_size=1, stride=1, bias=False))self.add_module('norm2', nn.BatchNorm2d(bn_size*growth_rate))self.add_module('relu2', nn.ReLU(inplace=True))self.add_module('conv2', nn.Conv2d(bn_size*growth_rate, growth_rate,kernel_size=3, stride=1, padding=1, bias=False))self.drop_rate = drop_ratedef forward(self, x):new_feature = super(DenseLayer, self).forward(x)if self.drop_rate>0:new_feature = F.dropout(new_feature, p=self.drop_rate, training=self.training)return torch.cat([x, new_feature], 1)''' DenseBlock '''
class DenseBlock(nn.Sequential):def __init__(self, num_layers, in_channel, bn_size, growth_rate, drop_rate):super(DenseBlock, self).__init__()for i in range(num_layers):layer = DenseLayer(in_channel+i*growth_rate, growth_rate, bn_size, drop_rate)self.add_module('denselayer%d'%(i+1,), layer)''' Transition layer between two adjacent DenseBlock '''
class Transition(nn.Sequential):def __init__(self, in_channel, out_channel):super(Transition, self).__init__()self.add_module('norm', nn.BatchNorm2d(in_channel))self.add_module('relu', nn.ReLU(inplace=True))self.add_module('conv', nn.Conv2d(in_channel, out_channel,kernel_size=1, stride=1, bias=False))self.add_module('pool', nn.AvgPool2d(2, stride=2))

在这个例子中,我们在每个密集块之后插入了SE块。可以根据需要更改密集块的数量和增长率,并调整SE块的超参数来进一步改进模型的性能。

''' DenseNet-BC model '''
class DenseNet(nn.Module):def __init__(self, growth_rate=32, block_config=(6,12,24,16), init_channel=64, bn_size=4, compression_rate=0.5, drop_rate=0, num_classes=1000):''':param growth_rate: (int) number of filters used in DenseLayer, `k` in the paper:param block_config: (list of 4 ints) number of layers in eatch DenseBlock:param init_channel: (int) number of filters in the first Conv2d:param bn_size: (int) the factor using in the bottleneck layer:param compression_rate: (float) the compression rate used in Transition Layer:param drop_rate: (float) the drop rate after each DenseLayer:param num_classes: (int) number of classes for classification'''super(DenseNet, self).__init__()# first Conv2dself.features = nn.Sequential(OrderedDict([('conv0', nn.Conv2d(3, init_channel, kernel_size=7, stride=2, padding=3, bias=False)),('norm0', nn.BatchNorm2d(init_channel)),('relu0', nn.ReLU(inplace=True)),('pool0', nn.MaxPool2d(3, stride=2, padding=1))]))# DenseBlocknum_features = init_channelfor i, num_layers in enumerate(block_config):block = DenseBlock(num_layers, num_features, bn_size, growth_rate, drop_rate)self.features.add_module('denseblock%d'%(i+1), block)num_features += num_layers*growth_rateif i!=len(block_config)-1:transition = Transition(num_features, int(num_features*compression_rate))self.features.add_module('transition%d'%(i+1), transition)num_features = int(num_features*compression_rate)# SE Moduleself.features.add_module('SE-module', SEModule(num_features))# final BN+ReLUself.features.add_module('norm5', nn.BatchNorm2d(num_features))self.features.add_module('relu5', nn.ReLU(inplace=True))# classification layerself.classifier = nn.Linear(num_features, num_classes)# params initializationfor m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight, 1)elif isinstance(m, nn.Linear):nn.init.constant_(m.bias, 0)def forward(self, x):x = self.features(x)x = F.avg_pool2d(x, 7, stride=1).view(x.size(0), -1)x = self.classifier(x)return x

相关内容

热门资讯

监控摄像头接入GB28181平... 流程简介将监控摄像头的视频在网站和APP中直播,要解决的几个问题是:1&...
Windows10添加群晖磁盘... 在使用群晖NAS时,我们需要通过本地映射的方式把NAS映射成本地的一块磁盘使用。 通过...
protocol buffer... 目录 目录 什么是protocol buffer 1.protobuf 1.1安装  1.2使用...
在Word、WPS中插入AxM... 引言 我最近需要写一些文章,在排版时发现AxMath插入的公式竟然会导致行间距异常&#...
【PdgCntEditor】解... 一、问题背景 大部分的图书对应的PDF,目录中的页码并非PDF中直接索引的页码...
修复 爱普生 EPSON L4... L4151 L4153 L4156 L4158 L4163 L4165 L4166 L4168 L4...
Fluent中创建监测点 1 概述某些仿真问题,需要创建监测点,用于获取空间定点的数据࿰...
educoder数据结构与算法...                                                   ...
MySQL下载和安装(Wind... 前言:刚换了一台电脑,里面所有东西都需要重新配置,习惯了所...
MFC文件操作  MFC提供了一个文件操作的基类CFile,这个类提供了一个没有缓存的二进制格式的磁盘...