简介高效的 CV 入门指南: 100 行实现 ConvNext 图像分类

简介高效的 CV 入门指南: 100 行实现 ConvNext 图像分类

  • 概述
  • ConvNext
  • 架构创新
  • 参数效率
  • 计算效率
  • 100 行代码实现 ConvNext 图像分类
    • ConvNext tiny
    • ConvNext Base

概述

在当今快速发展的人工智能领域, 计算机视觉 (Computer Vision, CV) 已称为一个关键的研究和应用领域. CV 可以使计算机理解图像和视频内容. CV 的核心目标是模拟和扩展人类的数据额系统功能, 使得机器能从图像或视频中自动提取, 处理, 分析和理解有用信息.

随着深度学习 (Deep Learning) 和神经网络 (Neural Network) 的兴起, 计算机视觉领域已经取得了显著的进步. 这些技术使得计算机能够通过学习大量的图像数据, 来识别和分类对象, 场景和活动. 应用敢为广泛, 从简单的图像分类到复杂的场景理解, 计算机视觉正逐渐成为日常生活和工业应用中不可或缺的一部分.

CV 技术已广泛应用于医疗成像, 自动驾驶汽车, 监控系统, 人脸识别, VR 等领域. 举个栗子: 在医疗领域, CV 能够帮我们诊断疾病, 通过分析医学图像来辅助医生做出更准确的诊断.

ConvNext

ConvNext (Convolutional Network Net Generation), 即下一代卷积神经网络, 是近些年来 CV 领域的一个重要发展. ConvNext 由 Facebook AI Research 提出, 仅仅通过卷积结构就达到了与 Transformer 结构相媲美的 ImageNet Top-1 准确率, 这在近年来以 Transformer 为主导的视觉问题解决趋势中显得尤为突出.

ConvNext

架构创新

ConvNext 网络的设计理念是在保持卷积神经网络 (Convolutional Neural Network, CNN) 核心优势的同时, 引入创新以解决传统 CNN 在处理复杂图像任务时遇到的限制. 这种新架构通过对传统卷积层的优化和改进, 显著提高了性能.

在传统的 CNN 中, 卷积层 (Convolution Layer) 通过滑动窗口对输入图像进行特征提取, 这一过程虽然有效, 但在处理大规模或复杂数据时往往会遇到性能瓶颈. ConvNext 通过引入改进的卷积策略, 如深度可分离卷积 (Depthwise Separable Convolution), 优化了这一过程. 深度可分离卷积将标准句那几分解为两个部分: 深度卷积核逐点卷积, 这样不仅减少了模型的参数数量, 还提高了计算效率.

深度可分离卷积

此外, ConvNext 还采用了更大的卷积核 (Kernel) 和调整后的步长, 以覆盖更广泛的输入特征并减少信息丢失. 这种设计使得 ConvNext 能够更有效地捕捉图像中的细节和上下文信息, 从而提高了模型在复杂视觉任务上的表现.

参数效率

参数效率是衡量神经网络设计优劣的重要指标之一. ConvNext 通过一系列创新设计显著提高了参数效率. 首先, 通过采用深度可分离卷积. ConvNext 大幅减少了模型参数的数量. 在传统卷积中, 参数数量随着输入通道和输出通道输的乘积线性增长. 而深度可分离卷积通过分离这两个过程, 显著降低了参数数量.

其次, ConvNext 在设计时还考虑到了参数的重用性. 通过精心设计的残差连接 (Residual Connect) 和注意力机制 (Attention Mechanism), 使得模型能够在不增加额外参数的情况下, 复用已有的特征表示. 这种设计不仅提高了模型的学习能力, 还进一步优化了参数的使用效率.

计算效率

在提高计算效率方面, ConvNext 采取了多项措施以优化网络结构, 减少计算资源的需求. 首先, 通过使用深度可分离卷积, ConvNext 大幅减低了每层所需的乘法-加法操作 (MACs). 这直接减少了模型的计算负担.

