Keras实战之图像分类识别

文章目录

    • 整体流程
      • 数据加载与预处理
      • 搭建网络模型
      • 优化网络模型
        • 学习率
        • Drop-out操作
        • 权重初始化方法对比
        • 正则化
        • 加载模型进行测试

实战:利用Keras框架搭建神经网络模型实现基本图像分类识别,使用自己的数据集进行训练测试。

问:为什么选择Keras?
答:使用Keras便捷快速。用起来简单,入门容易,上手快。没有tensorflow那么复杂的规范。

整体流程

  1. 读取数据
  2. 数据预处理
  3. 切分数据集(分为训练集和测试集)
  4. 搭建网络模型(初始化参数)
  5. 训练网络模型
  6. 评估测试模型(通过对比不同参数下损失函数不断优化模型)
  7. 保存模型到本地

(1)手动配置参数,设置数据存储路径、模型保存路径、图片保存路径

# 输入参数,手动设置数据存储路径、模型保存路径、图片保存路径等
ap = argparse.ArgumentParser()
ap.add_argument("-d", "--dataset", required=True,
	help="path to input dataset of images")
ap.add_argument("-m", "--model", required=True,
	help="path to output trained model")
ap.add_argument("-l", "--label-bin", required=True,
	help="path to output label binarizer")
ap.add_argument("-p", "--plot", required=True,
	help="path to output accuracy/loss plot")
args = vars(ap.parse_args())

在这里插入图片描述

数据加载与预处理

# 拿到图像数据路径,方便后续读取
imagePaths = sorted(list(utils_paths.list_images(args["dataset"])))
random.seed(42)
random.shuffle(imagePaths)
# 数据洗牌前设置随机种子确保后面调参过程中训练数据集一样

# 遍历读取数据
for imagePath in imagePaths:
	# 读取图像数据,由于使用神经网络,需要输入数据给定成一维
	image = cv2.imread(imagePath)
	# 而最初获取的图像数据是三维的,则需要将三维数据进行拉长
	image = cv2.resize(image, (32, 32)).flatten()
	data.append(image)

	# 读取标签,通过读取数据存储位置文件夹来判断图片标签
	label = imagePath.split(os.path.sep)[-2]
	labels.append(label)

# scale图像数据,归一化
data = np.array(data, dtype="float") / 255.0
labels = np.array(labels)

# 转换标签,one-hot格式
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)

数据预处理:①通过数据除以255进行数据归一化;②对数据标签进行格式转换。

搭建网络模型

  1. 创建序列结构
model = Sequential()
  1. 添加全连接层
  • 第一层全连接层Dense设计512个神经元,当前输入特征个数(输入神经元个数)为3072,设置激活函数为"relu";
  • 第二层设计256个神经元;
  • 第三层设计类别数个神经元(即3个),并作softmax操作得到最终分类类别。
# 第一层
model.add(Dense(512, input_shape=(3072,),activation="relu"))
# 第二层
model.add(Dense(256, activation="relu",))
# 第三层
model.add(Dense(len(lb.classes_), activation="softmax",))
  1. 初始化参数
# 学习率
INIT_LR = 0.01
# 迭代次数
EPOCHS = 200
  1. 训练网络模型
# 给定损失函数和评估方法
opt = SGD(lr=INIT_LR) # 指定优化器为梯度下降的优化器
model.compile(loss="categorical_crossentropy", optimizer=opt,
	metrics=["accuracy"])

# 训练网络模型
H = model.fit(trainX, trainY, validation_data=(testX, testY),
	epochs=EPOCHS, batch_size=32)
  1. 测试网络模型

使用上面训练所得网络模型对测试集进行预测,并对比预测解国和数据集真实结果打印结果报告(包括准确率、recall、f1-score),并将损失函数以折线图的效果直观展示出来

