Stable Diffusion (version x.x) 文生图模型实践指南

前言:本篇博客记录使用Stable Diffusion模型进行推断时借鉴的相关资料和操作流程。

相关博客:
超详细!DALL · E 文生图模型实践指南
DALL·E 2 文生图模型实践指南

目录

  • 1. 环境搭建和预训练模型准备
    • 环境搭建
    • 预训练模型下载
  • 2. 代码


1. 环境搭建和预训练模型准备

环境搭建

pip install diffusers transformers accelerate scipy safetensors

预训练模型下载

关于 huggingface 网站总是崩溃的情况,找到一个解决办法,就是可以通过脚本来下载

第一步:安装 huggingface_hub,使用命令 pip install huggingface_hub
第二步:下载具体模型,使用命令 python model_download.py --repo_id model_id,其中,model_id 为要下载的模型,比如SD v2.1 版本的model_id可以是 stabilityai/stable-diffusion-2-1;SD v1.5 版本的model_id可以是 runwayml/stable-diffusion-v1-5. model_id 的查找方式是在huggingface 网站直接搜索需要的模型(如下图),得到的「模型来源/版本」的组合即为所需。

在这里插入图片描述

model_download.py文件来自这个链接。

# usage     : python model_download.py --repo_id repo_id
# example   : python model_download.py --repo_id facebook/opt-350m
import argparse
import time
import requests
import json
import os
from huggingface_hub import snapshot_download
import platform
from tqdm import tqdm
from urllib.request import urlretrievedef _log(_repo_id, _type, _msg):date1 = time.strftime('%Y-%m-%d %H:%M:%S')print(date1 + " " + _repo_id + " " + _type + " :" + _msg)def _download_model(_repo_id, _repo_type):if _repo_type == "model":_local_dir = 'dataroot/models/' + _repo_idelse:_local_dir = 'dataroot/datasets/' + _repo_idtry:if _check_Completed(_repo_id, _local_dir):return True, "check_Completed ok"except Exception as e:return False, "check_Complete exception," + str(e)_cache_dir = 'caches/' + _repo_id_local_dir_use_symlinks = Trueif platform.system().lower() == 'windows':_local_dir_use_symlinks = Falsetry:if _repo_type == "model":snapshot_download(repo_id=_repo_id, cache_dir=_cache_dir, local_dir=_local_dir, local_dir_use_symlinks=_local_dir_use_symlinks,resume_download=True, max_workers=4)else:snapshot_download(repo_id=_repo_id, cache_dir=_cache_dir, local_dir=_local_dir, local_dir_use_symlinks=_local_dir_use_symlinks,resume_download=True, max_workers=4, repo_type="dataset")except Exception as e:error_msg = str(e)if ("401 Client Error" in error_msg):return True, error_msgelse:return False, error_msg_removeHintFile(_local_dir)return True, ""def _writeHintFile(_local_dir):file_path = _local_dir + '/~incomplete.txt'if not os.path.exists(file_path):if not os.path.exists(_local_dir):os.makedirs(_local_dir)open(file_path, 'w').close()def _removeHintFile(_local_dir):file_path = _local_dir + '/~incomplete.txt'if os.path.exists(file_path):os.remove(file_path)def _check_Completed(_repo_id, _local_dir):_writeHintFile(_local_dir)url = 'https://huggingface.co/api/models/' + _repo_idresponse = requests.get(url)if response.status_code == 200:data = json.loads(response.text)else:return Falsefor sibling in data["siblings"]:if not os.path.exists(_local_dir + "/" + sibling["rfilename"]):return False_removeHintFile(_local_dir)return Truedef download_model_retry(_repo_id, _repo_type):i = 0flag = Falsemsg = ""while True:flag, msg = _download_model(_repo_id, _repo_type)if flag:_log(_repo_id, "success", msg)breakelse:_log(_repo_id, "fail", msg)if i > 1440:msg = "retry over one day"_log(_repo_id, "fail", msg)breaktimeout = 60time.sleep(timeout)i = i + 1_log(_repo_id, "retry", str(i))return flag, msgdef _fetchFileList(files):_files = []for file in files:if file['type'] == 'dir':filesUrl = 'https://e.aliendao.cn/' + file['path'] + '?json=true'response = requests.get(filesUrl)if response.status_code == 200:data = json.loads(response.text)for file1 in data['data']['files']:if file1['type'] == 'dir':filesUrl = 'https://e.aliendao.cn/' + \file1['path'] + '?json=true'response = requests.get(filesUrl)if response.status_code == 200:data = json.loads(response.text)for file2 in data['data']['files']:_files.append(file2)else:_files.append(file1)else:if file['name'] != '.gitattributes':_files.append(file)return _filesdef _download_file_resumable(url, save_path, i, j, chunk_size=1024*1024):headers = {}r = requests.get(url, headers=headers, stream=True, timeout=(20, 60))if r.status_code == 403:_log(url, "download", '下载资源发生了错误,请使用正确的token')return Falsebar_format = '{desc}{percentage:3.0f}%|{bar}|{n_fmt}M/{total_fmt}M [{elapsed}<{remaining}, {rate_fmt}]'_desc = str(i) + ' of ' + str(j) + '(' + save_path.split('/')[-1] + ')'total_length = int(r.headers.get('content-length'))if os.path.exists(save_path):temp_size = os.path.getsize(save_path)else:temp_size = 0retries = 0if temp_size >= total_length:return True# 小文件显示if total_length < chunk_size:with open(save_path, 'wb') as f:for chunk in r.iter_content(chunk_size=chunk_size):if chunk:f.write(chunk)with tqdm(total=1, desc=_desc, unit='MB', bar_format=bar_format) as pbar:pbar.update(1)else:headers['Range'] = f'bytes={temp_size}-{total_length}'r = requests.get(url, headers=headers, stream=True,verify=False, timeout=(20, 60))data_size = round(total_length / 1024 / 1024)with open(save_path, 'ab') as fd:fd.seek(temp_size)initial = temp_size//chunk_sizefor chunk in tqdm(iterable=r.iter_content(chunk_size=chunk_size), initial=initial, total=data_size, desc=_desc, unit='MB', bar_format=bar_format):if chunk:temp_size += len(chunk)fd.write(chunk)fd.flush()return Truedef _download_model_from_mirror(_repo_id, _repo_type, _token, _e):if _repo_type == "model":filesUrl = 'https://e.aliendao.cn/models/' + _repo_id + '?json=true'else:filesUrl = 'https://e.aliendao.cn/datasets/' + _repo_id + '?json=true'response = requests.get(filesUrl)if response.status_code != 200:_log(_repo_id, "mirror", str(response.status_code))return Falsedata = json.loads(response.text)files = data['data']['files']for file in files:if file['name'] == '~incomplete.txt':_log(_repo_id, "mirror", 'downloading')return Falsefiles = _fetchFileList(files)i = 1for file in files:url = 'http://61.133.217.142:20800/download' + file['path']if _e:url = 'http://61.133.217.139:20800/download' + \file['path'] + "?token=" + _tokenfile_name = 'dataroot/' + file['path']if not os.path.exists(os.path.dirname(file_name)):os.makedirs(os.path.dirname(file_name))i = i + 1if not _download_file_resumable(url, file_name, i, len(files)):return Falsereturn Truedef download_model_from_mirror(_repo_id, _repo_type, _token, _e):if _download_model_from_mirror(_repo_id, _repo_type, _token, _e):returnelse:#return download_model_retry(_repo_id, _repo_type)_log(_repo_id, "download", '下载资源发生了错误,请使用正确的token')if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--repo_id', default=None, type=str, required=True)parser.add_argument('--repo_type', default="model",type=str, required=False)  # models,dataset# --mirror为从aliendao.cn镜像下载,如果aliendao.cn没有镜像,则会转到hf# 默认为Trueparser.add_argument('--mirror', action='store_true',default=True, required=False)parser.add_argument('--token', default="", type=str, required=False)# --e为企业付费版parser.add_argument('--e', action='store_true',default=False, required=False)args = parser.parse_args()if args.mirror:download_model_from_mirror(args.repo_id, args.repo_type, args.token, args.e)else:download_model_retry(args.repo_id, args.repo_type)

