EMA训练微调

就是取前几个epoch的weight的平均值,可以缓解微调时的灾难性遗忘(因为新数据引导,模型权重逐渐,偏离训练时学到的数据分布,忘记之前学好的先验知识)
在这里插入图片描述

class EMA():def __init__(self, model, decay):self.model = modelself.decay = decay  # decay rateself.shadow = {}  # old weightself.backup = {}  # new weightdef register(self):  # deep copy weight for initfor name, param in self.model.named_parameters():if param.requires_grad:self.shadow[name] = param.data.clone()def update(self):  # ema:average weight for trainfor name, param in self.model.named_parameters():if param.requires_grad:assert name in self.shadownew_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]self.shadow[name] = new_average.clone()def apply_shadow(self):  # load old weight for eval beginfor name, param in self.model.named_parameters():if param.requires_grad:assert name in self.shadowself.backup[name] = param.dataparam.data = self.shadow[name]def restore(self):  # load new weight for eval endfor name, param in self.model.named_parameters():if param.requires_grad:assert name in self.backupparam.data = self.backup[name]self.backup = {}# 初始化
ema = EMA(model, 0.999)
ema.register()# 训练过程中,更新完参数后,同步update shadow weights
def train():optimizer.step()ema.update()# eval前,apply shadow weights;eval之后,恢复原来模型的参数
def evaluate():ema.apply_shadow()# evaluateema.restore()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/232197.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

TikTok区块链实践:数字社交媒体的去中心化未来

随着区块链技术的日渐成熟,数字社交媒体行业也在探索如何整合区块链,以推动去中心化发展。在这一潮流中,TikTok作为全球领先的短视频平台,积极实践区块链技术,探索数字社交媒体的未来。本文将深入探讨TikTok的区块链实…

功能全面又强大的同步备份软件,你找到了吗?

随着企业规模的不断扩大,许多企业都会拥有自己的数据中心。因此每日员工都需要在服务器与服务中心之间调取文件,同时还需要对每日新增的业务数据进行实时同步。如果量比较小,一般问题不大;一旦数据比较大,量也比较大&a…

箭头函数与普通函数:谁更胜一筹?

🤍 前端开发工程师(主业)、技术博主(副业)、已过CET6 🍨 阿珊和她的猫_CSDN个人主页 🕠 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 🍚 蓝桥云课签约作者、已在蓝桥云…

酷开科技:让体育迷的热情释放,让运动精神传递

在繁忙的生活节奏中,我们总是被各种琐事所困扰,很难抽出时间去享受运动带来的快乐,甚至很少有时间去观看一场体育赛事。而一场好的体育赛事带给体育爱好者的快乐往往来自于两方面,一是线下参与,感受现场带来的震撼&…

docker搭建node环境开发服务器

docker搭建node环境开发服务器 本文章是我自己搭建node环境开发服务器的过程记录,不一定完全适用所有人。根据个人情况,按需取用。 命名项目路径 为了方便cd到项目路径,将项目路径重命名,方便输入。 vim /etc/profile # 修改p…

R语言gWQS包在加权分位数和回归模型的应用

在流行病学研究中,相较于单一因素的暴露,多因素同时暴露的情况更为常见。传统模型在评价多因素联合暴露时存在数据维度高、多重共线性等问题. WQS 回归模型的基本原理是通过分位数间距及加权的方法,将多种研究因素的效应综合成为一个指数&…

7-22 龟兔赛跑

import java.util.Scanner; class Main {public static void main(String[] args) {Scanner scnew Scanner(System.in);int timesc.nextInt();sc.close();int wugui 0;//乌龟里程int tuzi 0;//兔子里程int tuzi_run0;int tuzi_rest0;int is_rest0;//是否需要休息:…

visual studio 2022 更改字体和大小

工具--->选项 文本编辑器 输出窗口

【Hadoop】集群资源管理器 YARN

一、yarn 简介 Apache YARN (Yet Another Resource Negotiator) 是 hadoop 2.x 引入的分布式资源管理系统。主要用于解决 hadoop 1.x 架构中集群资源管理和数据计算耦合在一起,导致维护成本越来越高的问题。 yarn主要负责管理集群中的CPU和内存 用户可以将各种服…

JS获取字符串里最长的回文字符串

方法一 使用双指针配合枚举 /*** param {string} s* return {string}*/ const longestPalindrome s > {const LEN s.lengthif (LEN < 2) {return s}let maxStr /*** param left * param right * returns */const findPalindrome (left, right) > {while (left &…

服务器数据恢复—V7000存储raid5崩溃导致上层卷无法使用的数据恢复案例

服务器数据恢复环境&#xff1a; 某品牌V7000存储中有一组由几十块硬盘组建的raid5阵列。上层操作系统为windows server&#xff0c;NTFS分区。 服务器故障&#xff1a; 有一块硬盘出现故障离线&#xff0c;热备盘自动上线替换离线硬盘。在热备盘上线同步数据的过程&#xff0c…

视频智能分析国标GB28181云平台EasyCVR加密机授权异常是什么原因?

国标GB28181视频汇聚/视频云存储/集中存储/视频监控管理平台EasyCVR能在复杂的网络环境中&#xff0c;将分散的各类视频资源进行统一汇聚、整合、集中管理&#xff0c;实现视频资源的鉴权管理、按需调阅、全网分发、云存储、智能分析等。 近期有用户选择使用加密机进行EasyCVR授…