其次, ConvNext 在网络设计中引入了层标准化 (Layer Normalization) 和 GELU(Gaussian Error Linear Unit) 激活函数. 这些设计优化了网络的前向传播过程, 减少了计算过程中的冗余操作. 蹭标准化有助于稳定训练过程, 减少训练时所需的计算资源, 而 GELU 激活函数则提供了飞线性激活则提供了非线性激活的同时保持了较高的计算效率.

GELU 激活函数

最后, ConvNext 通过精简网络结构和优化数据流动路径, 减少了内存访问次数和数据传输量, 以提高模型运行效率. 这使得 ConvNext 在资源受限的设备上也能有效运行.

100 行代码实现 ConvNext 图像分类

ConvNext tiny

import logging
import tensorflow as tf
from load_image import ImageDataGenerator3# 定义超参数
EPOCHS = 20  # 迭代次数
BATCH_SIZE = 16  # 一次训练的样本数目
learning_rate = 5e-5
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)  # 优化器
loss = tf.losses.BinaryCrossentropy()  # 损失
logging.basicConfig(filename='../model/convnext_tiny/training_log.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')class TrainingLoggingCallback(tf.keras.callbacks.Callback):def on_epoch_end(self, epoch, logs=None):if logs is not None:logging.info(f"Epoch {epoch + 1}/{EPOCHS}")logging.info(f"loss: {logs['loss']} - accuracy: {logs['accuracy']}")logging.info(f"val_loss: {logs['val_loss']} - val_accuracy: {logs['val_accuracy']}")logging.info(f"lr: {self.model.optimizer.lr.numpy()}")class convnext_tiny(tf.keras.Model):def __init__(self):super(convnext_tiny, self).__init__()self.base_model = tf.keras.applications.convnext.ConvNeXtTiny(input_shape=(512, 512, 3), include_top=False, weights="imagenet")self.average_pooling_layer = tf.keras.layers.GlobalAveragePooling2D()self.output_layer = tf.keras.layers.Dense(1, activation="sigmoid")def call(self, inputs):x = self.base_model(inputs)x = self.average_pooling_layer(x)output = self.output_layer(x)return outputdef lr_schedule(epoch, lr):if epoch < 1:return lrelse:return lr * 0.95def main():# 获取数据image_generator = ImageDataGenerator3('../final_dataset-5_turns_chusai/train-metadata.json', batch_size=BATCH_SIZE)# 分割数据集,假设image_generator可以处理分割train_generator, val_generator = image_generator.split_data(test_size=0.2)# 建立模型inception = convnext_tiny()# 调试输出summaryinception.build(input_shape=[None, 512, 512, 3])print(inception.summary())# 配置模型inception.compile(optimizer=optimizer, loss=loss, metrics=["accuracy"])# 保存checkpoint = tf.keras.callbacks.ModelCheckpoint("../model/convnext_tiny/convnext_tiny.tf", monitor='val_loss',verbose=1, save_best_only=True, mode='min')# 学习率调度lr_scheduler = tf.keras.callbacks.LearningRateScheduler(lr_schedule)early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss',  # Monitor the validation losspatience=2,  # Number of epochs with no improvement after which training will be stoppedverbose=1,restore_best_weights=True# Whether to restore model weights from the epoch with the best value of the monitored quantity)# 训练inception.fit(train_generator, validation_data=val_generator, epochs=EPOCHS,callbacks=[checkpoint, lr_scheduler, TrainingLoggingCallback(), early_stopping])if __name__ == '__main__':main()

输出结果:

Model: "convnext_tiny"
_________________________________________________________________Layer (type)                Output Shape              Param #   
=================================================================convnext_tiny (Functional)  (None, 16, 16, 768)       27820128  global_average_pooling2d (  multiple                  0         GlobalAveragePooling2D)                                         dense (Dense)               multiple                  769       =================================================================
Total params: 27820897 (106.13 MB)
Trainable params: 27820897 (106.13 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
None

ConvNext Base

import logging
import numpy as np
import tensorflow as tf
from load_image import ImageDataGenerator3# 定义超参数
EPOCHS = 20  # 迭代次数
BATCH_SIZE = 8  # 一次训练的样本数目
learning_rate = 5e-6
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)  # 优化器
loss = tf.losses.BinaryCrossentropy()  # 损失
logging.basicConfig(filename='../model/convnext_base/training_log.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')class TrainingLoggingCallback(tf.keras.callbacks.Callback):def on_epoch_end(self, epoch, logs=None):if logs is not None:logging.info(f"Epoch {epoch + 1}/{EPOCHS}")logging.info(f"loss: {logs['loss']} - accuracy: {logs['accuracy']}")logging.info(f"val_loss: {logs['val_loss']} - val_accuracy: {logs['val_accuracy']}")logging.info(f"lr: {self.model.optimizer.lr.numpy()}")class convnext_base(tf.keras.Model):def __init__(self):super(convnext_base, self).__init__()self.base_model = tf.keras.applications.convnext.ConvNeXtBase(input_shape=(512, 512, 3), include_top=False, weights="imagenet")self.average_pooling_layer = tf.keras.layers.GlobalAveragePooling2D()self.output_layer = tf.keras.layers.Dense(1, activation="sigmoid")def call(self, inputs):x = self.base_model(inputs)x = self.average_pooling_layer(x)output = self.output_layer(x)return outputdef lr_schedule(epoch, lr):if epoch < 1:return lrelse:return lr * 0.95def main():# 获取数据image_generator = ImageDataGenerator3('../final_dataset-5_turns_chusai/train-metadata.json', batch_size=BATCH_SIZE)# 分割数据集,假设image_generator可以处理分割train_generator, val_generator = image_generator.split_data(test_size=0.125)# 建立模型inception = convnext_base()# 调试输出summaryinception.build(input_shape=[None, 512, 512, 3])print(inception.summary())# 配置模型inception.compile(optimizer=optimizer, loss=loss, metrics=["accuracy"])# 保存checkpoint = tf.keras.callbacks.ModelCheckpoint("../model/convnext_base/convnext_base.tf", monitor='val_loss',verbose=1, save_best_only=True, mode='min')# 学习率调度lr_scheduler = tf.keras.callbacks.LearningRateScheduler(lr_schedule)early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss',  # Monitor the validation losspatience=2,  # Number of epochs with no improvement after which training will be stoppedverbose=1,restore_best_weights=True# Whether to restore model weights from the epoch with the best value of the monitored quantity)# 训练inception.fit(train_generator, validation_data=val_generator, epochs=EPOCHS,callbacks=[checkpoint, lr_scheduler, TrainingLoggingCallback(), early_stopping])if __name__ == '__main__':main()

输出结果:

