deeplearning4j~实现简单模型训练和测试

news/2025/1/9 9:54:03/文章来源:https://www.cnblogs.com/lori/p/18657779

DeepLearning4j (DL4J) 是一个开源的深度学习库,专为 Java 和 Scala 设计。它可以用于构建、训练和部署深度学习模型。以下是关于如何使用 DL4J 的基本指南以及一个简单的模型训练示例。

本例中使用了MNIST数据集,MNIST(modified national institute of standard and technology)数据集是由Yann LeCun及其同事于1994年创建一个大型手写数字数据库(包含0~9十个数字)。MNIST数据集的原始数据来源于美国国家标准和技术研究院(national institute of standard and technology)的两个数据集:special database 1和special database 3。它们分别由NIST的员工和美国高中生手写的0-9的数字组成。原始的这两个数据集由128×128像素的黑白图像组成。LeCun等人将其进行归一化和尺寸调整后得到的是28×28的灰度图像。

DeepLearning4j 使用指南

安装与配置

  1. 环境要求

    • Java Development Kit (JDK) 8 或以上版本
    • Maven(推荐)或 Gradle 用于项目管理
  2. 创建 Maven 项目
    在你的 IDE 中创建一个新的 Maven 项目,并在 pom.xml 文件中添加以下依赖:

    <dependencies><!-- DL4J Core --><dependency><groupId>org.deeplearning4j</groupId><artifactId>deeplearning4j-core</artifactId><version>1.0.0-M1.1</version></dependency><!-- ND4J (Numpy for Java) --><dependency><groupId>org.nd4j</groupId><artifactId>nd4j-native-platform</artifactId><version>1.0.0-M1.1</version></dependency><!-- DataVec for data preprocessing --><dependency><groupId>org.datavec</groupId><artifactId>datavec-api</artifactId><version>1.0.0-M1.1</version></dependency>
    </dependencies>
    
  3. 更新 Maven 依赖
    确保你的 IDE 更新了 Maven 依赖,下载所需的库。

简单的模型训练

下面是一个使用 DL4J 训练简单神经网络的示例,目标是对手写数字进行分类(MNIST 数据集)。