predictions = model.predict(testX, batch_size=32)
print(classification_report(testY.argmax(axis=1),
	predictions.argmax(axis=1), target_names=lb.classes_))
  1. 评估结果
    在这里插入图片描述
    在这里插入图片描述
    从损失函数图像中可看出,模型出现明显过拟合现象,故而该初始参数所构建的模型效果较差,需要通过调参优化模型。

优化网络模型

学习率

对比学习率为0.01和0.001的损失函数图像。

在这里插入图片描述

train_loss与val_loss之间差异仍然存在,但是可看出学习率越大,过拟合现象越明显。

Drop-out操作

Dropout操作:在搭建网络模型中,通过设置一0到1范围内的参数从而防止过拟合。
在这里插入图片描述

权重初始化方法对比

(1)RandomNormal随机高斯初始化

kernel_initializer =initializers.random_normal(mean=0.0,stddev=0.05)
model.add(Dense(512, input_shape=(3072,),activation="relu",kernel_initializer =initializers.random_normal(mean=0.0,stddev=0.05)))
model.add(Dense(256, activation="relu",kernel_initializer =initializers.random_normal(mean=0.0,stddev=0.05)))
model.add(Dense(len(lb.classes_), activation="softmax",kernel_initializer =initializers.random_normal(mean=0.0,stddev=0.05)))

在这里插入图片描述
图中可看出,添加RandomNormal初始化后,过拟合现象减弱了一丢丢。

(2)TruncatedNormal截断

kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None)

相比于正常高斯分布截断了两边,只取小于2倍stddev的值

model.add(Dense(512, input_shape=(3072,), activation="relu" ,kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05)))
model.add(Dense(256, activation="relu",kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05)))
model.add(Dense(len(lb.classes_), activation="softmax",kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05)))

在这里插入图片描述

对比stddev取不同值时的loss函数图可得,TruncatedNormal中stddev值越小,过拟合风险越低,模型效果越好。TruncatedNormal消除过拟合的效果RandomNormal好。

正则化
kernel_regularizer=regularizers.l2(0.01)

正则化后,损失函数loss = 初始loss + aR(W)。正则化惩罚W,让稳定的W减少过拟合。

model.add(Dense(512, input_shape=(3072,), activation="relu" ,kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None),kernel_regularizer=regularizers.l2(0.01)))
model.add(Dense(256, activation="relu",kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None),kernel_regularizer=regularizers.l2(0.01)))
model.add(Dense(len(lb.classes_), activation="softmax",kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None),kernel_regularizer=regularizers.l2(0.01)))

对比正则化前后取迭代150到200的loss波动图,可发现正则化后虽然开始时loss值较大,但后期过拟合现象有明显减弱
在这里插入图片描述
再对比正则化参数l2 = 0.01和0.05的结果可得,l2越大,W的惩罚力度越大,过拟合风险越小
在这里插入图片描述

加载模型进行测试
# 导入所需工具包
from keras.models import load_model
import argparse
import pickle
import cv2

# 设置输入参数
ap = argparse.ArgumentParser()
ap.add_argument("-i", "--image", required=True,
	help="path to input image we are going to classify")
ap.add_argument("-m", "--model", required=True,
	help="path to trained Keras model")
ap.add_argument("-l", "--label-bin", required=True,
	help="path to label binarizer")
ap.add_argument("-w", "--width", type=int, default=28,
	help="target spatial dimension width")
ap.add_argument("-e", "--height", type=int, default=28,
	help="target spatial dimension height")
ap.add_argument("-f", "--flatten", type=int, default=-1,
	help="whether or not we should flatten the image")
args = vars(ap.parse_args())

# 加载测试数据并进行相同预处理操作
image = cv2.imread(args["image"])
output = image.copy()
image = cv2.resize(image, (args["width"], args["height"]))

# scale the pixel values to [0, 1]
image = image.astype("float") / 255.0

# 对图像进行拉平操作
image = image.flatten()
image = image.reshape((1, image.shape[0]))

