手把手写深度学习(23):视频扩散模型之Video DataLoader

手把手写深度学习(0):专栏文章导航

前言:训练自己的视频扩散模型的第一步就是准备数据集,而且这个数据集是text-video或者image-video的多模态数据集,这篇博客手把手教读者如何写一个这样扩散模型的的Video DataLoader。

目录

准备工作

下载数据集

视频数据打标签

代码讲解

纯视频文件夹+txt描述prompt 读取方式

CSV描述文件读取方式


准备工作

下载数据集

一般会去下载webvid数据集,但是这个数据集非常大,如果读者不做预训练的话不建议下载。

《Animating Pictures with Eulerian Motion Fields》提供了一个比较小的测试数据集:Animating Pictures with Eulerian Motion Fields

大概一个GB左右,谷歌云盘的链接如下:

https://drive.google.com/file/d/1-MKuNxO1mjopgY6UoEVGDVt5I_QvVeDn/view

下载之后的.pth文件我们暂时不用管,可以先删除掉,只保留.mp4文件。

视频数据打标签

很多数据集是没有一个比较好的文字描述的,如果我们要训练text-to-video的任务,第一步要做的事情是对视频数据打上文字标签。

如果有,那么就算了,主打一个淘气(不是)

还是下一讲专门讲一下如何用V-BLIP给视频数据打上text标签吧

代码讲解

纯视频文件夹+txt描述prompt 读取方式

第一个DataLoader只需要输入视频的文件夹路径,prompt要么是全部指定成相同的(那肯定不行),要么从同名的txt文件中读取:

        if os.path.exists(self.video_files[index].replace(".mp4", ".txt")):with open(self.video_files[index].replace(".mp4", ".txt"), "r") as f:prompt = f.read()else:prompt = self.fallback_prompt

注意这里的text我们直接用预训练的tokenizer编码了,如果不想要的话也可以把这里注释掉:

    def get_prompt_ids(self, prompt):return self.tokenizer(prompt,truncation=True,padding="max_length",max_length=self.tokenizer.model_max_length,return_tensors="pt",).input_ids

获取视频的部分需要特别注意的是,要把"f h w c"转换成"f c h w":

        video = rearrange(video, "f h w c -> f c h w")

完整代码如下:

class VideoFolderDataset(Dataset):def __init__(self,tokenizer=None,width: int = 256,height: int = 256,n_sample_frames: int = 16,fps: int = 8,path: str = "./data",fallback_prompt: str = "",use_bucketing: bool = False,**kwargs):self.tokenizer = tokenizerself.use_bucketing = use_bucketingself.fallback_prompt = fallback_promptself.video_files = glob(f"{path}/*.mp4")self.width = widthself.height = heightself.n_sample_frames = n_sample_framesself.fps = fpsdef get_frame_buckets(self, vr):h, w, c = vr[0].shape        width, height = sensible_buckets(self.width, self.height, w, h)resize = T.transforms.Resize((height, width), antialias=True)return resizedef get_frame_batch(self, vr, resize=None):n_sample_frames = self.n_sample_framesnative_fps = vr.get_avg_fps()every_nth_frame = max(1, round(native_fps / self.fps))every_nth_frame = min(len(vr), every_nth_frame)effective_length = len(vr) // every_nth_frameif effective_length < n_sample_frames:n_sample_frames = effective_lengtheffective_idx = random.randint(0, (effective_length - n_sample_frames))idxs = every_nth_frame * np.arange(effective_idx, effective_idx + n_sample_frames)video = vr.get_batch(idxs)video = rearrange(video, "f h w c -> f c h w")if resize is not None: video = resize(video)return video, vrdef process_video_wrapper(self, vid_path):video, vr = process_video(vid_path,self.use_bucketing,self.width, self.height, self.get_frame_buckets, self.get_frame_batch)return video, vrdef get_prompt_ids(self, prompt):return self.tokenizer(prompt,truncation=True,padding="max_length",max_length=self.tokenizer.model_max_length,return_tensors="pt",).input_ids@staticmethoddef __getname__(): return 'folder'def __len__(self):return len(self.video_files)def __getitem__(self, index):video, _ = self.process_video_wrapper(self.video_files[index])if os.path.exists(self.video_files[index].replace(".mp4", ".txt")):with open(self.video_files[index].replace(".mp4", ".txt"), "r") as f:prompt = f.read()else:prompt = self.fallback_promptprompt_ids = self.get_prompt_ids(prompt)return {"pixel_values": normalize_input(video[0]), "prompt_ids": prompt_ids, "text_prompt": prompt, 'dataset': self.__getname__()}

