[跑代码]BK-SDM: A Lightweight, Fast, and Cheap Version of Stable Diffusion

Installation(下载代码-装环境)

conda create -n bk-sdm python=3.8
conda activate bk-sdm
git clone https://github.com/Nota-NetsPresso/BK-SDM.git
cd BK-SDM
pip install -r requirements.txt
Note on the torch versions we've used
  • torch 1.13.1 for MS-COCO evaluation & DreamBooth finetuning on a single 24GB RTX3090
     

  • torch 2.0.1 for KD pretraining on a single 80GB A10
    火炬2.0.1在单个80GB A100上进行KD预训练

    • 如果A100上总批大小为256的预训练导致gpu内存不足,请检查torch版本并考虑升级到torch>2.0.0。
      我的版本也是torch2.0.1 单个A100(80G)理论上吃的下256batch

小的例子

PNDM采样器 50步去噪声

等效代码(仅修改SD-v1.4的U-Net,同时保留其文本编码器和图像解码器):

Distillation Pretraining

Our code was based on train_text_to_image.py of Diffusers 0.15.0.dev0. To access the latest version, use this link.
BK-SDM的diffusers版本0.15
我的diffusers版本比较高0.24.0

检测是否能够训练(先下载数据集get_laion_data.sh再运行代码kd_train_toy.sh)

1 一个玩具数据集(11K的img-txt对)下载到。

bash scripts/get_laion_data.sh preprocessed_11k

/data/laion_aes/preprocessed_11k (1.7GB in tar.gz;1.8GB数据文件夹)。
get_laion_data.sh

需要修改,实际就是下载这三个数据集,我自行下载

# https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.5plus/preprocessed_11k.tar.gz
# https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.5plus/preprocessed_212k.tar.gz
# https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.5plus/preprocessed_2256k.tar.gz

我修改后下载文件名 https://... .../preprocessed_11k.tar.gz直接粘贴到网址里面也可以下载
wget $S3_URL -0 $FILe_PATH
$S3_URL 就是这个网址
$FILe_PATH 就是下载路径./data/laion_aes/preprocessed_11k

DATA_TYPE=$"preprocessed_11k"  # {preprocessed_11k, preprocessed_212k, preprocessed_2256k}
FILE_NAME="${DATA_TYPE}.tar.gz"DATA_DIR="./data/laion_aes/"
FILE_UNZIP_DIR="${DATA_DIR}${DATA_TYPE}"
FILE_PATH="${DATA_DIR}${FILE_NAME}"if [ "$DATA_TYPE" = "preprocessed_11k" ] || [ "$DATA_TYPE" = "preprocessed_212k" ]; thenecho "-> preprocessed_11k or 212k"S3_URL="https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.5plus/${FILE_NAME}"
elif [ "$DATA_TYPE" = "preprocessed_2256k" ]; thenS3_URL="https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.25plus/${FILE_NAME}"
elseecho "Something wrong in data folder name"exit
fiwget $S3_URL -O $FILE_PATH
tar -xvzf $FILE_PATH -C $DATA_DIR
echo "downloaded to ${FILE_UNZIP_DIR}"

2 一个小脚本可以用来验证代码的可执行性,并找到与你的GPU匹配的批处理大小。
批量大小为8 (=4×2),训练BK-SDM-Base 20次迭代大约需要5分钟和22GB的GPU内存。

bash scripts/kd_train_toy.sh
MODEL_NAME="CompVis/stable-diffusion-v1-4"
TRAIN_DATA_DIR="./data/laion_aes/preprocessed_11k" # please adjust it if needed
UNET_CONFIG_PATH="./src/unet_config"UNET_NAME="bk_small" # option: ["bk_base", "bk_small", "bk_tiny"]
OUTPUT_DIR="./results/toy_"$UNET_NAME # please adjust it if neededBATCH_SIZE=2
GRAD_ACCUMULATION=4StartTime=$(date +%s)CUDA_VISIBLE_DEVICES=1 accelerate launch src/kd_train_text_to_image.py \--pretrained_model_name_or_path $MODEL_NAME \--train_data_dir $TRAIN_DATA_DIR\--use_ema \--resolution 512 --center_crop --random_flip \--train_batch_size $BATCH_SIZE \--gradient_checkpointing \--mixed_precision="fp16" \--learning_rate 5e-05 \--max_grad_norm 1 \--lr_scheduler="constant" --lr_warmup_steps=0 \--report_to="all" \--max_train_steps=20 \--seed 1234 \--gradient_accumulation_steps $GRAD_ACCUMULATION \--checkpointing_steps 5 \--valid_steps 5 \--lambda_sd 1.0 --lambda_kd_output 1.0 --lambda_kd_feat 1.0 \--use_copy_weight_from_teacher \--unet_config_path $UNET_CONFIG_PATH --unet_config_name $UNET_NAME \--output_dir $OUTPUT_DIREndTime=$(date +%s)
echo "** KD training takes $(($EndTime - $StartTime)) seconds."

