使用Pytorch从零开始实现CLIP

生成式建模知识回顾:
[1] 生成式建模概述
[2] Transformer I,Transformer II
[3] 变分自编码器
[4] 生成对抗网络,高级生成对抗网络 I,高级生成对抗网络 II
[5] 自回归模型
[6] 归一化流模型
[7] 基于能量的模型
[8] 扩散模型 I, 扩散模型 II
在这里插入图片描述

引言

2021 年 1 月,OpenAI 宣布了两种新模型:DALL-E 和 CLIP,这两种模型都是以某种方式连接文本和图像的多模态模型。在本文中,我们将在PyTorch中从零开始实现 CLIP 模型。OpenAI 开源了一些与 CLIP 模型相关的代码,但我发现它令人生畏,而且并不简洁。

CLIP 有什么作用?为什么有趣?

在《Learning Transferable Visual Models From Natural Language Supervision》论文中,OpenAI 介绍了他们的新模型,称为CLIP,用于Contrastive Language-Image Pre-training。简而言之,该模型学习整个句子与其描述的图像之间的关系;从某种意义上说,当训练模型时,给定一个输入句子,它将能够检索与该句子相对应的最相关的图像。这里重要的是,它是在完整的句子上进行训练,而不是像car、dog等单一类别一样。直觉上,当在整个句子上进行训练时,模型可以学习更多的东西,并找到图像和文本之间的一些模式。

他们还表明,当该模型在巨大的图像数据集及其相应文本上进行训练时,它也可以充当分类器。我鼓励你研究论文原文,以更多地了解这个令人兴奋的模型及其在基准数据集上的惊人结果。仅举一例,使用此策略训练的 CLIP 模型对 ImageNet 的分类效果比在 ImageNet 本身上训练的 SOTA 模型更好,该 SOTA 模型专门针对单一分类任务进行了优化!

先跳过过程,让我们看看我们将在本文中从头开始构建的最终模型能够实现什么功能:给出诸如“一个男孩用滑板跳跃”或“一个女孩从秋千上跳跃”这样的查询,模型将检索最相关的图像:
在这里插入图片描述

开始

让我们直接看它的 PyTorch 实现。首先,我们需要一个包含图像和一些描述它们的文本的数据集。坦率地说,网上有很多可用的。我们将使用Flickr 8k 数据集(您可以使用更大的 30k 版本,最终模型的性能会更好),该数据集主要用于图像字幕任务。但是, 我们也可以用它来训练 CLIP 模型。

以下代码将下载 8k(如果取消注释最后几行,则下载 30k)并解压缩它们。Kaggle数据集之下载可参考前文。

!pip install kaggle --upgrade
import os
os.environ['KAGGLE_USERNAME'] = "XXXXXX"
os.environ['KAGGLE_KEY'] = "XXXXXXXXXXXXXXXXXXXXXX" # Enter your Kaggle key here# For Flickr 8k
!kaggle datasets download -d adityajn105/flickr8k
!unzip flickr8k.zip
dataset = "8k"# For Flickr 30k
# !kaggle datasets download -d hsankesara/flickr-image-dataset
# !unzip flickr-image-dataset.zip
# dataset = "30k"

关于此数据集需要注意的一件事是: 每张图像都有 5 个标题。后面写损失函数的时候再讲这个!

数据集

正如您在本文的标题图片中看到的,我们需要对图像及其描述文本进行编码。因此,数据集需要返回图像和文本。当然,我们不会将原始文本提供给我们的文本编码器!我们将使用HuggingFace库中的DistilBERT模型(它比 BERT 小,但性能几乎与 BERT 一样)作为我们的文本编码器;因此,我们需要使用 DistilBERT 分词器对句子(标题)进行分词,然后将分词 id (input_ids) 和注意力掩码提供给 DistilBERT。因此,数据集也需要处理标记化。您可以在下面看到数据集的代码。下面我将解释代码中发生的最重要的事情。