CSV描述文件读取方式

这种方法每次都要打开一个txt文件去读取prompt,很不方便。而且如果读取的量级大了之后IO的开销会很大!

所以建议使用CSV方式的读取,CSV文件中存放着video-prompt的对应关系,样例如下:

video_path,prompt

...

video_path建议写成绝对路径,这样更方便读取。

完整代码如下:

class VideoCSVDataset(Dataset):def __init__(self,tokenizer=None,width: int = 256,height: int = 256,n_sample_frames: int = 16,fps: int = 8,csv_path: str = "./data",use_bucketing: bool = False,**kwargs):self.tokenizer = tokenizerself.use_bucketing = use_bucketingif not os.path.exists(csv_path):raise FileNotFoundError(f"The csv path does not exist: {csv_path}")self.csv_data = pd.read_csv(csv_path)self.width = widthself.height = heightself.n_sample_frames = n_sample_framesself.fps = fpsdef get_frame_buckets(self, vr):h, w, c = vr[0].shape        width, height = sensible_buckets(self.width, self.height, w, h)resize = T.transforms.Resize((height, width), antialias=True)return resizedef get_frame_batch(self, vr, resize=None):n_sample_frames = self.n_sample_framesnative_fps = vr.get_avg_fps()every_nth_frame = max(1, round(native_fps / self.fps))every_nth_frame = min(len(vr), every_nth_frame)effective_length = len(vr) // every_nth_frameif effective_length < n_sample_frames:n_sample_frames = effective_lengtheffective_idx = random.randint(0, (effective_length - n_sample_frames))idxs = every_nth_frame * np.arange(effective_idx, effective_idx + n_sample_frames)video = vr.get_batch(idxs)video = rearrange(video, "f h w c -> f c h w")if resize is not None: video = resize(video)return video, vrdef process_video_wrapper(self, vid_path):video, vr = process_video(vid_path,self.use_bucketing,self.width, self.height, self.get_frame_buckets, self.get_frame_batch)return video, vrdef get_prompt_ids(self, prompt):return self.tokenizer(prompt,truncation=True,padding="max_length",max_length=self.tokenizer.model_max_length,return_tensors="pt",).input_ids@staticmethoddef __getname__(): return 'csv'def __len__(self):return len(self.csv_data)def __getitem__(self, index):print(self.csv_data.iloc[index])video_path, prompt = self.csv_data.iloc[index]video, _ = self.process_video_wrapper(video_path)prompt_ids = self.get_prompt_ids(prompt)return {"pixel_values": normalize_input(video[0]), "prompt_ids": prompt_ids, "text_prompt": prompt, 'dataset': self.__getname__()}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/535343.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

一文告诉你服务器为什么要托管?

IDC的全称是Internet Data Center&#xff0c;即“互联网数据中心”&#xff0c;现在大家都称作“IDC数据中心” 。 什么是IDC服务器托管服务&#xff1f; 服务器托管是企业用户为提高公司效益、压缩成本&#xff0c;将自身企业的服务器及相关设备放到专业IDC服务商所建设的数…

考研C语言复习进阶(1)

目录 1. 数据类型介绍 1.1 类型的基本归类&#xff1a; 2. 整形在内存中的存储 2.1 原码、反码、补码 2.2 大小端介绍 3. 浮点型在内存中的存储 ​编辑 1. 数据类型介绍 前面我们已经学习了基本的内置类型&#xff1a; char //字符数据类型 short //短整型 int /…