# 读取模型和标签
print("[INFO] loading network and label binarizer...")
model = load_model(args["model"])
lb = pickle.loads(open(args["label_bin"], "rb").read())

# 预测
preds = model.predict(image)

# 得到预测结果以及其对应的标签
i = preds.argmax(axis=1)[0]
label = lb.classes_[i]

# 在图像中把结果画出来
text = "{}: {:.2f}%".format(label, preds[0][i] * 100)
cv2.putText(output, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7,
	(0, 0, 255), 2)

# 绘图
cv2.imshow("Image", output)
cv2.waitKey(0)

分类结果:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
通过预测结果可得:该模型在预测猫上存在较大误差,在预测熊猫上较为准确。或许改进增加迭代次数可进一步优化模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/779478.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

RedHat9 | Zabbix-Server监控服务部署

系统版本以及软件版本 使用的系统版本: Red Hat Enterprise Linux release 9.2 软件版本: zabbix-release-7.0-3.el9.noarchzabbix-web-7.0.0-release1.el9.noarchzabbix-web-mysql-7.0.0-release1.el9.noarchzabbix-web-deps-7.0.0-release1.el9.noar…

基于Android Studio订餐管理项目

目录 项目介绍 图片展示 运行环境 获取方式 项目介绍 能够实现登录,注册、首页、订餐、购物车,我的。 用户注册后,登陆客户端即可完成订餐、浏览菜谱等功能,点餐,加入购物车,结算,以及删减…

【电商纯干货分享】干货速看!电商数据集数据API接口数据分析大全!