关于配置和CFG的说明:我用 python 脚本编写了代码,然后将其转换为 Jupyter Notebook。因此,对于 python 脚本,config 是一个普通的 python 文件,我在其中放置所有超参数,对于 Jupyter Notebook,它是在笔记本开头定义的一个类,用于保留所有超参数。查看GitHub 存储库或笔记本以查看所有超参数。

import os
import cv2
import torch
import albumentations as Aimport config as CFGclass CLIPDataset(torch.utils.data.Dataset):def __init__(self, image_filenames, captions, tokenizer, transforms):"""image_filenames and cpations must have the same length; so, if there aremultiple captions for each image, the image_filenames must have repetitivefile names """self.image_filenames = image_filenamesself.captions = list(captions)self.encoded_captions = tokenizer(list(captions), padding=True, truncation=True, max_length=CFG.max_length)self.transforms = transformsdef __getitem__(self, idx):item = {key: torch.tensor(values[idx])for key, values in self.encoded_captions.items()}image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}")image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)image = self.transforms(image=image)['image']item['image'] = torch.tensor(image).permute(2, 0, 1).float()item['caption'] = self.captions[idx]return itemdef __len__(self):return len(self.captions)def get_transforms(mode="train"):if mode == "train":return A.Compose([A.Resize(CFG.size, CFG.size, always_apply=True),A.Normalize(max_pixel_value=255.0, always_apply=True),])else:return A.Compose([A.Resize(CFG.size, CFG.size, always_apply=True),A.Normalize(max_pixel_value=255.0, always_apply=True),])

init 中,我们收到一个 tokenizer 对象,它实际上是一个 HuggingFace tokinzer;运行模型时将加载此标记生成器。我们将字幕填充并截断为指定的 max_length。在 getitem 中,我们将首先加载一个编码的标题,它是一个带有 input_ids 和 Attention_mask 键的字典,根据其值创建张量,然后我们将加载相应的图像,对其进行变换和增强(如果有的话!),然后我们将其设为张量并将其放入以“image”为键的字典中。最后,我们将带有“caption”键的标题的原始文本放入字典中,仅用于可视化目的。

我没有使用额外的数据增强,但如果您想提高模型的性能,可以添加它们。

图像编码器

图像编码器代码很简单。我在这里使用 PyTorch 图像模型库 (timm),它提供了从 ResNets 到 EfficientNets 等许多不同的图像模型。这里我们将使用ResNet50作为我们的图像编码器。如果您不想安装新的库,您可以轻松地使用 torchvision 库来使用 ResNets。

class ImageEncoder(nn.Module):"""Encode images to a fixed size vector"""def __init__(self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable):super().__init__()self.model = timm.create_model(model_name, pretrained, num_classes=0, global_pool="avg")for p in self.model.parameters():p.requires_grad = trainabledef forward(self, x):return self.model(x)

该代码将每个图像编码为固定大小的向量,其大小与模型输出通道的大小相同(在 ResNet50 的情况下,向量大小将为2048)。这是 nn.AdaptiveAvgPool2d() 层之后的输出。

文本编码器

正如我之前提到的,我将使用 DistilBERT 作为文本编码器。与它的大哥 BERT 一样,两个特殊的标记将被添加到实际的输入标记中:CLS和SEP,它们标记句子的开始和结束。为了获取句子的完整表示(正如相关的 BERT 和 DistilBERT 论文所指出的那样),我们使用 CLS 标记的最终表示,并且我们希望该表示能够捕获句子(标题)的整体含义。这样想的话,就类似于我们对图像的处理,将其转换为固定大小的向量。

from transformers import DistilBertModel, DistilBertConfigclass TextEncoder(nn.Module):def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):super().__init__()if pretrained:self.model = DistilBertModel.from_pretrained(model_name)else:self.model = DistilBertModel(config=DistilBertConfig())for p in self.model.parameters():p.requires_grad = trainable# we are using the CLS token hidden representation as the sentence's embeddingself.target_token_idx = 0def forward(self, input_ids, attention_mask):output = self.model(input_ids=input_ids, attention_mask=attention_mask)last_hidden_state = output.last_hidden_statereturn last_hidden_state[:, self.target_token_idx, :]

