TensorFlow域对抗训练DANN神经网络分析MNIST与Blobs数据集梯度反转层提升目标域适应能力可视化

news/2025/2/22 16:18:05/文章来源:https://www.cnblogs.com/tecdat/p/18730980

全文链接:https://tecdat.cn/?p=39656

原文出处:拓端数据部落公众号

本文围绕基于TensorFlow实现的神经网络对抗训练域适应方法展开研究。详细介绍了梯度反转层的原理与实现,通过MNIST和Blobs等数据集进行实验,对比了不同训练方式(仅源域训练、域对抗训练等)下的分类性能。结果表明,域对抗训练能够有效提升模型在目标域上的适应能力,为解决无监督域适应问题提供了一种有效的途径。

 

在机器学习和深度学习领域,域适应是一个重要的研究方向。不同数据源(即不同域)之间往往存在分布差异,这使得在一个域上训练的模型在另一个域上的性能显著下降。“Unsupervised Domain Adaptation by Backpropagation” 论文提出了一种简单有效的方法,通过随机梯度下降(SGD)和梯度反转层来实现域适应。后续的 “Domain - Adversarial Training of Neural Networks” 对该工作进行了详细阐述和扩展。

梯度反转层

梯度反转层是实现域对抗训练的关键。

 
  1.  
    # 反转 x 关于 y 的梯度,并按 l 进行缩放(默认为 1.0)
  2.  
    y = flip_gradient(x, l)
 
MNIST
构建MNIST - M数据集
实验结果对比

以下是大致的结果:

Blobs - DANN
Blob数据集
 
  1.  
     
  2.  
    # 绘制数据集
  3.  
    plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap='coolwarm', alpha=0.4)
  4.  
    plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap='cool', alpha=0.4)
  5.  
    plt.show()
 

Blob数据集可视化

构建模型

不同训练方式的实验
  • 域分类:设置 grad_scale=-1.0 可以有效关闭梯度反转。仅训练域分类器会创建使类别合并的表示。
 
  1.  
     
  2.  
    train_loss = sess.graph.get_tensor_by_name(train_loss_name + ':0')
  3.  
    train_op = sess.graph.get_operation_by_name(train_op_name)
  4.  
    sess.run(tf.global_variables_initializer())
  5.  
    for i in range(num_batches):
  6.  
    if grad_scale is None:
 
不同训练方式的实验
  • 域分类
 
  1.  
     
  2.  
    F = sess.graph.get_tensor_by_name(feat_tensor_name + ':0')
  3.  
    emb_s = sess.run(F, feed_dict={'X:0': Xs})
  4.  
    emb_t = sess.run(F, feed_dict={'X:0': Xt})
  5.  
    emb_all = np.vstack([emb_s, emb_t])
  6.  
    pca = PCA(n_components=2)
  7.  
    pca_emb = pca.fit_transform(emb_all)
  8.  
    num = pca_emb.shape[0] // 2
  9.  
    plt.scatter(pca_emb[:num, 0], pca_emb[:num, 1], c=ys, cmap='coolwarm', alpha=0.4)
  10.  
    plt.scatter(pca_emb[num:, 0], pca_emb[num:, 1], c=yt, cmap='cool', alpha=0.4)
  11.  
    plt.show()
  12.  
    train_and_evaluate(sess, 'domain_train_op', 'domain_loss', grad_scale=-1.0, verbose=False)
  13.  
    extract_and_plot_pca_feats(sess)
 

运行结果如下:

域分类PCA特征可视化
从结果可以看出,仅训练域分类器时,模型能够很好地区分源域和目标域,但对类别的区分能力较差,这表明这种训练方式创建的表示使类别合并了。

  • 标签分类

运行结果如下:

标签分类PCA特征可视化
在源域上进行标签预测训练时,模型在源域上能够很好地区分不同类别,但在目标域上的类别区分能力较差,说明这种训练方式对目标域的适应能力不足。

  • 域适应

运行结果如下:

域适应PCA特征可视化
使用域对抗损失进行训练时,模型在源域和目标域上的类别分类准确率都较高,说明域对抗训练能够有效提升模型在目标域上的适应能力。

  • 更深的域分类器的域适应

运行结果如下:

更深域分类器的域适应PCA特征可视化
使用更深的域分类器进行域适应训练时,在多次实验中似乎更能可靠地合并域,同时保持较高的类别分类准确率。

MNIST - DANN

数据处理

在数据处理阶段,我们对MNIST和MNIST - M数据集进行了预处理。对于MNIST数据,将其转换为适合卷积神经网络输入的格式,并扩展为三通道图像。MNIST - M数据则直接从之前生成的 pkl 文件中加载。通过计算像素均值,我们对数据进行归一化处理,这有助于提高模型的训练效果。最后,创建了一个混合数据集用于后续的TSNE可视化,方便我们直观地观察模型在不同域上的特征分布情况。

数据可视化

MNIST训练数据可视化
MNIST - M训练数据可视化
通过 函数对MNIST和MNIST - M的训练数据进行可视化展示,我们可以直观地看到两个数据集之间的差异,这也体现了域适应问题的挑战性,即不同域之间的数据分布存在明显差异。

构建模型
 
  1.  
     
  2.  
    # 特征提取器 - CNN模型
  3.  
     
  4.  
    b_conv1 = bias_variable([48])
  5.  
    h_conv1 = tf.nn.relu(conv2d(h_pool0, W_conv1) + b_conv1)
  6.  
    h_pool1 = max_pool_2x2(h_conv1)
  7.  
    self.feature = tf.reshape(h_pool1, [-1, 7 * 7 * 48])
  8.  
    # 标签预测器 - MLP模型
  9.  
    with tf.variable_scope('label_predictor'):
  10.  
     
  11.  
    W_fc2 = weight_variable([100, 10])
  12.  
    b_fc2 = bias_variable([10])
  13.  
    logits = tf.matmul(h_fc1, W_fc2) + b_fc2
  14.  
    self.pred = tf.nn.softmax(logits)
  15.  
    self.pred_loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=self.classify_labels)
  16.  
    # 域预测器 - 小MLP模型,带有对抗损失
  17.  
     
  18.  
    d_b_fc1 = bias_variable([2])
  19.  
    d_logits = tf.matmul(d_h_fc0, d_W_fc1) + d_b_fc1
  20.  
    self.domain_pred = tf.nn.softmax(d_logits)
  21.  
    self.domain_loss = tf.nn.softmax_cross_entropy_with_logits(logits=d_logits, labels=self.domain)
 

该模型主要由三个部分组成:特征提取器、标签预测器和域预测器。特征提取器使用卷积神经网络(CNN)从输入图像中提取特征;标签预测器是一个多层感知机(MLP),用于对图像的类别进行预测;域预测器同样是一个MLP,用于判断输入数据来自源域还是目标域。在域预测器中,使用了梯度反转层 flip_gradient 来实现对抗训练,使得特征提取器学习到的特征能够在不同域之间具有不变性。

模型训练与评估

上述代码实现了两种训练模式:仅在源域上训练(source)和使用域对抗训练(dann)。在训练过程中,根据论文中的方法动态调整适应参数 l 和学习率 lr
运行结果如下:

从结果可以看出,仅在源域上训练时,模型在源域(MNIST)上有较高的准确率,但在目标域(MNIST - M)上的准确率较低,说明模型对目标域的适应能力较差。而使用域对抗训练后,虽然源域的准确率略有下降,但目标域的准确率有了显著提升,表明域对抗训练有效地提高了模型在不同域之间的泛化能力。