Model: "convnext_base"
_________________________________________________________________Layer (type)                Output Shape              Param #   
=================================================================convnext_base (Functional)  (None, 16, 16, 1024)      87566464  global_average_pooling2d (  multiple                  0         GlobalAveragePooling2D)                                         dense (Dense)               multiple                  1025      =================================================================
Total params: 87567489 (334.04 MB)
Trainable params: 87567489 (334.04 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
None

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/479390.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

学习鸿蒙基础(4)

1.条件渲染 ArkTS提供了渲染控制的能力。条件渲染可根据应用的不同状态&#xff0c;使用if、else和else if渲染对应状态下的UI内容。 当if、else if后跟随的状态判断中使用的状态变量值变化时&#xff0c;条件渲染语句会进行更新。。 Entry Component struct PageIfElse {Stat…

统信操作系统下数据库管理利器

PL/SQL是一款荷兰公司开发的数据库管理软件&#xff0c;尽管只支持Oracle一种数据库&#xff0c;但是在这一种数据库的支持上深度耕耘了30年&#xff0c;做到了Oracle管理的极致&#xff0c;从而拥有量海量的用户。 当然&#xff0c;随着时间的推移&#xff0c;PL/SQL也出现了一…

鸿蒙-基于ArkTS声明式开发的简易备忘录,适合新人学习,可用于大作业

本文地址&#xff1a;https://blog.csdn.net/qq_40785165/article/details/136161182?spm1001.2014.3001.5502&#xff0c;转载请附上此链接 大家好&#xff0c;我是小黑&#xff0c;一个还没秃头的程序员~~~ 不知不觉已经有很长一段时间没有分享过自己写的东西了&#xff0…

QPainter绘图与QPen、QFont及QBrush详细用法

一.基本绘图属性&#xff1a; 1.基本绘图类&#xff1a; Qt的绘图功能基于QPainter、QPainterDevice和QPainterEngine三个类。QPainter类在窗口和其他绘制设备上执行低级绘制。它提供高度优化的功能来完成大多数图形用户界面程序所需的工作&#xff0c;包括绘制基本的点、线、…

阿里云 OSS

阿里云对象存储服务&#xff08;Object Storage Service&#xff0c;简称 OSS&#xff09; OSS 为 Object Storage Service&#xff0c;即对象存储服务。是阿里云提供的海量、安全、低成本、高可靠的云存储服务。 OSS 具有与平台无关的 RESTful API 接口&#xff0c;可以在任…

深度学习(16)--基于经典网络架构resnet训练图像分类模型

目录 一.项目介绍 二.项目流程详解 2.1.引入所需的工具包 2.2.数据读取和预处理 2.3.加载resnet152模型 2.4.初始化模型 2.5.设置需要更新的参数 2.6.训练模块设置 2.7.再次训练所有层 2.8.测试网络效果 三.完整代码 一.项目介绍 使用PyTorch工具包调用经典网络架构…

vue3项目配置按需自动导入API组件unplugin-auto-import

场景应用&#xff1a;避免写一大堆的import&#xff0c;比如关于Vue和Vue Router的 1、安装unplugin-auto-import npm i -D unplugin-auto-import 2、配置vite.config import AutoImport from unplugin-auto-import/vite//按需自动加载API插件 AutoImport({ imports: ["…

项目发布前如何打tag标签及标签命名规范

项目发布前如何打tag标签及标签命名规范 1.问题背景 我们知道git分支可以理解为一系列提交流水组成的线&#xff0c;如果我们开发的项目到了一个比较重要的阶段&#xff0c;比如项目发布上线&#xff0c;处于方便后期代码的追溯和维护的考虑&#xff0c;如何在繁杂的日志提交…

一般小红书种草达人多少钱,投放注意事项

在互联网时代&#xff0c;小红书成为了广大消费者了解、评价和分享美妆、服饰、生活方式等方面的平台之一。平台上诸多用户的种草帖&#xff0c;已经成为了很多人发现新品、了解产品真实情况的重要渠道。同时众多品牌也纷纷加入了进来&#xff0c;今天我们和大家来分享下一般小…

力扣面试150 验证回文串 双指针 Character API

Problem: 125. 验证回文串 文章目录 思路复杂度Code 思路 &#x1f468;‍&#x1f3eb; 参考题解 Character.isLetterorDigit(char c)&#xff1a;判读字符 c 是否是字母或者数字 Character.toLowerCase(char c)&#xff1a;将字符 c 转换为小写字母 复杂度 时间复杂度: …

51_蓝桥杯_数码管静态显示

一 电路 二 数码管静态显示工作原理 三 代码 代码1 实现第一个数码管显示数字6的功能 #include "reg52.h"unsigned char code SMG_Duanma[18] {0xc0,0xf9.0xa4,0x99,0x92,0x82,0xf8,0x80,0x90,0x88,0x80,0xc0,0x86,0x8e,0xbf,0x7f};void Delay(unsignde int t) {wh…

【微服务生态】Docker

文章目录 一、基础篇1. 简介2. 下载与安装3. 常用命令3.1 帮助启动类3.2 镜像命令3.3 容器命令 4. Docker 容器数据券5. Docker 镜像5.1 commit 生成镜像5.2 Docker Registry5.3 发布镜像 6. Docker 常规安装软件 二、高级篇1. Dockerfile1.1 概述1.2 基础知识1.3 Dockerfile常…