单GPU训练BK-SDM{Base, Small, Tiny}-0.22M数据训练
 

bash scripts/get_laion_data.sh preprocessed_212k
bash scripts/kd_train.sh

1 下载数据集preprocessed_212k
2 训练kd_train.sh
(256batch 训练BD-SM-Base 50K轮次需要300hours/53G单卡)
(64batch 训练BD-SM-Base 50K轮次需要60hours/28G单卡) 不理解?
 

单GPU训练BK-SDM{Base, Small, Tiny}-2.3M数据训练

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/235622.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

行业风云突变,毫末发内部信:明年要打赢智驾生死之战

岁末智能车行业风云突变。 广州车展前夕,小米汽车首车曝光;车展现场,各路车企上演旗舰之战;而车展才刚落幕,华为突然宣布车BU业务拆分。 外资、自主、合资、跨界车企,各路势力提前决战2025。智能新能源车…

TCP解帧解码、并发送有效数据到FPGA

TCP解帧解码、并发送有效数据到FPGA 工程的功能:使用TCP协议接收到网络调试助手发来的指令,将指令进行解帧,提取出帧头、有限数据、帧尾;再将有效数据发送到FPGA端的BRAM上,实现信息传递。 参考:正点原子启…

LeetCode(40)同构字符串【哈希表】【简单】

目录 1.题目2.答案3.提交结果截图 链接: 同构字符串 1.题目 给定两个字符串 s 和 t ,判断它们是否是同构的。 如果 s 中的字符可以按某种映射关系替换得到 t ,那么这两个字符串是同构的。 每个出现的字符都应当映射到另一个字符&#xff0…

在idea中写sql语句,向数据库添加数据时,添加的字符串却显示???,解决方法

这是字符编码的问题 如何解决: 在idea的配置数据库的地方修改下边:mysql8版本和5版本差距不大。 在URL后加?useUnicodetrue&characterEncodingUTF8 例如 原来:String url “jdbc:mysql://localhost:3306/stu”; 改变后:St…

ssh连接docker容器处理备忘

1、查看容器ip,记下来之后要用 docker inspect elastic | grep IPAddress 2、使用root进入docker容器 docker exec -it -u root elastic /bin/bash 3、安装openssh #更新apt apt-get update#安装ssh client apt-get install openssh-client#安装ssh server apt-…

ubuntu配置ssh

本教程中的涉及路径的所有命令都是在root用户下的,读者可将路径中的/root更改为/home/用户名 1、重置密码 新安装的系统需要在服务器控制台点击“重置密码”,为root用户设置一个密码 ————————————————————————————————…

分享:大数据方向学生学徒参与条件

学生学徒制的实施旨在解决当前新技术企业招聘技能人才难和青年就业难的结构性矛盾,通过生态链链主企业携手院校共同解决毕业年度学生就业问题,按照学生个人意愿,建立以就业导向的学生学徒制关系,签订学徒培养协议确定学生就业岗位…

每日一练2023.11.30——谁先倒【PTA】

题目链接:谁先倒 题目要求: 划拳是古老中国酒文化的一个有趣的组成部分。酒桌上两人划拳的方法为:每人口中喊出一个数字,同时用手比划出一个数字。如果谁比划出的数字正好等于两人喊出的数字之和,谁就输了&#xff0…

快速排序算法的代码及算法思想

快速排序(Quick Sort)是一种常用的排序算法,他的时间复杂度为O(nlogn) 算法思想: 通过一趟排序将待排序的数据分割成独立的两部分,其中一部分的所有数据都比另一部分的所有数据小,然后再对这两部分数据分别进行快速排…

换抵挡装置(按位运算符的运用)

给出两个长度分别为n1&#xff0c;n2&#xff08;n1 n2 <32)且每列高度只为1或2的长条&#xff08;保证高度为1的地方水平上一致&#xff09;。需要将它们放入一个高度为3的容器长度&#xff0c;问能够容纳它们的最短容器长度 用手画的 本来是n1&#xff0c;n2 < 100的…

[操作系统] 文件管理

文章目录 5.1 磁盘调度算法1. 先来先服务算法( First Come First Served, FCFS) 算法2. 最短寻道时间优先算法( Shortest Seek Time First, SSTF) 算法3. 扫描算法( SCAN ) 算法4. 循环扫描算法( Circular Scan, CSCAN ) 算法5. LOOK 与 CLOOK 算法 5.2 进程写文件时&#xff0…

全汉电源SN生产日期解读

新买了一个全汉的电脑电源&#xff0c;SN&#xff1a;WZ3191900030&#xff0c;看了几次没想明白&#xff0c;最后估计SN是2023年19周这样来记录日期的。问了一下京东全汉客服&#xff0c;果然就是这样的。那大家如果在闲鱼上看到全汉电源&#xff0c;就知道它的生产日期了。