Generative AI 新世界 | 文生图领域动手实践:预训练模型的微调

在上期文章,我们探讨了预训练模型的部署和推理,包括运行环境准备、角色权限配置、支持的主要推理参数、图像的压缩输出、提示工程 (Prompt Engineering)、反向提示 (Negative Prompting) 等内容。

亚马逊云科技开发者社区为开发者们提供全球的开发技术资源。这里有技术文档、开发案例、技术专栏、培训视频、活动与竞赛等。帮助中国开发者对接世界最前沿技术,观点,和项目,并将中国优秀开发者或技术推荐给全球云社区。如果你还没有关注/收藏,看到这里请一定不要匆匆划过,点这里让它成为你的技术宝库!

本期文章,我们将探讨如何在自定义数据集上来微调(fine-tuned)模型,该模型可以针对任何图像数据集进行微调。即使你手上只有几张自定义的图像提供做训练,模型也能输出比较理想的结果。

首先,让我们通过一篇论文的概括解读,来了解这种文生图模型的微调 (fine-tuned),背后的工作原理和理论基础知识。

DreamBooth 论文概述

这种文生图模型的微调(fine-tuned)理论基础来自于 DreamBooth 论文,如下图所示:

image.png

DreamBooth: Fine Tuning Text-to-Image Diffusion Models for Subject-DrivenGeneration

https://arxiv.org/pdf/2208.12242.pdf?trk=cndc-detail

在论文的开头,作者提出一个挑战性的问题:

虽然当时的文生图模型已经可以根据给定的 **prompt **生成高质量的图片,但是这些模型并不能模仿给定参考图片中的物体要素,在不同情景中来生成新的图片。

举个例子。

我家里有一只叫做“小花”的可爱加菲猫,如下图:

image.png

我想让加菲猫“小花”带上一顶礼帽,如下图:

image.png

或者带上一副很酷炫的墨镜,如下图:

image.png

甚至想象下她刷牙的魔幻景象,如下图:

image.png

事实上,上面的这些加菲猫“小花”的照片(戴礼帽、戴墨镜、刷牙),都是大模型使用 DreamBooth 做微调后生成的。很有趣吧?在文末会提供生成这些魔幻照片的全部代码。

我们先看下 DreamBooth 论文阐述的背后原理。

DreamBooth 论文提出一个新颖的方法:将输入图片中的物体与一个特殊标识符绑定在一起,即用这个特殊标记符来表示输入图片中的物体。因此论文为微调模型设计了一种 prompt 格式:a [identifier] [class noun],即将所有输入图片的 prompt 都设置成这种形式,其中 identifier 是一个与输入图片中物体相关联的特殊标记符,class noun 是对物体的类别描述。

这里之所以在 prompt 中加入类别,是因为想利用预训练模型中关于该类别物品的先验知识,并将先验知识与特殊标记符相关信息进行融合,这样就可以在不同场景下生成不同姿势的目标物体。

简单来说就是:不要学了新的知识,就忘了旧的知识

论文提出的方法,大致如下图所示,即仅仅通过 3 到 5 张图片去微调文生图模型,使得模型能将输入图片中特定的物品和 prompt 中的特殊标记符关联起来了。

image.png

Source: https://dreambooth.github.io\?trk=cndc-detail

关于特殊标记符的选择,论文提出通过在词表中选择罕见词来作为特殊标记符,这样避免了预训练模型对特殊标记符有很强烈的先验知识。

DreamBooth 论文提出一个新的微调方法:**通过预先生成的一些图像,来保留先验损失权重;以此来解决过拟合与语言漂移问题。**用模型自己生成的样本来监督模型,以便在 few-shot(小样本)微调开始后保留先验知识,如以下论文中提供的解释图所示:

image.png

Source: https://dreambooth.github.io/?trk=cndc-detail

给定大约 3-5 张拍摄对象的图像,我们分两步微调文本到图像的扩散:

  1. 使用输入图像与包含唯一标识符和主题所属类名称(例如:“A photo of a [T] dog”)的文本提示配对;同时,我们应用特定于类的预先保存损失,它利用了模型之前的语义通过在文本提示中注入类名,来鼓励它生成属于受试者类的各种实例提示(例如:“A photo of a dog”)。
  2. 使用从我们的输入图像集中拍摄的低分辨率和高分辨率图像,对超分辨率组件进行微调,这使我们能够保持对拍摄对象小细节的高保真度。