数据分析——深入探索中小企业数字化转型,专注提供各行业数据分析干货、分析技巧、工具推荐以及各类超实用分析模板,为钻研于数据分析的朋友们加油充电。 公共参数 名称类型必须描述keyString是调用key(必须以GET方式拼接在URL中&#xff09…

Unity编辑器扩展之Inspector面板扩展

内容将会持续更新,有错误的地方欢迎指正,谢谢! Unity编辑器扩展之Inspector面板扩展 TechX 坚持将创新的科技带给世界! 拥有更好的学习体验 —— 不断努力,不断进步,不断探索 TechX —— 心探索、心进取&#xff…

Java--Super

1.super调用父类的构造方法,必须在构造方法的第一个 2.super必须只能出现在子类的方法或者构造方法中 3.super和this不能同时调用构造该方法 和this差别 1.代表的对象不同 this():代指本身调用者这个对象 super(&a…

Docker-文件分层与数据卷挂载详解(附案例)

文章目录 文件分层数据卷挂载的含义数据卷挂载实践数据卷挂载案例数据卷挂载方式数据卷常用命令容器间数据共享 更多相关内容可查看 文件分层 例:拉取mysql5.7的镜像,在继续拉取mysql5.8的镜像,会出现一部分文件已存在的现象 这种分层技术 是…

同花顺问财选股,使用自然语言的形式调接口选股

http形式的接口调用问财接口:https://stockapi.com.cn/v1/base/xuangu?strategy创业板,竞价涨幅大于1,竞价量比大于1 代码中调用该接口调试数据。

cmake编译源码教程(一)

1、介绍 本次博客介绍使用cmake编译平面点云分割的源代码,其对室内点云以及TLS点云中平面结构进行分割,分割效果如下: 2、编译过程 2.1 源代码下载 首先,下载源代码,如下所示,在该文件夹下新建一个build文件夹,用于后续生成sln工程。 同时,由于该库依赖open…

什么是 DDoS 攻击及如何防护DDOS攻击

自进入互联网时代,网络安全问题就一直困扰着用户,尤其是DDOS攻击,一直威胁着用户的业务安全。而高防IP被广泛用于增强网络防护能力。今天我们就来了解下关于DDOS攻击,以及可以防护DDOS攻击的高防IP该如何正确选择使用。 一、什么是…

springboot云南特色民宿预约系统-计算机毕业设计源码81574

目 录 第 1 章 引 言 1.1 选题背景 1.2 选题目的 1.3 论文结构安排 第 2 章 系统的需求分析 2.1 系统可行性分析 2.1.1 技术方面可行性分析 2.1.2 经济方面可行性分析 2.1.3 法律方面可行性分析 2.1.4 操作方面可行性分析 2.2 系统功能需求分析 2.3 系统性需求分析…

文件读写操作之c语言、c++、windows、MFC、Qt

目录 一、前言 二、c语言文件读写 1.写文件 2.读文件 三、c文件读写 1.写文件 2.读文件 四、windows api文件读写 1.写文件 2.读文件 五、MFC文件读写 1.写文件 2.读文件 六、Qt文件读写 1.写文件 2.读文件 七、总结 一、前言 我们在学习过程中&#xff0c…

idea使用技巧---超实用的mybatisX插件

一、使用原因 传统创建mybatis项目之后,在mapper接口和xml映射文件之间手动切换非常麻烦:不仅需要记住文件的所在位置,而且每次在mapper当中添加一个新的接口,都需要单独手动点开xml再编写sql; eg:在item…

前端面试题22(js中sort常见的用法)

JavaScript 的 sort() 方法是数组的一个非常强大的功能,用于对数组的元素进行排序。这个方法直接修改原数组,并返回排序后的数组。sort() 的默认行为是将数组元素转换为字符串,然后按照字符串的 Unicode 字典顺序进行排序。这意味着如果你试图…

优化路由,优化请求url

1、使用父子关系调整下使其更加整洁 2、比如说我修改了下url,那所有的页面都要更改 优化:把这个url抽出来,新建一个Api文件夹用于存放所有接口的url,在业务里只需要关注业务就可以 使用时 导包 发请求 如果想要更改路径,在这里…

docker-compose Install gitlab 17.1.1

gitlab 前言 GitLab 是一个非常流行的开源 DevOps 平台,用于软件开发项目的整个生命周期管理。它提供了从版本控制、持续集成/持续部署(CI/CD)、项目规划到监控和安全的一系列工具。 前提要求 Linux安装 docker docker-compose 参考Windows 10 ,11 2022 docker docker-c…

Zookeeper分布式锁原理说明【简单易理解】

Zookeeper 非公平锁/公平锁/共享锁 。 1.zookeeper分布式锁加锁原理 如上实现方式在并发问题比较严重的情况下,性能会下降的比较厉害,主要原因是,所有的连接都在对同一个节点进行监听,当服务器检测到删除事件时,要通知…

【Kafka】Kafka生产者开启幂等性后报错:Cluster authorization failed.

文章目录 背景解决服务端配置ACL增加授权 背景 用户业务需求,需要开启生产者的幂等性,生产者加了配置:enable.idempotence true用户使用的集群开启了ACL认证:SASL_PLAINTEXT/SCRAM-SHA-512用户生产消息时报错:org.ap…

惕佫酰假托品合酶的发现-文献精读28

Discovering a mitochondrion-localized BAHD acyltransferase involved in calystegine biosynthesis and engineering the production of 3β-tigloyloxytropane 发现一个定位于线粒体的BAHD酰基转移酶,参与打碗花精生物合成,并工程化生产惕佫酰假托品…

Git在多人开发中的常见用例

前言 作为从一个 svn 转过来的 git 前端开发,在经历过git的各种便捷好处后,想起当时懵懂使用git的胆颤心惊:总是害怕用错指令,又或者遇到报错就慌的场景,想起当时查资料一看git指令这么多,看的头晕眼花&am…

Java继承和多态

一.继承 继承顾名思义即一方可以把另一方的东西啊传承到自己手里。 例如猫和狗都是动物。动物都有吃饭,喝水等行为,也有年龄,体重的属性。 那么我们在定义猫和狗的时候就没必要去重复写,而是我们可以定义一个动物类&#xff0c…