代码示例
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;public class MnistExample {public static void main(String[] args) throws Exception {// 加载 MNIST 数据集DataSetIterator mnistTrain = new MnistDataSetIterator(128, true, 12345);// 配置神经网络MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123).updater(new Adam(0.001)).list().layer(0, new DenseLayer.Builder().nIn(784).nOut(256).activation(Activation.RELU).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).activation(Activation.SOFTMAX).nIn(256).nOut(10).build()).build();// 创建并初始化网络MultiLayerNetwork model = new MultiLayerNetwork(conf);model.init();model.setListeners(new ScoreIterationListener(100));  // 每100次迭代输出一次分数// 训练模型for (int i = 0; i < 10; i++) { // 训练10个epochmodel.fit(mnistTrain);}System.out.println("训练完成!");// 加载 MNIST 测试数据集DataSetIterator mnistTest = new MnistDataSetIterator(128, false, 12345);// 评估模型double accuracy = model.evaluate(mnistTest).accuracy();System.out.println("模型准确率: " + accuracy);// 保存模型到文件File modelFile = new File("mnist_model.zip");ModelSerializer.writeModel(model, modelFile, true);}
}
代码说明
  1. 加载数据集:使用 MnistDataSetIterator 加载 MNIST 数据集。
  2. 配置神经网络
    • 使用 NeuralNetConfiguration.Builder 构建神经网络配置。
    • 添加输入层(DenseLayer)和输出层(OutputLayer)。
  3. 创建和初始化模型:使用 MultiLayerNetwork 创建模型并初始化。
  4. 训练模型:通过循环调用 fit() 方法训练模型。

运行示例

确保你的环境已正确设置,然后运行上述代码。模型将在 MNIST 数据集上进行训练,训练完成后会输出“训练完成!”的信息。

模型评估

在训练完模型后,通常需要对其进行评估,以了解模型在未见数据上的表现。你可以使用测试集来评估模型的准确性和其他性能指标。

保存与加载模型

训练完成后,你可能希望保存模型以便以后使用。DL4J 提供了简单的方法来保存和加载模型。

调整与优化模型

根据评估结果,你可能需要调整模型的超参数或架构。可以尝试以下方法:

  • 增加层数或节点数:增加模型的复杂性。
  • 改变学习率:试验不同的学习率以找到最佳值。
  • 使用不同的激活函数:例如,尝试 LeakyReLUELU
  • 正则化:添加 Dropout 层或 L2 正则化以防止过拟合。

部署模型

如果你打算将模型应用于生产环境,可以考虑将其部署为服务。可以使用以下方式之一:

  • REST API:将模型包装为 RESTful 服务,方便客户端调用。
  • 嵌入式应用:将模型嵌入到 Java 应用程序中,直接进行预测。

模型的测试

使用 Java 和 DeepLearning4j 来训练自己的手写数字图像(例如 0 到 9 的标准图像)是一个很好的项目。下面是一个简单的步骤指南,帮助你实现这个目标。

步骤概述

  1. 准备数据:将你的数字图像准备为合适的格式。
  2. 创建和配置模型:使用 DeepLearning4j 创建神经网络模型。
  3. 训练模型:使用你的图像数据训练模型。
  4. 评估和测试模型:验证模型的性能。

1. 准备数据

首先,你需要将你的 0-9 数字图像准备好。假设你有 10 张图像,每张图像都是 28x28 像素的灰度图像,并且它们存储在本地文件系统中。

模型测试的步骤

步骤 1: 使用 MNIST 数据集训练模型

  1. 加载数据集:使用 MnistDataSetIterator 加载 MNIST 数据集。
  2. 构建模型:根据你的需求,构建一个适合的神经网络模型。
  3. 训练模型:使用 MNIST 数据集对模型进行训练。
  4. 保存模型:将训练好的模型保存到文件中(例如,保存为 .zip 文件)。

步骤 2: 准备手写数字图片

  1. 手写数字:自己手写一个数字 1,并拍照或扫描成图片。
  2. 预处理图片
    • 将图片转换为灰度图像。
    • 调整图片大小为 28x28 像素(MNIST 数据集中的标准尺寸)。
    • 对图像进行归一化处理(通常将像素值缩放到 [0, 1] 范围内)。

步骤 3: 比较手写数字与 MNIST 数据集

  1. 加载保存的模型:从 zip 文件中加载之前训练好的模型。
  2. 预测手写数字:将预处理后的手写数字图片输入到模型中进行预测。
  3. 输出结果:模型将输出手写数字的预测结果。你可以将这个结果与 MNIST 数据集中相应的标签进行比较。

注意事项

  • 数据预处理:确保手写数字的预处理方式与训练时一致,包括图像大小、颜色通道和归一化。
  • 模型评估:在比较之前,可以先在测试集上评估模型的性能,以确保其准确性。
  • 可视化结果:可以通过可视化工具(如 matplotlib)展示手写数字及其预测结果,以便更好地理解模型的表现。

示例代码

以下是一个简单的示例代码框架,展示了如何实现这些步骤
[MnistUtils.java]

/*** @author lind* @date 2025/1/7 14:27* @since 1.0.0*/
public class MnistUtils {/*** 将图像转换为灰度图像** @param original* @return*/private static BufferedImage convertToGrayscale(BufferedImage original) {BufferedImage grayImage = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_BYTE_GRAY);Graphics g = grayImage.getGraphics();g.drawImage(original, 0, 0, null);g.dispose();return grayImage;}/*** 调整图像大小** @param original* @param width* @param height* @return*/private static BufferedImage resizeImage(BufferedImage original, int width, int height) {Image scaledImage = original.getScaledInstance(width, height, Image.SCALE_SMOOTH);BufferedImage resizedImage = new BufferedImage(width, height, BufferedImage.TYPE_BYTE_GRAY);Graphics2D g2d = resizedImage.createGraphics();g2d.drawImage(scaledImage, 0, 0, null);g2d.dispose();return resizedImage;}/*** 加载图像** @param fileName* @return*/public static INDArray loadGrayImg(String fileName) {try {// 1. 加载图片BufferedImage originalImage = ImageIO.read(new File(fileName));// 2. 转换为灰度图像BufferedImage grayImage = convertToGrayscale(originalImage);// 3. 调整大小为 28x28 像素BufferedImage resizedImage = resizeImage(grayImage, 28, 28);// 4. 进行归一化处理return normalizeImage(resizedImage);} catch (IOException e) {e.printStackTrace();}return null;}/*** 对图像进行归一化处理并生成 INDArray** @param image* @return*/private static INDArray normalizeImage(BufferedImage image) {int width = image.getWidth();int height = image.getHeight();double[] normalizedData = new double[width * height]; // 创建一维数组for (int y = 0; y < height; y++) {for (int x = 0; x < width; x++) {// 获取灰度值(0-255)int grayValue = image.getRGB(x, y) & 0xFF; // 只取灰度部分// 归一化到 [0, 1] 范围normalizedData[y * width + x] = grayValue / 255.0; // 填充一维数组}}// 将一维数组转换为 INDArray,并添加批次维度INDArray indArray = Nd4j.create(normalizedData).reshape(1, 784); // reshape to [1, 784]return indArray;}
}

[MnistTest.java]

public static void main(String[] args) throws IOException {// 加载已训练的模型MultiLayerNetwork model = MultiLayerNetwork.load(new File("E:\\github\\lind-deeplearning4j\\mnist_model.zip"), true);// 测试图像路径String testImagePath = "d:\\dlj4\\img\\";// 假设你有10个测试图像,命名为 0.png 到 9.png,当我从MNIST数据集网站下载9张图片后,这个大模型确实可以给我识别出来for (int i = 0; i <= 3; i++) {String fileName = testImagePath + i + ".png";System.out.println("fileName=" + fileName);INDArray testImage = loadGrayImg(fileName);INDArray output = model.output(testImage); // 进行预测// 获取预测结果int predictedClass = Nd4j.argMax(output, 1).getInt(0);System.out.println("测试图像 " + i + " 的预测结果: " + predictedClass);}}

模型测试结果,它会根据0-3的图片,将图片上面的数字分析出来,这个事实上是根据我们训练的MINIST数据集得到的结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/865367.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

云上攻防-云原生K8s安全实战场景攻击Pod污点Taint横向移动容器逃逸

云原生&K8s安全&实战场景&攻击Pod&污点Taint&横向移动&容器逃逸知识点 1、云原生-K8s安全-横向移动-污点Taint 2、云原生-K8s安全-Kubernetes实战场景 一、演示案例-云原生-K8s安全-横向移动-污点Taint 如何判断实战中能否利用污点Taint? 设置污点 ku…

深入解析 Spring AI 系列:项目结构一览

从今天起,我们将以 Spring AI 为主线,开始更新一系列的文章。这些文章将围绕 Spring AI 项目展开,结合我的理解,深入讲解其相关的知识点、技术原理、以及在实际开发过程中涉及到的部分代码实现等内容。通过这些文章,希望能够帮助大家更好地理解和使用 Spring AI。今天的主…

BUG:SWM32开机绘制lvgl框架下的某个自定义控件死机

一.BUG描述 现象1.画了一个关于"模式"的自定义控件,结果开机绘制总是死机。 现象2.用keil进行仿真调试全速运行同样死机,但是如果在异常处加断点,然后单步调试就正常。(注:仿真调试比直接运行的速度要慢) 现象3.把这个异常对象的创建代码删除,再后面加四个打印…

大镖客2--中文地图

下载GMM 从网站 https://gmm.aoe.top/ 下载GMM软件 双击安装GMM即可 下载前置 下载大镖客2前置 从我网盘下载: 通过网盘分享的文件:荒野大镖客2 Mod前置包.gmm链接: https://pan.baidu.com/s/1IRorQ9FO7Dy_LzU70U_uRw?pwd=6666 提取码: 6666 选择游戏 找到RDR2.exe文件 …

CDS标准视图:测量点数据 I_MeasuringPointData

视图名称:测量点数据 I_MeasuringPointData 视图类型:基础视图 视图代码:点击查看代码 @AbapCatalog.sqlViewName: IMEASPOINTDATA @AbapCatalog.compiler.compareFilter: true @AccessControl.authorizationCheck: #CHECK @EndUserText.label: Measuring Point Data@VDM.vi…

vxe-table 实现 excel 选择两个单元格,拖拽自动识别数字规则并根据规则自动填充新的单元格

vxe-table 实现 excel 选择两个单元格,拖拽自动识别数字规则并根据规则自动填充新的单元格 官网:https://vxetable.cn鼠标按住右下角扩展按钮,当选取一个单元格时,自动将当前内容填充到扩展区域的所有单元格中,如果不希望自动识别数字规则,可以同时按住 ctrl 键可取消值自…

《模态测试、分析与仿真》网络研讨会

模态测试的目的和方法 如何利用有限元仿真数据进行建模和优化测点选择 模态参数识别和验证 模态分析结果如何与有限元仿真结果进行相关以上内容来源于网络,如有侵权,请联系删除!

新开了一家宠物公园,如何营销?

宠物公园可以借助看板软件进行营销,通过明确营销目标、利用看板软件进行任务管理、数据可视化与分析、优化客户互动与品牌建设以及整合供应链与提升运营效率等多个方面提升营销效率和效果。宠物公园可以借助看板软件进行营销,通过数据可视化、任务管理、客户互动等多个方面提…

pkusc/wc 做题记录题解

漆黑空心也想被释放,奈何思念比恨更顽强鲜花所有题目来源:search pku in qoj 参考文章:DaiRuiChens blog:PKUSC2024,云浅s blog:THUSC/PKUSC 2024cyfffs blog:PKUWC2024,realskcs blog:PKUWC2024 Day1,PKUWC2024 Day2ix35s blog:PKUSC2023,aleph1022s blog:PKUSC …

荣耀再续 | 蚁景科技在第136届广交会测试赛中再获团体亚军佳绩

近日,在第136届广交会网络平台测试赛中,蚁景科技凭借卓越表现与精湛技术脱颖而出,斩获团体积分132分,再度荣膺第二名的佳绩,续写荣誉篇章。近日,在备受瞩目的第 136 届广交会网络平台测试赛中,蚁景科技凭借卓越表现与精湛技术脱颖而出。此次比赛由中国网络空间安全教育论…

altium designer中布线如何走圆角.如图所示。我按了shift+空格键。可依旧是走45读的角。麻烦指教啊。急求

按TP键,在PCB Editor_Interactive Routing中,看是否选择了Restrict To 90/45,取消勾选即可更多追问追答追问取消了还是没用 追答shift+空格键要多按几次,是循环切换的转载自: https://zhidao.baidu.com/question/536745523.html

nginx下载安装配置

nginx下载安装配置 环境:win10虚拟机Vmware,在linux(centos7.6)下安装nginx(nginx-1.20.2.tar.gz) 注意:在linux中,不是管理员登录,最好在指令前加上sudo提权,避免权限不够执行失败带来的麻烦,以下命令请都加上sudo,这是一个好习惯。 下载 https://nginx.org/ …