引入了先验损失的 loss 公式,如下所示:

image.png

通过这种 DreamBooth 方法,使得:输入训练集 + 提示词 [v] dog,然后还有用模型本身自己生成的 dog 图像,训练完成后得到了一个特殊标记符:[v]。通过这个特殊标记符 [v],就把这次训练的 dog 和其他本身学过的 dog 分开了。

最后得到惊艳的结果,比如给一只小熊带上太阳镜,如下图所示:

image.png

Source: https://dreambooth.github.io/?trk=cndc-detail

接下来,我们将完整用代码演示,如何给我家的加菲猫“小花”带上眼镜和礼帽。

Fine-tune 预训练模型在自有数据集上的微调

我们使用 Amazon SageMaker Studio 来实现在自有数据上的模型微调。

我首先将为我家的加菲猫“小花”拍摄几张照片,然后用这几张照片来微调模型;完成模型微调后,我们将使用 “a picture of Garfield cat with glasses” 这样的提示词,来直接为我家的加菲猫“小花”带上眼镜。

1 实例和环境准备

这个 Notebook 在带有 Python 3(Data Science)内核的 SageMaker Studio 中,使用 ml.t3.medium 实例上进行了测试。要对数据集的模型进行微调,您需要在账户中提供 ml.g4dn.2xlarge 实例类型。

完整的示例代码,可参考以下 GitHub 文档链接,从 “Fine-tune the pre-trained model on a custom dataset” 这一部分开始阅读代码:

https://github.com/aws/studio-lab-examples/blob/main/generative-deep-learning/stable-diffusion-finetune/Amazon_JumpStart_Text_To_Image.ipynb?trk=cndc-detail

你存放自定义照片的 s3 路径,应该看起来像这样:s3://bucket_name/input_directory/

请注意,后面的“/”为必填项。

以下是训练数据的示例格式:

input_directory|---instance_image_1.png|---instance_image_2.png|---instance_image_3.png|---instance_image_4.png|---instance_image_5.png|---dataset_info.json|---class_data_dir|---class_image_1.png|---class_image_2.png|---class_image_3.png|---class_image_4.png

 

预先保存、实例提示和类提示(Prior preservation, instance prompt and class prompt):预先保存是一种使用我们正在尝试训练的同一个类的其他图像的技术。例如,如果训练数据由特定狗的图像组成,并事先保存,则我们会合并普通犬的类别图像。它试图通过在为特定狗训练时显示不同狗的图像来避免过度拟合。类提示中缺少表示实例提示中存在的特定狗的标签。

例如,实例提示可能是 “A photo of a Garfield cat”,类提示可能是 “A photo of a cat”。

您可以通过将超参数设置为 _prior_preservation = True 来启用预先保存。

以下为使用我家加菲猫“小花”的照片的 dataset_info.json 的文件示例:

$ cat dataset_info.json
{"instance_prompt": "A photo of a Garfield cat","class_prompt": "A photo of a cat"
}

 

以下是我为了微调模型,而拍摄的我家加菲猫“小花”的照片。我只用了下面这六张照片,就实现了模型的微调。

image.png

我存放照片(即为微调模型提供的自定义训练图片)的 S3 桶参考路径如下:s3://sagemaker-us-east-1-xxxxxxxxxxxx/haowen-datasets/cat_finetuning/

其中 “sagemaker-us-east-1-xxxxxxxxxxxx” 需要更新为你自己定义的桶名。

最终完成微调后,模型存放的 S3 桶参考路径如下:s3://sagemaker-us-east-1-xxxxxxxxxxxx/jumpstart-example-sd-training/output

其中 “sagemaker-us-east-1-xxxxxxxxxxxx” 需要更新为你自己定义的桶名。

2 检索训练数据的 Artifacts

在这里,我们检索训练 docker 容器、训练算法源和预先训练的基础模型。

请注意,model_version= “*” 获取的是最新的模型版本号。以下代码选择了 Stable Diffusion V2.1 Base 的文生图大模型。

# Select a model 
train_model_id, train_model_version, train_scope = ("model-txt2img-stabilityai-stable-diffusion-v2-1-base","*","training",
)

