CLIP在Github上的使用教程

CLIP的github链接:https://github.com/openai/CLIP

CLIP

Blog,Paper,Model Card,Colab
CLIP(对比语言-图像预训练)是一个在各种(图像、文本)对上进行训练的神经网络。可以用自然语言指示它在给定图像的情况下预测最相关的文本片段,而无需直接对任务进行优化,这与 GPT-2 和 3 的零镜头功能类似。我们发现,CLIP 无需使用任何 128 万个原始标注示例,就能在 ImageNet "零拍摄 "上达到原始 ResNet50 的性能,克服了计算机视觉领域的几大挑战。

Usage用法

首先,安装 PyTorch 1.7.1(或更高版本)和 torchvision,以及少量其他依赖项,然后将此 repo 作为 Python 软件包安装。在 CUDA GPU 机器上,完成以下步骤即可:

conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0
pip install ftfy regex tqdm
pip install git+https://github.com/openai/CLIP.git

将上面的 cudatoolkit=11.0 替换为机器上相应的 CUDA 版本,如果在没有 GPU 的机器上安装,则替换为 cpuonly

import torch
import clip
from PIL import Imagedevice = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)with torch.no_grad():image_features = model.encode_image(image)text_features = model.encode_text(text)logits_per_image, logits_per_text = model(image, text)probs = logits_per_image.softmax(dim=-1).cpu().numpy()print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]

API

CLIP 模块提供以下方法:

clip.available_models()

返回可用 CLIP 模型的名称。例如下面就是我执行的结果。
在这里插入图片描述

clip.load(name, device=..., jit=False)

返回模型和模型所需的 TorchVision 变换(由 clip.available_models() 返回的模型名称指定)。它将根据需要下载模型。name参数也可以是本地检查点的路径。
可以选择指定运行模型的设备,默认情况下,如果有第一个 CUDA 设备,则使用该设备,否则使用 CPU。当 jitFalse 时,将加载模型的非 JIT 版本。

clip.tokenize(text: Union[str, List[str]], context_length=77)

返回包含给定文本输入的标记化序列的 LongTensor。这可用作模型的输入。

clip.load() 返回的模型支持以下方法:

model.encode_image(image: Tensor)

给定一批图像,返回 CLIP 模型视觉部分编码的图像特征。

model.encode_text(text: Tensor)

给定一批文本标记,返回 CLIP 模型语言部分编码的文本特征。

model(image: Tensor, text: Tensor)

给定一批图像和一批文本标记,返回两个张量,其中包含与每张图像和每个文本输入相对应的 logit 分数。这些值是相应图像和文本特征之间的余弦相似度乘以 100。

More Examples更多实例

Zero-Shot预测

下面的代码使用 CLIP 执行零点预测,如论文附录 B 所示。该示例从 CIFAR-100 数据集中获取一张图片,并预测数据集中 100 个文本标签中最有可能出现的标签。

import os
import clip
import torch
from torchvision.datasets import CIFAR100# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)# Prepare the inputs
image, class_id = cifar100[3637]
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)# Calculate features
with torch.no_grad():image_features = model.encode_image(image_input)text_features = model.encode_text(text_inputs)# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")

输出结果如下(具体数字可能因计算设备而略有不同):

Top predictions:snake: 65.31%turtle: 12.29%sweet_pepper: 3.83%lizard: 1.88%crocodile: 1.75%

请注意,本示例使用的 encode_image()encode_text() 方法可返回给定输入的编码特征。

Linear-probe evaluation线性探针评估

下面的示例使用 scikit-learn 对图像特征进行逻辑回归。

import os
import clip
import torchimport numpy as np
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
from tqdm import tqdm# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)# Load the dataset
root = os.path.expanduser("~/.cache")
train = CIFAR100(root, download=True, train=True, transform=preprocess)
test = CIFAR100(root, download=True, train=False, transform=preprocess)def get_features(dataset):all_features = []all_labels = []with torch.no_grad():for images, labels in tqdm(DataLoader(dataset, batch_size=100)):features = model.encode_image(images.to(device))all_features.append(features)all_labels.append(labels)return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy()# Calculate the image features
train_features, train_labels = get_features(train)
test_features, test_labels = get_features(test)# Perform logistic regression
classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1)
classifier.fit(train_features, train_labels)# Evaluate using the logistic regression classifier
predictions = classifier.predict(test_features)
accuracy = np.mean((test_labels == predictions).astype(float)) * 100.
print(f"Accuracy = {accuracy:.3f}")

请注意,C 值应通过使用验证分割进行超参数扫描来确定。

See Also

