基于keras与tensorflow手工实现ResNet50网络
创始人
2024-03-15 08:34:08
0

前言

在文章 基于tensorflow的ResNet50V2网络识别动物,我们使用了keras已经提供的神经网络,完成了图像分类的。这个时候,小明同学就问了,那么我怎么自己去写一个神经网络来进行训练呢?
本文就基于tensorflow,自己定一个神经网络。

ResNet50网络

在这里插入图片描述
从结构上看,与我们之前的的区别在于,输入的格式变成(3,224,224)
ResNet50有两个基本的块,分别名为Conv BlockIdentity Block

整体架构

在这里插入图片描述

Conv Block架构

在这里插入图片描述

Identity Block架构

在这里插入图片描述

模型训练

手工实现模型

模型代码(resnet50.py)

# 根据模型进行引入
from keras import layersfrom keras.layers import Input,Activation,BatchNormalization,Flatten
from keras.layers import Dense,Conv2D,MaxPooling2D,ZeroPadding2D,AveragePooling2D
from keras.models import Modeldef identity_block(input_tensor, kernel_size, filters, stage, block):filters1, filters2, filters3 = filtersname_base = str(stage) + block + '_identity_block_'x = Conv2D(filters1, (1, 1), name=name_base + 'conv1')(input_tensor)x = BatchNormalization(name=name_base + 'bn1')(x)x = Activation('relu', name=name_base + 'relu1')(x)x = Conv2D(filters2, kernel_size, padding='same', name=name_base + 'conv2')(x)x = BatchNormalization(name=name_base + 'bn2')(x)x = Activation('relu', name=name_base + 'relu2')(x)x = Conv2D(filters3, (1, 1), name=name_base + 'conv3')(x)x = BatchNormalization(name=name_base + 'bn3')(x)x = layers.add([x, input_tensor], name=name_base + 'add')x = Activation('relu', name=name_base + 'relu4')(x)return xdef conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):filters1, filters2, filters3 = filtersres_name_base = str(stage) + block + '_conv_block_res_'name_base = str(stage) + block + '_conv_block_'x = Conv2D(filters1, (1, 1), strides=strides, name=name_base + 'conv1')(input_tensor)x = BatchNormalization(name=name_base + 'bn1')(x)x = Activation('relu', name=name_base + 'relu1')(x)x = Conv2D(filters2, kernel_size, padding='same', name=name_base + 'conv2')(x)x = BatchNormalization(name=name_base + 'bn2')(x)x = Activation('relu', name=name_base + 'relu2')(x)x = Conv2D(filters3, (1, 1), name=name_base + 'conv3')(x)x = BatchNormalization(name=name_base + 'bn3')(x)shortcut = Conv2D(filters3, (1, 1), strides=strides, name=res_name_base + 'conv')(input_tensor)shortcut = BatchNormalization(name=res_name_base + 'bn')(shortcut)x = layers.add([x, shortcut], name=name_base+'add')x = Activation('relu', name=name_base+'relu4')(x)return xdef ResNet50(input_shape=[224,224,3],classes=1000):img_input = Input(shape=input_shape)x = ZeroPadding2D((3, 3))(img_input)x = Conv2D(64, (7, 7), strides=(2, 2), name='conv1')(x)x = BatchNormalization(name='bn_conv1')(x)x = Activation('relu')(x)x = MaxPooling2D((3, 3), strides=(2, 2))(x)x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')x = AveragePooling2D((7, 7), name='avg_pool')(x)x = Flatten()(x)x = Dense(classes, activation='softmax', name='fc1000')(x)model = Model(img_input, x, name='resnet50')return model

模型训练(resnet50_model_train.py)

将 基于tensorflow的ResNet50V2网络识别动物的模型训练代码进行少量改造