2. 代码

Stable Diffusion 完整推断流程如下(from https://huggingface.co/stabilityai/stable-diffusion-2-1):

import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepSchedulermodel_id = "/dataroot/models/stabilityai/stable-diffusion-2-1"  # 预训练模型的下载路径# Use the DPMSolverMultistepScheduler (DPM-Solver++) scheduler here instead
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]image.save("astronaut_rides_horse.png")

参考文献

  1. https://aliendao.cn/model_download.py
  2. https://github.com/Stability-AI/stablediffusion

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/179440.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

8.查询数据

一、单表查询 MySQL从数据表中查询数据的基本语为SELECT语。SELECT语的基本格式是: SELECT {* | <字段列名>} [ FROM <表 1>, <表 2>… [WHERE <表达式> [GROUP BY <group by definition> [HAVING <expression> [{<operator>…

POE也收费了

一直通过POE在用chatgpt&#xff0c;今天下午发现要收费了…

想买GPT4会员却只能排队?来看看背后的故事!

文章目录 &#x1f9d0; 为什么要进候选名单&#xff1f;&#x1f50d; 究竟发生了什么&#xff1f;&#x1f62e; IOS端还能买会员&#xff01;&#x1f914; 网页端为啥不能订会员&#xff1f;第一点&#xff1a;防止黑卡消费第二点&#xff1a;当技术巨头遇上资源瓶颈&#…

使用MathType将文献中的数学公式进行转换

mathtype将文献中的数学公式进行转换 文章目录 mathtype将文献中的数学公式进行转换一、截图识别二、MathType下载与设置2.1、MathType介绍2.2、[下载位置](http://www.51xiazai.cn/soft/5975.htm)2.3、设置 三、使用MathType&#xff1a; 一、截图识别 这两个在线网站都可以将…

Linux是什么,Linux系统介绍

很多小伙伴都不是那么了解和知道Linux&#xff0c;到底Linux是什么&#xff1f; 像大家用到的安卓手机&#xff0c;生活中用到的各种智能设备&#xff0c;比如路由器&#xff0c;光猫&#xff0c;智能家具等&#xff0c;很多都是在Linux操作系统上。 Linux是什么&#xff1f;Li…

指针传2

几天没有写博客了&#xff0c;怎么说呢&#xff1f;这让我总感觉缺点什么&#xff0c;心里空落落的&#xff0c;你懂吧&#xff01; 好了&#xff0c;接下来开始我们今天的正题&#xff01; 1. ⼆级指针 我们先来看看代码&#xff1a; 首先创建了一个整型变量a&#xff0c;将…

【开源】基于Vue和SpringBoot的快乐贩卖馆管理系统

项目编号&#xff1a; S 064 &#xff0c;文末获取源码。 \color{red}{项目编号&#xff1a;S064&#xff0c;文末获取源码。} 项目编号&#xff1a;S064&#xff0c;文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 数据中心模块2.2 搞笑视频模块2.3 视…

模拟业务流程+构造各种测试数据,一文带你测试效率提升80%

&#x1f4e2;专注于分享软件测试干货内容&#xff0c;欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1f4dd; 如有错误敬请指正&#xff01;&#x1f4e2;交流讨论&#xff1a;欢迎加入我们一起学习&#xff01;&#x1f4e2;资源分享&#xff1a;耗时200小时精选的「软件测试」资…

Java设计模式-结构型模式-适配器模式

适配器模式 适配器模式应用场景案例类适配器模式对象适配器模式接口适配器模式适配器模式在源码中的使用 适配器模式 如图&#xff1a;国外插座标准和国内不同&#xff0c;要使用国内的充电器&#xff0c;就需要转接插头&#xff0c;转接插头就是起到适配器的作用 适配器模式&…

Django之模版层

文章目录 模版语法传值模版语法传值特性模版语法标签语法格式if模板标签for模板标签with起别名 模版语法过滤器常用过滤器 自定义过滤器、标签、inclusion_tag自定义过滤器自定义标签自定义inclusion_tag 模版导入模版继承 模版语法传值 模板层三种语法{{}}:主要与数据值相关{%…

C语言不可不敲系列:跳水比赛排名问题

目录 1题干&#xff1a; 2解题思路&#xff1a; 3代码: 4运行结果: 5总结: 1题干&#xff1a; 5位运动员参加了10米台跳水比赛&#xff0c;有人让他们预测比赛结果 A选手说&#xff1a;B第二&#xff0c;我第三&#xff1b; B选手说&#xff1a;我第二&#xff0c;E第四&am…

Redhat Linux v8.2 实时内核环境配置及参数调优

Redhat-Linux V8.2 实时内核环境配置及参数调优 -------物理机 & 虚拟机 一、前言 本文档包含有关Redhat Linux for Real Time的基本安装和调试信息。许多行业和组织需要极高性能的计算&#xff0c;并且可能需要低且可预测的延迟&#xff0c;尤其是在金融和电信行业中。延…