OpenCLIP:包括更大的、独立训练的 CLIP 模型,最高可达 ViT-G/14
Hugging Face implementation of CLIP:更易于与高频生态系统集成

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/266026.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【快速应用开发】看看RedwoodJS

自我介绍 做一个简单介绍,酒架年近48 ,有20多年IT工作经历,目前在一家500强做企业架构.因为工作需要,另外也因为兴趣涉猎比较广,为了自己学习建立了三个博客,分别是【全球IT瞭望】,【…

【C语言】【数据结构】自定义类型:结构体

引言 这是一篇对结构体的详细介绍,这篇文章对结构体声明、结构体的自引用、结构体的初始化、结构体的内存分布和对齐规则、库函数offsetof、以及进行内存对齐的原因、如何修改默认对齐数、结构体传参进行介绍和说明。 ✨ 猪巴戒:个人主页✨ 所属专栏&am…

# 一些视觉-激光、加速度传感器类的铣削振动测试方法案例

一些视觉-激光类的铣削振动测试方法 1. 基于激光测振仪的振动测试2. 切削加工的 加速度传感器实测信号2.1 x轴向信号2.2 Y轴向信号3. 关于数值频域积分1. 基于激光测振仪的振动测试 【1】舜宇LDV|激光测振—机床铣刀寿命预测 新刀具为100hz主频 旧刀具为800hz主频 方法原理:…

C# OpenCvSharp DNN 部署YOLOV6目标检测

目录 效果 模型信息 项目 代码 下载 C# OpenCvSharp DNN 部署YOLOV6目标检测 效果 模型信息 Inputs ------------------------- name:image_arrays tensor:Float[1, 3, 640, 640] -------------------------------------------------------------…

被忽悠选择那些价格昂贵的知识付费平台?我有才知识服务平台手把手教你如何正确选择!

在当今的知识经济时代,一个高效、便捷的知识服务平台对于企业和个人至关重要。然而,市面上的众多知识服务平台中,许多产品存在高昂的费用、无用功能的堆砌、无法定制化等问题,让用户进退两难,甚至被忽悠掉入使用陷阱。…

Leo赠书活动-13期 【以企业架构为中心的SABOE数字化转型五环法】文末送书

Leo赠书活动-13期 【以企业架构为中心的SABOE数字化转型五环法】文末送书 ✅作者简介:大家好,我是Leo,热爱Java后端开发者,一个想要与大家共同进步的男人😉😉 🍎个人主页:Leo的博客…

力扣77. 组合(java 回溯法)

Problem: 77. 组合 文章目录 题目描述思路解题方法复杂度Code 题目描述 思路 题目要求给出1-n中每k个数一组的所有组合,我们可以利用回溯,将其穷举出来,具体的: 1.以数字1-n为回溯的决策阶段,回溯的起始阶段为1 2.回溯…

外汇天眼:掌握这个技巧,你也能成为交易高手

在金融市场这个大潮中,外汇交易因其高杠杆、24小时交易等特点吸引着无数交易者。然而成功的交易并非易事,对于投资者来说,外汇交易市场是一个复杂且多变的市场,要在外汇市场中获得成功就需要扎实的外汇金融基础知识和独特的策略&a…

「实用教程」win32spl.dll文件的作用及修复方法分享

本文将为您详细介绍Win32spl.dll文件的作用、丢失原因以及提供5个修复教程,帮助您解决这一问题。 一、Win32spl.dll文件的作用 Win32spl.dll是一个动态链接库文件,它是Windows操作系统中的一个重要组件。该文件主要负责处理系统启动时的一些操作&#…

【LeetCode刷题】-- 118.杨辉三角

118.杨辉三角 class Solution {public List<List<Integer>> generate(int numRows) {List<List<Integer>> res new ArrayList<List<Integer>>();for(int i 0; i < numRows;i){List<Integer> ret new ArrayList<>();for(…

Docker中部署ElasticSearch 和Kibana,用脚本实现对数据库资源的未授权访问

图未保存&#xff0c;不过文章当中的某一步骤可能会帮助到您&#xff0c;那么&#xff1a;感恩&#xff01; 1、docker中拉取镜像 #拉取镜像 docker pull elasticsearch:7.7.0#启动镜像 docker run --name elasticsearch -d -e ES_JAVA_OPTS"-Xms512m -Xmx512m" -e…

入职字节外包一个月,我离职了。。。

有一种打工人的羡慕&#xff0c;叫做“大厂”。 真是年少不知大厂香&#xff0c;错把青春插稻秧。 但是&#xff0c;在深圳有一群比大厂员工更庞大的群体&#xff0c;他们顶着大厂的“名”&#xff0c;做着大厂的工作&#xff0c;还可以享受大厂的伙食&#xff0c;却没有大厂…