import os
import pandas as pd# Model
import keras
from keras.preprocessing.image import ImageDataGenerator# Callbacks
from keras.callbacks import EarlyStopping, ModelCheckpoint# Pre-Trained Model
import tensorflow as tf
import resnet50root_path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Training Data/'
valid_path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Validation Data/'
test_path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Testing Data/'
class_names = sorted(os.listdir(root_path))
n_classes = len(class_names)print(f"Total Number of Classes : {n_classes} \nClass Names : {class_names}")class_dis = [len(os.listdir(root_path+name)) for name in class_names]train_gen = ImageDataGenerator(rescale=1/255., rotation_range=10, horizontal_flip=True)
valid_gen = ImageDataGenerator(rescale=1/255.)
test_gen = ImageDataGenerator(rescale=1/255)# Load Data
train_ds = train_gen.flow_from_directory(root_path, class_mode='binary', target_size=(224,224), shuffle=True, batch_size=32)
valid_ds = valid_gen.flow_from_directory(valid_path, class_mode='binary', target_size=(224,224), shuffle=True, batch_size=32)
test_ds = test_gen.flow_from_directory(test_path, class_mode='binary', target_size=(224,224), shuffle=True, batch_size=32)with tf.device("/GPU:0"):## Pre-Trained Modelmodel = resnet50.ResNet50()model.summary()model_file = "ResNet50_V1.h5"# 加载预训练模型if os.path.exists(model_file):model.load_weights(model_file)# Callbackscbs = [EarlyStopping(patience=5, restore_best_weights=True), ModelCheckpoint(model_file, save_best_only=True)]# Modelopt = tf.keras.optimizers.Adam(learning_rate=2e-3)model.compile(loss='sparse_categorical_crossentropy', optimizer=opt, metrics=['accuracy'])# Model Traininghistory = model.fit(train_ds, validation_data=valid_ds, callbacks=cbs, epochs=200)data = pd.DataFrame(history.history)print(data)

模型执行

GPU基本跑满
在这里插入图片描述

Model: "resnet50"
__________________________________________________________________________________________________Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================input_1 (InputLayer)           [(None, 224, 224, 3  0           []                               )]                                                                zero_padding2d (ZeroPadding2D)  (None, 230, 230, 3)  0          ['input_1[0][0]']                conv1 (Conv2D)                 (None, 112, 112, 64  9472        ['zero_padding2d[0][0]']         )                                                                 bn_conv1 (BatchNormalization)  (None, 112, 112, 64  256         ['conv1[0][0]']                  )                                                                 activation (Activation)        (None, 112, 112, 64  0           ['bn_conv1[0][0]']               )                                                                 max_pooling2d (MaxPooling2D)   (None, 55, 55, 64)   0           ['activation[0][0]']             2a_conv_block_conv1 (Conv2D)   (None, 55, 55, 64)   4160        ['max_pooling2d[0][0]']          2a_conv_block_bn1 (BatchNormal  (None, 55, 55, 64)  256         ['2a_conv_block_conv1[0][0]']    ization)                                                                                         2a_conv_block_relu1 (Activatio  (None, 55, 55, 64)  0           ['2a_conv_block_bn1[0][0]']      n)                                                                                               2a_conv_block_conv2 (Conv2D)   (None, 55, 55, 64)   36928       ['2a_conv_block_relu1[0][0]']    2a_conv_block_bn2 (BatchNormal  (None, 55, 55, 64)  256         ['2a_conv_block_conv2[0][0]']    ization)                                                                                         2a_conv_block_relu2 (Activatio  (None, 55, 55, 64)  0           ['2a_conv_block_bn2[0][0]']      n)                                                                                               2a_conv_block_conv3 (Conv2D)   (None, 55, 55, 256)  16640       ['2a_conv_block_relu2[0][0]']    2a_conv_block_res_conv (Conv2D  (None, 55, 55, 256)  16640      ['max_pooling2d[0][0]']          )                                                                                                2a_conv_block_bn3 (BatchNormal  (None, 55, 55, 256)  1024       ['2a_conv_block_conv3[0][0]']    ization)                                                                                         2a_conv_block_res_bn (BatchNor  (None, 55, 55, 256)  1024       ['2a_conv_block_res_conv[0][0]'] malization)                                                                                      2a_conv_block_add (Add)        (None, 55, 55, 256)  0           ['2a_conv_block_bn3[0][0]',      '2a_conv_block_res_bn[0][0]']   2a_conv_block_relu4 (Activatio  (None, 55, 55, 256)  0          ['2a_conv_block_add[0][0]']      n)                                                                                               2b_identity_block_conv1 (Conv2  (None, 55, 55, 64)  16448       ['2a_conv_block_relu4[0][0]']    D)                                                                                               2b_identity_block_bn1 (BatchNo  (None, 55, 55, 64)  256         ['2b_identity_block_conv1[0][0]']rmalization)                                                                                     2b_identity_block_relu1 (Activ  (None, 55, 55, 64)  0           ['2b_identity_block_bn1[0][0]']  ation)                                                                                           2b_identity_block_conv2 (Conv2  (None, 55, 55, 64)  36928       ['2b_identity_block_relu1[0][0]']D)                                                                                               *******略,大致意思是有50层,因为是ResNet50*********                                                                                         5c_identity_block_bn3 (BatchNo  (None, 7, 7, 2048)  8192        ['5c_identity_block_conv3[0][0]']rmalization)                                                                                     5c_identity_block_add (Add)    (None, 7, 7, 2048)   0           ['5c_identity_block_bn3[0][0]',  '5b_identity_block_relu4[0][0]']5c_identity_block_relu4 (Activ  (None, 7, 7, 2048)  0           ['5c_identity_block_add[0][0]']  ation)                                                                                           avg_pool (AveragePooling2D)    (None, 1, 1, 2048)   0           ['5c_identity_block_relu4[0][0]']flatten (Flatten)              (None, 2048)         0           ['avg_pool[0][0]']               fc1000 (Dense)                 (None, 1000)         2049000     ['flatten[0][0]']                ==================================================================================================
Total params: 25,636,712
Trainable params: 25,583,592
Non-trainable params: 53,120
__________________________________________________________________________________________________