10个调研分析模板,轻松搞定市场调查与分析!

市场调查与分析&#xff0c;对于任何一家企业来说&#xff0c;都是不可或缺的一环。对进入市场开展业务的企业而言&#xff0c;不管处于哪个阶段——初创公司&#xff0c;抑或是已经稳定运营的企业&#xff0c;了解市场动态和客户需求总是至关重要的。 但必须承认的是&#xf…

探索C++中的动态数组:实现自己的Vector容器

&#x1f389;个人名片&#xff1a; &#x1f43c;作者简介&#xff1a;一名乐于分享在学习道路上收获的大二在校生 &#x1f648;个人主页&#x1f389;&#xff1a;GOTXX &#x1f43c;个人WeChat&#xff1a;ILXOXVJE &#x1f43c;本文由GOTXX原创&#xff0c;首发CSDN&…

mybatis基础操作(三)

动态sql 通过动态sql实现多条件查询&#xff0c;这里以查询为例&#xff0c;实现动态sql的书写。 创建members表 创建表并插入数据&#xff1a; create table members (member_id int (11),member_nick varchar (60),member_gender char (15),member_age int (11),member_c…

什么时候去检测大数据信用风险比较合适?

什么时候去检测大数据信用风险比较合适?在当今这个数据驱动的时代&#xff0c;大数据信用风险检测已经成为个人的一项重要需求。本文将从贷前检测、信息泄露检测和定期检测三个方面&#xff0c;阐述何时进行大数据信用风险检测较为合适。 一、贷前检测 大数据信用风险检测在贷…

指针【理论知识速成】(3)

一.指针的使用和传值调用&#xff1a; 在了解指针的传址调用前&#xff0c;先来额外了解一下 “传值调用” 1.传值调用&#xff1a; 对于来看这个帖子的你相信代码展示胜过千言万语 #include <stdio.h> #include<assert.h> int convert(int a, int b) {int c 0…

【Python】python实现Apriori算法和FP-growth算法(附源代码)

使用一种你熟悉的程序设计语言&#xff0c;实现&#xff08;1&#xff09;Apriori算法和&#xff08;2&#xff09;FP-growth算法。 目录 1、Apriori算法2、F-Growth算法3、两种算法比较 1、Apriori算法 def item(dataset): # 求第一次扫描数据库后的 候选集&#xff0c;&am…

OpenCASCADE开发指南<四>:OCC 数据类型和句柄

一个软件首先要规定能处理的数据类型&#xff0c; 其次要实现三项最基本的功能——引用管理、内存管理和异常管理。在 OCC 中&#xff0c;这三项功能分别对应基础类中的句柄、内存管理器和异常类。 1 数据类型 在基本概念篇里&#xff0c;已经介绍了 OCC 数据类型的分类&…

Linux:好用的Linux指令

进程的Linux指令 1.查看进程信息 ​​​​ps ajx | head -1 && ps ajx | grep 进程名创建一个进程后输入上述代码&#xff0c;会打印进程信息&#xff0c;当我们在code.exe中写入打印pid&#xff0c;ppid&#xff0c;这里也和进程信息一致。 while :; do ps ajx | he…

Python语言在编程业界的地位——《跟老吕学Python编程》附录资料

Python语言在编程业界的地位——《跟老吕学Python编程》附录资料 ⭐️Python语言在编程业界的地位2024年3月编程语言排行榜&#xff08;TIOBE前十&#xff09; ⭐️Python开发语言开发环境介绍1.**IDLE**2.⭐️PyCharm3.**Anaconda**4.**Jupyter Notebook**5.**Sublime Text** …

机器学习——过拟合问题、正则化解决法

过拟合的基本概念 欠拟合&#xff1a;假设函数没有很好的拟合训练集数据&#xff0c;也称这个假设函数有高偏差&#xff1b; 过拟合&#xff1a;过拟合也称为高方差。在假设函数中添加高阶多项式&#xff0c;让假设函数几乎能完美的拟合每个样本数据点&#xff0c;这看起来很…