对于 DistilBERT(以及 BERT),每个标记的输出隐藏表示是一个大小为768的向量。因此,整个标题将被编码为大小为 768 的 CLS 令牌表示形式。

Projection Head

我使用Keras 代码示例实现在 PyTorch 中编写了以下内容。

现在我们已经将图像和文本编码为固定大小的向量(图像为 2048,文本为 768),我们需要将它们带(投影)到一个图像和文本具有相似尺寸的新世界(!),以便能够对它们进行比较,将不相关的图像和文本分开,并将匹配的图像和文本放在一起。因此,以下代码将把 2048 和 768 维向量带入 256 (projection_dim) 维世界,我们可以在其中比较它们:

import torch
from torch import nnclass ProjectionHead(nn.Module):def __init__(self,embedding_dim,projection_dim=CFG.projection_dim,dropout=CFG.dropout):super().__init__()self.projection = nn.Linear(embedding_dim, projection_dim)self.gelu = nn.GELU()self.fc = nn.Linear(projection_dim, projection_dim)self.dropout = nn.Dropout(dropout)self.layer_norm = nn.LayerNorm(projection_dim)def forward(self, x):projected = self.projection(x)x = self.gelu(projected)x = self.fc(x)x = self.dropout(x)x = x + projectedx = self.layer_norm(x)return x

“embedding_dim”是输入向量的大小(图像为 2048,文本为 768),“projection_dim”是输出向量的大小,在我们的例子中为 256。要了解这部分的详细信息,您可以参考CLIP 论文。

CLIP模型

这部分是最有趣的!这里我还要讲一下损失函数。我将Keras 代码示例中的一些代码翻译成 PyTorch 来编写这部分。查看代码,然后阅读该代码块下面的说明。

import torch
from torch import nn
import torch.nn.functional as Fimport config as CFG
from modules import ImageEncoder, TextEncoder, ProjectionHeadclass CLIPModel(nn.Module):def __init__(self,temperature=CFG.temperature,image_embedding=CFG.image_embedding,text_embedding=CFG.text_embedding,):super().__init__()self.image_encoder = ImageEncoder()self.text_encoder = TextEncoder()self.image_projection = ProjectionHead(embedding_dim=image_embedding)self.text_projection = ProjectionHead(embedding_dim=text_embedding)self.temperature = temperaturedef forward(self, batch):# Getting Image and Text Featuresimage_features = self.image_encoder(batch["image"])text_features = self.text_encoder(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])# Getting Image and Text Embeddings (with same dimension)image_embeddings = self.image_projection(image_features)text_embeddings = self.text_projection(text_features)# Calculating the Losslogits = (text_embeddings @ image_embeddings.T) / self.temperatureimages_similarity = image_embeddings @ image_embeddings.Ttexts_similarity = text_embeddings @ text_embeddings.Ttargets = F.softmax((images_similarity + texts_similarity) / 2 * self.temperature, dim=-1)texts_loss = cross_entropy(logits, targets, reduction='none')images_loss = cross_entropy(logits.T, targets.T, reduction='none')loss =  (images_loss + texts_loss) / 2.0 # shape: (batch_size)return loss.mean()def cross_entropy(preds, targets, reduction='none'):log_softmax = nn.LogSoftmax(dim=-1)loss = (-targets * log_softmax(preds)).sum(1)if reduction == "none":return losselif reduction == "mean":return loss.mean()

在这里,我们将使用之前构建的模块来实现主模型。init 函数是不言自明的。在前向函数中,我们首先将图像和文本分别编码为固定大小的向量(具有不同的维度)。之后,使用单独的投影模块,我们将它们投影到我之前谈到的共享世界(空间)。这里的编码将变得相似的形状(在我们的例子中是 256)。之后我们将计算损失。我再次建议阅读 CLIP 论文以使其更好,但我会尽力解释这部分。