以下代码选择了微调模型的实例是 ml.g4dn.2xlarge:

training_instance_type = "ml.g4dn.2xlarge"

以下代码获取 Docker Image:

# Retrieve the docker image
train_image_uri = image_uris.retrieve(region=None,framework=None,  # automatically inferred from model_idmodel_id=train_model_id,model_version=train_model_version,image_scope=train_scope,instance_type=training_instance_type,
)

 

以下代码获取训练脚本:

# Retrieve the training script. This contains all the necessary files including data processing, model training etc.
train_source_uri = script_uris.retrieve(model_id=train_model_id, model_version=train_model_version, script_scope=train_scope
)

以下代码获取预训练模型的 tarball 包,用于之后的微调工作:

# Retrieve the pre-trained model tarball to further fine-tune
train_model_uri = model_uris.retrieve(model_id=train_model_id, model_version=train_model_version, model_scope=train_scope
)

3 设置训练参数

现在我们已经完成了所有需要的设置,我们已经准备好微调 Stable Diffusion 模型了。首先,让我们创建一个 sageMaker.estimator.Estimator 对象。该 Estimator 将启动训练作业。

模型的微调训练需要设置两种参数。

第一组参数是训练作业的参数。其中包括:

  1. 训练数据路径,这是存储输入数据的 S3  路径。即之前我们准备的 “s3://sagemaker-us-east-1-xxxxxxxxxxxx/haowen-datasets/cat_finetuning/” 这个路径;
  2.  输出路径,这是存储微调模型训练的输出 s3 路径。即之前我们准备的“s3://sagemaker-us-east-1-xxxxxxxxxxxx/jumpstart-example-sd-training/output” 这个路径;
  3. 训练实例类型,这表示运行模型微调训练的机器类型。我们在上面定义了训练实例类型,以获取正确的 train_image_uri。

第二组参数是特定于算法的训练超参数。对于算法特定的超参数,我们首先获取算法接受的训练超参数的 python 字典及其默认值,然后可以将其改写为自定义值。示例代码如下所示:

from sagemaker import hyperparameters# Retrieve the default hyper-parameters for fine-tuning the model
hyperparameters = hyperparameters.retrieve_default(model_id=train_model_id, model_version=train_model_version
)# [Optional] Override default hyperparameters with custom values
hyperparameters["max_steps"] = "400"
print(hyperparameters)

4 启动模型微调训练

我们首先使用所有必需的 assets 创建 estimator 对象,然后启动训练作业。

from sagemaker.estimator import Estimator
from sagemaker.utils import name_from_base
from sagemaker.tuner import HyperparameterTunertraining_job_name = name_from_base(f"jumpstart-example-{train_model_id}-transfer-learning")# Create SageMaker Estimator instance
sd_estimator = Estimator(role=aws_role,image_uri=train_image_uri,source_dir=train_source_uri,model_uri=train_model_uri,entry_point="transfer_learning.py",  # Entry-point file in source_dir and present in train_source_uri.instance_count=1,instance_type=training_instance_type,max_run=360000,hyperparameters=hyperparameters,output_path=s3_output_location,base_job_name=training_job_name,
)# Launch a SageMaker Training job by passing s3 path of the training data
sd_estimator.fit({"training": training_dataset_s3_path}, logs=True)

模型训练开始后,如果观察 SageMaker 的控制台,会发现:

  1. 训练任务的状态,从 “InProgress” 逐渐变成 “Completed”;
  2. 超参调优的状态,从 “InProgress” 逐渐变成 “Completed”。

如下图所示:

image.png

image.png

image.png

经过那六张照片作为新的输入数据,微调后的模型重新训练完成后,就可以进入以下的模型部署阶段了。

5 微调后模型的部署

我们将遵循上一篇中介绍的模型部署的相同步骤,在训练好的模型上运行推理。我们首先检索用于部署端点的 jumpstart 工件。但是,我们部署的是经过微调的 sd_estimator 估算器,而不是上一篇中使用的 base_predictor 估算器。

