基于lora技术微调Gemma(2B)代码实践

一、前置条件

获得模型访问权,选择Colab运行时,配置训练环境。

先在Kaggle上注册,然后获得Gemma 2B 的访问权;

然后在Google colab 配置环境,主要是GPU的选择,免费的是T4,建议采用付费的A100(为了节省时间,微调训练耗时T4需要30分钟左右,A100只需要2分钟左右)

最后 在Kaggle 上的account上生成令牌文件(主要是usename 和 API Key),并将令牌文件配置到colab环境。

二、微调步骤

三、源码

# -*- coding: utf-8 -*-
"""gemma-lora微调.ipynbAutomatically generated by Colaboratory.Original file is located athttps://colab.research.google.com/drive/1_uEbuYP-vk0tCO0EA7IQ1t85jJ6ne7Qi
"""from google.colab import files
uploaded = files.upload()!mkdir ~/.kaggle
!mv kaggle.json ~/.kaggle/!chmod 600 ~/.kaggle/kaggle.json# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U keras-nlp
!pip install -q -U keras>=3import osos.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"import keras
import keras_nlp!wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonlimport json
data = []
with open("databricks-dolly-15k.jsonl") as file:for line in file:features = json.loads(line)# Filter out examples with context, to keep it simple.if features["context"]:continue# Format the entire example as a single string.template = "Instruction:\n{instruction}\n\nResponse:\n{response}"data.append(template.format(**features))# Only use 1000 training examples, to keep it fast.
data = data[:1000]gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.summary()prompt = template.format(instruction="What should I do on a trip to Europe?",response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))prompt = template.format(instruction="Explain the process of photosynthesis in a child could understand.",response="",
)
print(gemma_lm.generate(prompt,max_length=256))# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()# Limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(learning_rate=5e-5,weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])gemma_lm.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),optimizer=optimizer,weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)prompt = template.format(instruction="What should I do on a trip to Europe?",response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))prompt = template.format(instruction="Explain the process of photosynthesis in a way that a child could understand.",response="",
)
print(gemma_lm.generate(prompt, max_length=256))

微调关键代码--A100

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/589628.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++——list类及其模拟实现

前言&#xff1a;这篇文章我们继续进行C容器类的分享——list&#xff0c;也就是数据结构中的链表&#xff0c;而且是带头双向循环链表。 一.基本框架 namespace Mylist {template<class T>//定义节点struct ListNode{ListNode<T>* _next;ListNode<T>* _pre…

【重学C语言】三、C语言最简单的程序

【重学C语言】三、C语言最简单的程序 最简单的程序头文件使用尖括号 < >使用双引号 ""区别与注意事项示例 主函数认识三个错误 常量和变量常量ASCII 码表转义字符 关键字数据类型关键字存储类关键字修饰符关键字控制流程关键字函数相关关键字其他关键字 变量变…

Linux 恶意软件“Migo”针对 Redis 进行加密劫持攻击

安全研究人员遇到了一种新的加密劫持活动&#xff0c;该活动使用一种名为 Migo 的新恶意软件&#xff0c;该恶意软件针对 Linux 主机上的 Redis 服务器。在 Cado Security 研究人员注意到在野外利用 Redis 系统的新命令后&#xff0c;该活动曝光了。 初始访问 根据 Cado secu…

RUST语言基本数据类型认识

1.RUST的基本数据类型参考: 2.使用RUST数据类型声明变量并赋值: let a:i8=1;//8位有符号整数let a1:u8=2;//8位无符号整数let b:i16=1;//16位有符号整数let b1:u16=2;//16位无符号整数let c:i32=1;//32位有符号整数let c1:u32=2;//32位无符号整数let d:i64=1;//64位有符号整数l…

公众号爆文策略与实践:揭秘千万阅读量的秘密

1. 引言 介绍公众号爆文的重要性&#xff0c;以及分享个人通过每天投入半小时赚到30倍门票的经验。强调跟上大佬步伐&#xff0c;提升认知的重要性。 2. 爆文的底层逻辑 2.1 推荐的底层逻辑 内容分发机制的变化&#xff0c;从仅限于直接关注到通过搜索、浏览推荐等多种方式…

【项目实战经验】DataKit迁移MySQL到openGauss(上)

前言 本文将分享DataKit迁移MySQL到openGauss的项目实战&#xff0c;供广大openGauss爱好者参考。 1. 下载操作系统 https://www.openeuler.org/zh/download https://support.huawei.com/enterprise/zh/doc/EDOC1100332931/1a643956 https://support.huawei.com/enterprise…

封装一个vue3的公共组件

在Vue 3中&#xff0c;封装公共组件的场景包括但不限于以下几种情况&#xff1a; 重复使用的组件&#xff1a;如果你发现某个组件在多个地方重复使用&#xff0c;那么将其封装成公共组件是很有意义的。比如&#xff0c;页面中的各种表单控件&#xff08;输入框、下拉框、日期选…

硬件开发文档规范

本文出发点&#xff1a; 一般来说&#xff0c;越是大公司越注重开发文档的规范性&#xff0c;因为这样最大的好处是能够保证开发的连贯性&#xff0c;也就是即使有员工离职了&#xff0c;只要开发文档是齐全的&#xff0c;新员工入职后&#xff0c;就能够很快接手工作&#xf…

Linux驱动学习:从Linux主机nfs共享文件到uboot

第一步&#xff1a;在Linux主机上开启NFS服务&#xff0c;使用如下命令安装NFS服务&#xff1a; sudo apt-get install nfs-kernel-server rpcbind 第二步&#xff1a;创建一个文件夹用于共享&#xff0c;直接以nfs命名就行&#xff1a; 第三步&#xff1a;打开nfs服务配置文…

arm裸机(1)、点灯|按键

芯片是S3C2440 首先看原理图&#xff0c;led_1234分别对应引脚GPB 5678 设置引脚为输出 向寄存器相应位写入 #define GPBCON (*(volatile unsigned long *)0x56000010) //p5 6 7 8 void led_init(void) {GPBCON & ~(0x3 << 10);GPBCON | (0x1 <<…

LeetCode-19. 删除链表的倒数第 N 个结点【链表 双指针】

LeetCode-19. 删除链表的倒数第 N 个结点【链表 双指针】 题目描述&#xff1a;解题思路一&#xff1a;双指针解题思路二&#xff1a;优化解题思路三&#xff1a;0 题目描述&#xff1a; 给你一个链表&#xff0c;删除链表的倒数第 n 个结点&#xff0c;并且返回链表的头结点。…

MTU/TCPMSS/VLAN/ACCESS/TRUNK/HYBRID

MTU RFC标准定义以太网的默认MTU值为1500 最小64字节是为了保证最极端的冲突能被检测到&#xff0c;64字节是能被检测到的最小值&#xff1b;最大不超过1518字节是为了防止过长的帧传输时间过长而占用共享链路太长时间导致其他业务阻塞。所以规定以太网帧大小为64~1518字节&am…