在这里插入图片描述

训练结果

Epoch 1/200
625/625 [==============================] - 245s 376ms/step - loss: 0.3225 - accuracy: 0.8965 - val_loss: 0.7825 - val_accuracy: 0.7670
Epoch 2/200
625/625 [==============================] - 226s 361ms/step - loss: 0.2998 - accuracy: 0.9018 - val_loss: 0.9110 - val_accuracy: 0.7060
Epoch 3/200
625/625 [==============================] - 223s 357ms/step - loss: 0.2644 - accuracy: 0.9113 - val_loss: 1.2283 - val_accuracy: 0.6760
Epoch 4/200
625/625 [==============================] - 223s 357ms/step - loss: 0.2465 - accuracy: 0.9188 - val_loss: 0.9871 - val_accuracy: 0.7500
Epoch 5/200
625/625 [==============================] - 225s 360ms/step - loss: 0.2307 - accuracy: 0.9234 - val_loss: 1.1059 - val_accuracy: 0.6720
Epoch 6/200
625/625 [==============================] - 228s 365ms/step - loss: 0.2016 - accuracy: 0.9341 - val_loss: 0.5819 - val_accuracy: 0.8370
Epoch 7/200
625/625 [==============================] - 227s 363ms/step - loss: 0.1859 - accuracy: 0.9380 - val_loss: 0.8662 - val_accuracy: 0.7740
Epoch 8/200
625/625 [==============================] - 223s 356ms/step - loss: 0.1732 - accuracy: 0.9419 - val_loss: 0.6927 - val_accuracy: 0.8130
Epoch 9/200
625/625 [==============================] - 221s 353ms/step - loss: 0.1631 - accuracy: 0.9446 - val_loss: 0.7033 - val_accuracy: 0.8090
Epoch 10/200
625/625 [==============================] - 220s 352ms/step - loss: 0.1416 - accuracy: 0.9528 - val_loss: 0.8072 - val_accuracy: 0.8130
Epoch 11/200
625/625 [==============================] - 221s 352ms/step - loss: 0.1403 - accuracy: 0.9524 - val_loss: 0.9740 - val_accuracy: 0.7570loss  accuracy  val_loss  val_accuracy
0   0.322520   0.89655  0.782460         0.767
1   0.299801   0.90180  0.910993         0.706
2   0.264406   0.91130  1.228291         0.676
3   0.246451   0.91880  0.987062         0.750
4   0.230691   0.92340  1.105917         0.672
5   0.201550   0.93405  0.581927         0.837
6   0.185873   0.93800  0.866195         0.774
7   0.173195   0.94185  0.692714         0.813
8   0.163083   0.94460  0.703307         0.809
9   0.141572   0.95280  0.807219         0.813
10  0.140329   0.95240  0.974021         0.757