inference_instance_type = "ml.g4dn.2xlarge"# Retrieve the inference docker container uri
deploy_image_uri = image_uris.retrieve(region=None,framework=None,  # automatically inferred from model_idimage_scope="inference",model_id=train_model_id,model_version=train_model_version,instance_type=inference_instance_type,
)
# Retrieve the inference script uri. This includes scripts for model loading, inference handling etc.
deploy_source_uri = script_uris.retrieve(model_id=train_model_id, model_version=train_model_version, script_scope="inference"
)endpoint_name = name_from_base(f"jumpstart-example-FT-{train_model_id}-")# Use the estimator from the previous step to deploy to a SageMaker endpoint
finetuned_predictor = sd_estimator.deploy(initial_instance_count=1,instance_type=inference_instance_type,entry_point="inference.py",  # entry point file in source_dir and present in deploy_source_uriimage_uri=deploy_image_uri,source_dir=deploy_source_uri,endpoint_name=endpoint_name,
)

在等待新模型部署的过程中,可以回到 SageMaker 的控制台,在 Endpoints 项中刷新检查模型部署的情况。当 Status 从 “Creating” 变成 “Completed”,就表示微调后的新模型已经部署完成可以开始进行推理了。如下图所示:

image.png

6 微调后模型的推理

下面进入激动人心的时刻,我们在微调后的模型上进行推理。

我输入的提示词是:“a photo of a Garfield cat with a hat”(一只带帽子的加菲猫)。

text = " a photo of a Garfield cat with a hat"
query_response = query(finetuned_predictor, text)
img, prmpt = parse_response(query_response)
display_img_and_prompt(img, prmpt)

模型的魔幻输出如下图所示。我们成功地给加菲猫“小花”带上礼帽了!

image.png

接着我们给加菲猫“小花”带上眼镜,我输入的提示词是:“a picture of Garfield cat with glasses”:

text = " a picture of Garfield cat with glasses"
query_response = query(finetuned_predictor, text)
img, prmpt = parse_response(query_response)
display_img_and_prompt(img, prmpt)

模型的输出如下:

image.png

最后让加菲猫“小花”像人类一样去刷牙,我输入的提示词是:“a picture of Garfield cat brushing her teeth”:

text = " a picture of Garfield cat brushing her teeth"
query_response = query(finetuned_predictor, text)
img, prmpt = parse_response(query_response)
display_img_and_prompt(img, prmpt)

image.png

神奇吧?加菲猫“小花”会自己刷牙了!

7 计算资源删除和清理

和以前一样,实验完成后别忘记清除相关的 endpoint 资源,以避免产生不必要的费用:

# Delete the SageMaker endpoint
finetuned_predictor.delete_model()
finetuned_predictor.delete_endpoint()

总结

本文我们学习了如何使用 Amazon SageMaker JumpStart 方便地微调文生图的 Stable Diffusion 模型。

Amazon SageMaker JumpStart 为预训练的模型提供了微调功能,本文的例子中,你只需使用六张训练图像即可根据自己的用例调整模型。这在创建个性化艺术品、独特的徽标、企业的 LOGO、或者其他需要自定义设计的场景时非常有用。

下一期的文章,我们将重新回到文本生成的大模型场景,探讨如何在 Amazon SageMaker JumpStart 上部署当今炙手可热的开源大语言模型。我们将以 Falcon 40B 开源大模型为例,逐行代码轻松部署高达 400 亿参数的这个大型语言模型。敬请期待。

请持续关注 Build On Cloud 专栏,了解更多面向开发者的技术分享和云开发动态!

 

作者 黄浩文

亚马逊云科技资深开发者布道师,专注于 AI/ML、Data Science 等。拥有 20 多年电信、移动互联网以及云计算等行业架构设计、技术及创业管理等丰富经验,曾就职于 Microsoft、Sun Microsystems、中国电信等企业,专注为游戏、电商、媒体和广告等企业客户提供 AI/ML、数据分析和企业数字化转型等解决方案咨询服务。

文章来源:https://dev.amazoncloud.cn/column/article/64cb87265306fa4a7fa3a3c9?sc_medium=regulartraffic&sc_campaign=crossplatform&sc_channel=CSDN

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/132007.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Lumen/Laravel - 数据库读写分离原理 - 探究