在线性代数中,衡量两个向量是否具有相似特征(它们彼此相似)的一种常见方法是计算它们的点积(将匹配项相乘并取它们的总和);如果最终的数字很大,那么它们是相似的,如果最后的数字很小,那么它们就不相似(相对而言)!

好的!我刚才所说的是理解这个损失函数需要牢记的最重要的事情。我们继续吧。我们讨论了两个向量,但是,我们这里有什么?我们有 image_embeddings,形状为 (batch_size, 256) 的矩阵和形状为 (batch_size, 256) 的 text_embeddings。够简单的!这意味着我们有两组向量而不是两个单个向量。我们如何测量两组向量(两个矩阵)彼此的相似程度?同样,使用点积(在这种情况下,PyTorch 中的 @ 运算符执行点积或矩阵乘法)。为了能够将这两个矩阵相乘,我们转置第二个矩阵。好的,我们得到一个形状为 (batch_size, batch_size) 的矩阵,我们将其称为logits。(在我们的例子中,温度等于 1.0,因此,它没有什么区别。您可以使用它,看看它会产生什么差异。另请参阅论文,了解它为什么在这里!)。

我希望你还在我身边!如果不是也没关系,只需检查代码并检查它们的形状即可。现在我们有了逻辑,我们需要目标。我需要说的是,有一种更直接的方法来获取目标,但我必须为我们的案例这样做(我将在下一段中讨论原因)。

让我们考虑一下我们希望这个模型学习什么:我们希望它学习给定图像和描述它的标题的“相似表示(向量)”。这意味着我们要么给它一个图像,要么给它描述它的文本,我们希望它为两者生成相同的 256 大小的向量。

因此,在最好的情况下,text_embeddings 和 image_embedding 矩阵应该相同,因为它们描述的是相似的事物。现在我们想一下:如果发生这种情况,logits 矩阵会是什么样子?让我们看一个简单的例子!

import torch
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as pltbatch_size = 4
dim = 256
embeddings = torch.randn(batch_size, dim)
out = embeddings @ embeddings.T
print(F.softmax(out, dim=-1))-----------
# tensor([[1., 0., 0., 0.],
#         [0., 1., 0., 0.],
#         [0., 0., 1., 0.],
#         [0., 0., 0., 1.]])

因此,在最好的情况下,logits 将是一个矩阵,如果我们采用其 softmax,对角线中将有 1.0(一个用奇特的词来称呼它的单位矩阵!)。由于损失函数的作用是使模型的预测与目标相似(至少在大多数情况下!),因此我们希望这样的矩阵作为我们的目标。这就是我们在上面的代码块中计算 images_similarity 和 texts_similarity 矩阵的原因。

现在我们已经有了目标矩阵,我们将使用简单的交叉熵来计算实际损失。我已经将交叉熵的完整矩阵形式编写为函数,您可以在代码块的底部看到。好的!我们完了!是不是很简单?!好吧,你可以忽略下一段,但如果你好奇的话,里面有一个重要的注释。

这就是为什么我没有使用更简单的方法:我需要承认在 PyTorch 中有一种更简单的方法来计算这种损失;通过这样做:nn.CrossEntropyLoss()(logits, torch.arange(batch_size))。为什么我这里没有使用它?有两个原因。1- 我们使用的数据集对单个图像有多个标题;因此,批次中可能存在两个具有相似标题的相同图像(这种情况很少见,但可能会发生)。使用这种更简单的方法获取损失将忽略这种可能性,并且模型学会分离实际上相同的两个表示(假设它们不同)。显然,我们不希望这种情况发生,因此我以照顾这些边缘情况的方式计算了整个目标矩阵。2-按照我的方式做,让我更好地理解了这个损失函数中发生的事情;所以,我认为这也会给你更好的直觉!

训练

这是一个训练模型的函数。这里没有发生太多事情;只需加载批次,将它们输入模型并步进优化器和 lr_scheduler。