训练结果

从输出的训练结果来看,效果没有ResNet50V2的,对参数进行了一些调整,没有太多的效果。各位如果有兴趣可以对这样的网络进行修改,从而提升验证的正确性。

模型验证

模型验证代码

from keras.models import load_model
import tensorflow as tf
from tensorflow.keras.utils import load_img, img_to_array
import numpy as np
import osimport matplotlib.pyplot as pltroot_path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Training Data/'class_names = sorted(os.listdir(root_path))model = load_model('./ResNet50_V1.h5')
model.summary()def load_image(path):'''This function will load the image present at the given location'''image = tf.cast(tf.image.resize(img_to_array(load_img(path))/255., (224,224)), tf.float32)#image = tf.cast(tf.image.resize(img_to_array(load_img(path)) / 255., (224, 224)), tf.float32)return imagei_path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Validation Data/Gorilla/Gorilla (3).jpeg'
image = load_image(i_path)
preds = model.predict(image[np.newaxis, ...])[0]print(preds)pred_class = class_names[np.argmax(preds)]confidence_score = np.round(preds[np.argmax(preds)], 2)# Configure Title
title = f"Pred : {pred_class}\nConfidence : {confidence_score:.2}"
print(title)plt.figure(figsize=(25, 8))
plt.title(title)
plt.imshow(image)
plt.show()while True:path =  input("input:")if (path == "q!"):exit()image = load_image(path)preds = model.predict(image[np.newaxis, ...])[0]print(preds)pred_class = class_names[np.argmax(preds)]confidence_score = np.round(preds[np.argmax(preds)], 2)# Configure Titletitle = f"Pred : {pred_class}\nConfidence : {confidence_score:.2}"print(title)plt.figure(figsize=(25, 8))plt.title(title)plt.imshow(image)plt.show()

相关内容

热门资讯

监控摄像头接入GB28181平... 流程简介将监控摄像头的视频在网站和APP中直播,要解决的几个问题是:1&...
Windows10添加群晖磁盘... 在使用群晖NAS时,我们需要通过本地映射的方式把NAS映射成本地的一块磁盘使用。 通过...
protocol buffer... 目录 目录 什么是protocol buffer 1.protobuf 1.1安装  1.2使用...
在Word、WPS中插入AxM... 引言 我最近需要写一些文章,在排版时发现AxMath插入的公式竟然会导致行间距异常&#...
Fluent中创建监测点 1 概述某些仿真问题,需要创建监测点,用于获取空间定点的数据࿰...
educoder数据结构与算法...                                                   ...
MySQL下载和安装(Wind... 前言:刚换了一台电脑,里面所有东西都需要重新配置,习惯了所...
MFC文件操作  MFC提供了一个文件操作的基类CFile,这个类提供了一个没有缓存的二进制格式的磁盘...
有效的括号 一、题目 给定一个只包括 '(',')','{','}'...
【Ctfer训练计划】——(三... 作者名:Demo不是emo  主页面链接:主页传送门 创作初心ÿ...