特征可视化
 
  1.  
     
  2.  
    plot_embedding(dann_tsne
 



通过t - 分布随机邻域嵌入(t - SNE)方法将高维特征映射到二维空间进行可视化。从可视化结果可以直观地看到,仅在源域上训练时,源域和目标域的数据在特征空间中分离明显,说明模型没有学习到域不变的特征。而使用域对抗训练后,源域和目标域的数据在特征空间中更加接近,表明模型学习到了更具泛化性的特征,能够更好地适应不同的域。

结论

本文详细介绍了基于TensorFlow实现的神经网络对抗训练域适应方法。通过梯度反转层和域对抗训练,模型能够学习到域不变的特征,从而提高在目标域上的分类性能。在MNIST和Blobs数据集上的实验结果表明,域对抗训练相比于仅在源域上训练,能够显著提升模型在目标域上的准确率。同时,通过特征可视化可以直观地观察到域对抗训练对特征分布的影响,进一步验证了该方法的有效性。未来的研究可以考虑在更复杂的数据集和任务上应用该方法,以及探索如何进一步优化域对抗训练的效果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/888046.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【专题】2025年我国机器人产业发展形势展望:人形机器人量产及商业化关键挑战报告汇总PDF洞察(附原数据表)

原文链接:https://tecdat.cn/?p=39668 机器人已广泛融入我们生活的方方面面。在工业领域,它们宛如不知疲倦的工匠,精准地完成打磨、焊接等精细工作,极大提升了生产效率和产品质量;在日常生活里,它们是贴心的助手,扫地机器人默默清扫房间,陪伴机器人给予老人孩子温暖陪…

vba主动着色

原来的条件格式效率太低,改为主动方式着色 Sub SetColor() On Error Resume Next Dim hang As Integer 行数 Dim lie As Integer Dim IsBuy As Boolean Dim IsSell As Boolean hang = ActiveSheet.UsedRange.Rows.Count With ActiveSheet …

2025省选模拟13

2025省选模拟13\(T1\) P1025. Easy Problem \(40pts\)部分分\(40pts\)设 \(f_{i,j}\) 表示 \(p_{3j}=i\) 时 \([1,i]\) 对答案的贡献,状态转移方程为 \(f_{i,j}=\max\limits_{k=3(j-1)}^{i-3} \{ f_{k,j-1}+w(k+1,i) \}\) ,其中 \(w(k+1,i)\) 表示 \([k+1,i]\) 的次大值。 设…

installerX还你一个清爽的安装

相信大家都有被手机自带的软件安装器折磨的情况,各种禁止安装,这种验证和识别,不开启安全模式和开了没区别,针对这种情况有没有什么办法绕过呢? 我们可以使用开源软件installerX,这款软件使用拥有这类原生的安装体验,安装速度也不差,并且简洁高效,还可以进行降级安装。…

[Paper Writting] 论文画图指南

目录Motivation方法概念图新老对比类方法简图类实物示意图效果示意图Architecture Motivation 方法概念图 HPT新老对比类 OSXMOTRUniADMulti-modal 3D Human Pose Estimation方法简图类 MoCoconformerBEVFormerDETRDriveVLM实物示意图 emg2pose效果示意图 umetracktransmvshoid…

不到24小时,AOne让全员用上DeepSeek的秘诀是……

DeepSeek引发新一轮AI浪潮,面对企业数字化智能升级与数据安全红线的急迫需求,IT负责人的压力山大!如何在24小时内实现全员AI落地,同时为后续安全部署铺平道路?Step1:一键开启全员智能时代 基于国产大模型领军者DeepSeek(671B满血版&70B版),天翼云AOne搭载智能引擎…

Unity Addresable打包总结第一弹

前言 使用AB包很久了,一直没有机会做一个系统的总结,趁现在准备离职,时间空闲比较多,将项目内的Addresable使用经验大致的分析总结一下,以作日后备用。 使用介绍 下方的引用链接中,发哥已经总结的很详细了,但我这里还是稍微介绍一下基本流程。 基本流程在Package Manage…

AutoCAD 逆向工程中 Shx 字体文件解析

数据格式相关的文章代码实现 https://blog.csdn.net/qq_29830577/article/details/78604983#####愿你一寸一寸地攻城略地,一点一点地焕然一新#####

golang学习笔记——gorm

gen是gorm官方推出的一个GORM代码生成工具 官方文档:https://gorm.io/zh_CN/gen/ 1.使用gen框架生成model和dao 安装gorm gengo get -u gorm.io/gen假设有如下用户表CREATE TABLE user (`id` bigint unsigned NOT NULL AUTO_INCREMENT COMMENT 主键,`username` varchar(1…

原神

oj.hailiangedu.com/file/22/dragon.gif

平衡树从启蒙到入土

首先得承认伊德利拉美貌盖世无双将数列改成数后处理起来更舒服 什么是平衡树 更广泛的定义:左右子树高度不超过 1 的 如果将这东西和二叉搜索树结合,便是平衡树搜索树 平衡树分类:treap 随机 splay 贪心 fhq 合并 分裂fhq 实现 合并 给出两个树,根分别为 a、b,如果我们将 …