1.应用场景 主要用于学习与探究Lumen/Laravel的数据库读写分离原理。 2.学习/操作 1.文档阅读 chatgpt & 其他资料 数据库入门 | 数据库操作 | Laravel 8 中文文档 入门篇(一):数据库连接配置和读写分离 | 数据库与 Eloquent 模型 | La…

分布式数据库HBase(林子雨慕课课程)

文章目录 4. 分布式数据库HBase4.1 HBase简介4.2 HBase数据模型4.3 HBase的实现原理4.4 HBase运行机制4.5 HBase的应用方案4.6 HBase安装和编程实战 4. 分布式数据库HBase 4.1 HBase简介 HBase是BigTable的开源实现 对于网页搜索主要分为两个阶段 1.建立整个网页索引&#xf…

【JVM】初步认识Java虚拟机

🐌个人主页: 🐌 叶落闲庭 💨我的专栏:💨 c语言 数据结构 javaEE 操作系统 Redis 石可破也,而不可夺坚;丹可磨也,而不可夺赤。 JVM 一、初识JVM1.1 什么是JVM1.2 JVM的功能…

苹果安卓网页的H5封装成App的应用和原生开发的应用有什么不一样?

老哥在么?H5封装的app和原生开发的app有什么不一样?,不懂就要问,我能理解哈,虽然这个问题有点小白,但是我还是得认真回答,以防止我回答的不是很好,所以我科技了一下,所以…

input时间控件选择时禁用某个日期之前或之后

【版权所有,文章允许转载,但须以链接方式注明源地址,否则追究法律责任】【创作不易,点个赞就是对我最大的支持】 前言 仅作为学习笔记,供大家参考 总结的不错的话,记得点赞收藏关注哦! 目录 …

ubuntu20.04安装genymotion3.5.1

下载和安装genymotion https://www.genymotion.com/download/ wget https://dl.genymotion.com/releases/genymotion-3.5.1/genymotion-3.5.1-linux_x64.bin chmod x genymotion-3.5.1-linux_x64.bin sudo ./genymotion-3.5.1-linux_x64.bin默认位置为:/opt/genym…

【SpringCloud】微服务技术栈入门8 - 黑马旅游微服务项目实战笔记

目录 黑马旅游案例分页查询自动补全安装依赖自定义分词器Completion Suggester 聚合数据聚合的分类Bucket 聚合Metrix 聚合RestClient 实现聚合suggest 查询结果 数据同步同步策略mq 同步 eses 搭设集群 黑马旅游案例 分页查询 前端页面以及对应请求接口已经设置完备&#xff…

cdsn目录处理:空行替换2个```,在2个```中间添加“# 空行文本后遇到的第1行文字”?

原标题: python查找替换:查找空行,空行前后添加,中间添加 # 空格 空行后遇到的第1行文字?初始代码 查找空行空行前后添加 中间添加 # 空行后遇到的第1行文字txt 36 96 159 8 72可以使用Python的字符串处理函数来查找…

Vue3中使用tinymce全功能演示,包括开源功能

效果图: 1、下载插件: npm i tinymce npm i tinymce/tinymce-vue 2、在node_modules文件夹中找到tinymce下的skins复制到项目public文件夹中 (可以先创建一个tinymce文件夹): 3、在tinymce官网中下载中文包,并放在刚…

【AntDesign】多环境配置和启动

环境分类,可以分为 本地环境、测试环境、生产环境等,通过对不同环境配置内容,来实现对不同环境做不同的事情。 AntDesign 项目,通过 config.xxx.ts 添加不同的后缀来区分配置文件,启动时候通过后缀启动即可。 config…

如何基于先进视频技术,构建互联网视频监控安全管理平台解决方案

一、建设思路 依托互联网,建设一朵云,实现各类二三类视频资源统一接入,实现天网最后100米、10米、1米的全域覆盖。 依托人工智能与互联网技术,拓展视频资源在政府、社会面等多领域的全面应用;建设与运营模式并存&…

湖州OLED透明拼接屏技术应用引领现代化旅游观光方式

湖州市位于中国浙江省北部,拥有悠久的历史和丰富的文化遗产。湖州市以其美丽的湖泊和秀丽的自然风光而闻名。 作为中国重要的历史文化名城之一,湖州市有着丰富的文化遗产和历史资源,如古城墙、古建筑和古镇等。 这为OLED透明拼接屏技术的应用…