def train_epoch(model, train_loader, optimizer, lr_scheduler, step):loss_meter = AvgMeter()tqdm_object = tqdm(train_loader, total=len(train_loader))for batch in tqdm_object:batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}loss = model(batch)optimizer.zero_grad()loss.backward()optimizer.step()if step == "batch":lr_scheduler.step()count = batch["image"].size(0)loss_meter.update(loss.item(), count)tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))return loss_meter

好的!我们已经完成了模型的训练。现在,我们需要进行推理,在我们的例子中,将给模型一段文本,并希望它从看不见的验证(或测试)集中检索最相关的图像。

获取图像嵌入

在此函数中,我们加载训练后保存的模型,向其提供验证集中的图像,并返回形状为 (valid_set_size, 256) 的 image_embeddings 和模型本身。

def get_image_embeddings(valid_df, model_path):tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)valid_loader = build_loaders(valid_df, tokenizer, mode="valid")model = CLIPModel().to(CFG.device)model.load_state_dict(torch.load(model_path, map_location=CFG.device))model.eval()valid_image_embeddings = []with torch.no_grad():for batch in tqdm(valid_loader):image_features = model.image_encoder(batch["image"].to(CFG.device))image_embeddings = model.image_projection(image_features)valid_image_embeddings.append(image_embeddings)return model, torch.cat(valid_image_embeddings)

寻找匹配项

该函数执行我们希望模型能够完成的最终任务:它获取模型、image_embeddings 和文本查询。它将显示验证集中最相关的图像!是不是很神奇呢?让我们看看它到底表现如何!

def find_matches(model, image_embeddings, query, image_filenames, n=9):tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)encoded_query = tokenizer([query])batch = {key: torch.tensor(values).to(CFG.device)for key, values in encoded_query.items()}with torch.no_grad():text_features = model.text_encoder(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])text_embeddings = model.text_projection(text_features)image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)dot_similarity = text_embeddings_n @ image_embeddings_n.T# multiplying by 5 to consider that there are 5 captions for a single image# so in indices, the first 5 indices point to a single image, the second 5 indices# to another one and so on.values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)matches = [image_filenames[idx] for idx in indices[::5]]_, axes = plt.subplots(math.sqrt(n), math.sqrt(n), figsize=(10, 10))for match, ax in zip(matches, axes.flatten()):image = cv2.imread(f"{CFG.image_path}/{match}")image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)ax.imshow(image)ax.axis("off")plt.show()

让我们看一些例子!此时,当我看到输出时,我高兴地尖叫起来,并惊讶于该模型实际上正在学习图像和文本之间的这种关系!这种感觉简直难以置信。

find_matches(model, image_embeddings,query="one dog sitting on the grass",image_filenames=valid_df['image'].values,n=9)

这就是我们使用这个函数的方式。结果如下:
在这里插入图片描述
我当时就想:哇!这个模型知道一些东西!当然它并不完美,因为有些图片中有两只狗,但考虑到训练集小和训练时间短,我认为这很棒!

让我们看看其他一些输出。Quert写在每个图像的顶部。
在这里插入图片描述
看!它还可以算数!将此与上一个进行比较。该模型知道“两个”的含义,并提供了有两只狗的图像,与之前的查询形成鲜明对比!这一刻我第二次震惊得尖叫起来:)

文章开头的输出:
在这里插入图片描述
对于下面的示例,模型犯了一些错误,但总的来说,它显然对文本和图像都有很好的理解。
在这里插入图片描述

资源

  • 本文对应的Github代码库

本博文译自Moein Shariatnia的博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/235975.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring Security 的使用

一、简介 1.1、Spring Security 相关概念 1.过滤器链(Filter Chain) 基于Servlet过滤器(Filter)处理和拦截请求,进行身份验证、授权等安全操作。过滤器链按顺序执行,每个过滤器负责一个具体的安全功能。 …

【经验分享】openGauss 客户端(Data Studio / DBeaver)连接方式

前言 本篇介绍了openGauss常用的客户端连接工具Data Studio和DBeaver 01 客户端工具 openGauss部署之后,在服务器上提供了在命令行下运行的数据库连接工具gsql。此工具除了具备操作数据库的基本功能,还提供了若干高级特性,便于用户使用。…

