tensorflow fashion_mnist数据集模型训练及预测
创始人
2024-03-16 09:30:19
0

✨ 博客主页:小小马车夫的主页
✨ 所属专栏:Tensorflow

文章目录

  • 前言
  • 一、环境
  • 二、fashion_mnist数据集介绍
  • 三、fashion_mnist数据集下载和展示
  • 四、数据预处理
  • 五、构建模型和训练模型
  • 六、模型预测
  • 总结


前言

前面介绍mnist手写数字集训练,本文对fashion_mnist数据集训练和预测进行简要介绍。


一、环境

MacOS: 13.0
Python: 3.9.13
Tensorflow: 2.11.0

二、fashion_mnist数据集介绍

fashion_mnist数据集和mnist数据集类似,都是28x28的灰度图片,区分是fashion_mnist数据集是服装图片,具体分类如下图:

分类英文描述中文描述
0t-shirtT恤
1trouser牛仔裤
2pullover套衫
3dress裙子
4coat外套
5sandal凉鞋
6shirt衬衫
7sneaker运动鞋
8bag
9ankle boot短靴

三、fashion_mnist数据集下载和展示

运用tensorflow下载fashion_mnist数据集与mnist类似,代码如下:

import tensorflow as tf
from tensorflow import keras
import numpy as npfashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
print(train_images.shape, train_labels.shape)
print(test_images.shape, test_labels.shape)

输出:

(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)

可以看到训练集是60000张28x28的灰度图,测试集是10000张28x28的灰度图。
一些样例展示如下:
fashion_mnist

四、数据预处理

数据预处理主要是对图片归一化处理,如下:

train_images=train_images / 255.
test_images = test_images / 255.

五、构建模型和训练模型

模型构建

model = keras.Sequential()
model.add(keras.layers.Flatten(input_shape=(28, 28)))
model.add(keras.layers.Dense(128, activation=tf.nn.relu))
model.add(keras.layers.Dense(10, activation=tf.nn.softmax))
model.summary()

模型结构如下:

Model: "sequential"
_________________________________________________________________Layer (type)                Output Shape              Param #   
=================================================================flatten (Flatten)           (None, 784)               0         dense (Dense)               (None, 128)               100480    dense_1 (Dense)             (None, 10)                1290      =================================================================
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________

模型训练

class MyCallback(tf.keras.callbacks.Callback):def on_epoch_end(self, epoch, logs={}):#loss小于0.25就停止训练if logs.get('loss') < 0.25:self.model.stop_training = True
callbacks = MyCallback()
model.compile(optimizer=tf.optimizers.Adam(), loss=tf.losses.sparse_categorical_crossentropy, metrics=['acc'])
h = model.fit(train_images, train_labels, batch_size=32, epochs=15, validation_data=(test_images_scaled, test_labels), callbacks=[callbacks])

查看结果

Epoch 1/15
1875/1875 [==============================] - 11s 5ms/step - loss: 0.5031 - acc: 0.8239 - val_loss: 0.4201 - val_acc: 0.8499
Epoch 2/15
1875/1875 [==============================] - 9s 5ms/step - loss: 0.3774 - acc: 0.8648 - val_loss: 0.4333 - val_acc: 0.8482
Epoch 3/15
1875/1875 [==============================] - 9s 5ms/step - loss: 0.3371 - acc: 0.8773 - val_loss: 0.3662 - val_acc: 0.8667
Epoch 4/15
1875/1875 [==============================] - 9s 5ms/step - loss: 0.3145 - acc: 0.8845 - val_loss: 0.3697 - val_acc: 0.8667
Epoch 5/15
1875/1875 [==============================] - 10s 5ms/step - loss: 0.2929 - acc: 0.8921 - val_loss: 0.3404 - val_acc: 0.8794
Epoch 6/15
1875/1875 [==============================] - 10s 5ms/step - loss: 0.2805 - acc: 0.8958 - val_loss: 0.3453 - val_acc: 0.8793
Epoch 7/15
1875/1875 [==============================] - 9s 5ms/step - loss: 0.2683 - acc: 0.9009 - val_loss: 0.3452 - val_acc: 0.8778
Epoch 8/15
1875/1875 [==============================] - 9s 5ms/step - loss: 0.2566 - acc: 0.9032 - val_loss: 0.3370 - val_acc: 0.8820
Epoch 9/15
1875/1875 [==============================] - 9s 5ms/step - loss: 0.2480 - acc: 0.9065 - val_loss: 0.3482 - val_acc: 0.8789

用图标显示损失曲线和准确率曲线

loss_list = h.history['loss']
acc_list = h.history['acc']
test_loss_list = h.history['val_loss']
test_acc_list = h.history['val_acc']plt.rcParams['font.sans-serif'] = ['Songti SC']
plt.rcParams['axes.unicode_minus'] = Falseplt.figure(figsize=(20, 10))plt.subplot(221)
plt.ylabel('loss')
plt.plot(loss_list, color='blue', marker='.', label='train_loss')
plt.plot(test_loss_list, color='red', marker='.', label='val_loss')
plt.legend(loc='upper left')
plt.title('损失曲线', fontsize=16)plt.subplot(222)
plt.ylabel('acc')
plt.plot(acc_list, color='blue', marker='.', label='train_acc')
plt.plot(test_acc_list, color='red', marker='.', label='val_acc')
plt.legend(loc='upper left')
plt.title('准确率曲线', fontsize=16)
plt.show()

输出:
在这里插入图片描述

六、模型预测

选一个图像进行预测:

image = tf.cast(test_images[1], tf.float32)
image = tf.reshape(image, [1, 28, 28])
np.argmax(model.predict(image))
print(test_labels[1])
plt.imshow(test_images[1])
plt.show()

输出:

1/1 [==============================] - 0s 408ms/step
2

predict

总结

本文主要介绍了tensorflow fashion_mnist的下载、训练、预测,模型用的全连接网络。

如果觉得有些帮助或觉得文章还不错,请关注一下博主,你的关注是我持续写作的动力。另外,如果有什么问题,可以在评论区留言,或者私信博主,博主看到后会第一时间进行回复。
【间歇性的努力和蒙混过日子,都是对之前努力的清零】
欢迎转载,转载请注明出处:https://blog.csdn.net/xxm524/article/details/128160073

相关内容

热门资讯

监控摄像头接入GB28181平... 流程简介将监控摄像头的视频在网站和APP中直播,要解决的几个问题是:1&...
Windows10添加群晖磁盘... 在使用群晖NAS时,我们需要通过本地映射的方式把NAS映射成本地的一块磁盘使用。 通过...
protocol buffer... 目录 目录 什么是protocol buffer 1.protobuf 1.1安装  1.2使用...
在Word、WPS中插入AxM... 引言 我最近需要写一些文章,在排版时发现AxMath插入的公式竟然会导致行间距异常&#...
Fluent中创建监测点 1 概述某些仿真问题,需要创建监测点,用于获取空间定点的数据࿰...
educoder数据结构与算法...                                                   ...
MySQL下载和安装(Wind... 前言:刚换了一台电脑,里面所有东西都需要重新配置,习惯了所...
MFC文件操作  MFC提供了一个文件操作的基类CFile,这个类提供了一个没有缓存的二进制格式的磁盘...
有效的括号 一、题目 给定一个只包括 '(',')','{','}'...
【Ctfer训练计划】——(三... 作者名:Demo不是emo  主页面链接:主页传送门 创作初心ÿ...