【Node.js】笔记整理 5 - Express框架

写在最前:跟着视频学习只是为了在新手期快速入门。想要学习全面、进阶的知识,需要格外注重实战和官方技术文档,文档建议作为手册使用 系列文章 【Node.js】笔记整理 1 - 基础知识【Node.js】笔记整理 2 - 常用模块【Node.js】笔记整理 3 - n…

qt5.15播放音频示例(4种方法)

文章目录 Qt播放音频方法一 QMediaPlayer方法二 QSound方法三 QSoundEffect方法四 QAudioOutput问题1 播放无声问题2 QAudioOutput播放嗡嗡声的问题参考Qt播放音频 在linux系统中,可以通过aplay进行简单的播放音频,如 aplay /opt/Audio/test.wav在图形界面,也可以封装apla…

玩转大数据:3-Hadoop家族的力量与挑战

引言 Hadoop作为一个强大的大数据处理框架,以其分布式计算和存储能力在业界备受关注。然而,Hadoop在应用场景、适用范围、社区支持以及后续持续发展等方面也面临着一些挑战。本文将围绕Hadoop的生态应用,以及来自其他生态的挑战,…

【排序,直接插入排序 折半插入排序 希尔插入排序】

文章目录 排序排序方法的分类插入排序直接插入排序折半插入排序希尔插入排序 排序 将一组杂乱无章的数据按照一定规律排列起来。将无序序列排成一个有序序列。 排序方法的分类 储存介质: 内部排序:数据量不大,数据在内存,无需…

leetcode 18. 四数之和(优质解法)

代码&#xff1a; class Solution {public List<List<Integer>> fourSum(int[] nums, int target) {List<List<Integer>> listsnew ArrayList<>();int lengthnums.length;Arrays.sort(nums);for(int i0;i<length-4;){for(int ji1;j<lengt…

【ArcGIS Pro二次开发】:CC工具箱1.1.4更新_免费_50+工具

CC工具箱1.1.4更新【2023.11.30】 使用环境要求&#xff1a;ArcGIS Pro 3.0 一、下载链接 工具安装文件及使用文档&#xff1a; https://pan.baidu.com/s/1OJmO6IPtMfX_vob3bMtvEg?pwduh5rhttps://pan.baidu.com/s/1OJmO6IPtMfX_vob3bMtvEg?pwduh5r 二、使用方法 1、在下…

抖音本地生活服务商申请条件

抖音的本地生活服务商目前有两种&#xff0c;一种是可以做全国的服务商&#xff0c;我们一般叫抖音本地生活服务商&#xff0c;一种是区域优待服务商&#xff0c;也就是后面出来的服务商&#xff0c;这两种服务商的申请方式大同小异。 相同的地方就是都需要给平台交保证金。抖…

Go语言 值传递

官方说法&#xff0c;Go中只有值传递&#xff0c;没有引用传递 而Go语言中的一些让你觉得它是引用传递的原因&#xff0c;是因为Go语言有值类型和引用类型&#xff0c;但是它们都是值传递。 值类型 有int、float、bool、string、array、sturct等 引用类型有slice&#xff0c…

FlatLaf:干净、优雅、扁平化,基于java swing现代开源跨平台外观

一个很不错的java swing ui库&#xff0c;idea主题风格&#xff0c;还能自定义 FlatLaf是用于JavaSwing 桌面应用程序的现代开源跨平台外观。 它看起来几乎是平的&#xff08;没有阴影或渐变&#xff09;、干净、简单和优雅。FlatLaf带有Light、Dark、IntelliJ和Darcula主题&a…

11 款顶级的免费 iPhone 数据恢复软件

iPhone 拥有巨大的存储容量。您可以在 iPhone 设备上存储图像、文档和视频等数据。有时&#xff0c;您的 iPhone 会发生许多意外事件&#xff0c;例如意外删除&#xff0c;从而导致数据丢失。这里有 11 个最好的免费 iPhone 数据恢复软件&#xff0c;您可以免费下载&#xff0c…