基于DETR的人脸伪装检测
- 前言
- 前提条件
- 实验环境
- 项目地址
- Linux
- Windows
- DETR
- 训练自己的数据集
- 修改models/detr.py中的参数
- 进行训练
- 进行预测
- 相关资源免费获取
- 参考
前言
- 本文是个人使用DETR训练自己的COCO格式数据集的应用案例,由于水平有限,难免出现错漏,敬请批评改正。
- 更多精彩内容,可点击进入YOLO系列专栏或我的个人主页查看
- YOLOv5:添加SE、CBAM、CoordAtt、ECA注意力机制
- YOLOv5:yolov5s.yaml配置文件解读、增加小目标检测层
- YOLOv5:IoU、GIoU、DIoU、CIoU、EIoU
- YOLOv7训练自己的数据集(口罩检测)
- YOLOv8训练自己的数据集(足球检测)
- 玩转Jetson Nano(五):TensorRT加速YOLOv5目标检测
前提条件
- 熟悉Python
实验环境
cython
git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI&egg=pycocotools
submitit
torch>=1.5.0
torchvision>=0.6.0
git+https://github.com/cocodataset/panopticapi.git#egg=panopticapi
scipy
onnx
onnxruntime
项目地址
DETR官方源代码地址:https://github.com/facebookresearch/detr.git
Linux
git clone https://github.com/facebookresearch/detr.git
Cloning into 'yolov8'...
remote: Enumerating objects: 4583, done.
remote: Counting objects: 100% (4583/4583), done.
remote: Compressing objects: 100% (1270/1270), done.
remote: Total 4583 (delta 2981), reused 4576 (delta 2979), pack-reused 0
Receiving objects: 100% (4583/4583), 23.95 MiB | 1.55 MiB/s, done.
Resolving deltas: 100% (2981/2981), done.
Windows
请到
https://github.com/facebookresearch/detr.git
网站下载源代码zip压缩包。
DETR
- DETR是Facebook提出的基于Transformer的端到端目标检测网络。DETR做到了真正没有非最大抑制(NMS)后处理,而且不需要anchor(锚点生成)。但是,训练时间较长,对小目标的检测性能不是很高。建议使用可变形注意模块(deformable attention module)代替原始的多头注意力来关注参考点周围的关键位置。
- DETR论文地址:https://arxiv.org/abs/2005.12872
- DETR官方源代码地址:https://github.com/facebookresearch/detr.git
训练自己的数据集
修改models/detr.py中的参数
- num_class需要设置为max_id+1,
- 比如本文使用的人脸伪装数据集,索引从0到7,那么num_class应该设置为7+1=8,索引为8的类为背景类。
- 又比如,有些数据集,索引从1到20,那么num_class应该设置为20+1=21,索引为21的类为背景类,但是因为索引从1开始,所以把索引为0的类设置为N/A,既不是背景也不是前景,应该是缺失类。
- 作者举例4个类别的索引分别为1,23,24,56,那么num_class应该设置为 56+1 = 57,索引为57的类为背景类。其中缺失索引值:0、2-22、25-55应该用N/A填充,都是缺失类。
# origin
# num_classes = 20 if args.dataset_file != 'coco' else 91
# alter_my [num_classes = (max_obj_id + 1)]
num_classes = 8 if args.dataset_file != 'coco' else 8
进行训练
python main.py --output_dir ./weights --coco_path ../datasets/face_guise_datasets/ --epochs 100 --resume detr_r50_8.pth
进行预测
新建一个pre_img.py,内容如下:
import numpy as np
from models.detr import build
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transformstorch.set_grad_enabled(False)
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
transform_input = transforms.Compose([transforms.Resize(800),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])def box_cxcywh_to_xyxy(x):x_c, y_c, w, h = x.unbind(1)b = [(x_c - 0.5 * w), (y_c - 0.5 * h),(x_c + 0.5 * w), (y_c + 0.5 * h)]return torch.stack(b, dim=1)def rescale_bboxes(out_bbox, size):img_w, img_h = sizeb = box_cxcywh_to_xyxy(out_bbox)b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32, device="cuda")return bdef plot_results(pil_img, prob, boxes, img_save_path):plt.figure(figsize=(16, 10))plt.imshow(pil_img)ax = plt.gca()colors = COLORS * 100for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,fill=False, color=c, linewidth=3))cl = p.argmax()text = f'{CLASSES[cl]}: {p[cl]:0.2f}'ax.text(xmin, ymin, text, fontsize=9,bbox=dict(facecolor='yellow', alpha=0.5))plt.savefig(img_save_path)plt.axis('off')plt.show()def main(chenkpoint_path, img_path, img_save_path):args = torch.load(chenkpoint_path)['args']model = build(args)[0]device = "cuda" if torch.cuda.is_available() else "cpu"model.to(device)# 加载模型参数model_data = torch.load(chenkpoint_path)['model']model.load_state_dict(model_data)model.eval()img = Image.open(img_path).convert('RGB')size = img.sizeinputs = transform_input(img).unsqueeze(0)outputs = model(inputs.to(device))# 这类最后[0, :, :-1]索引其实是把背景类筛选掉了probs = outputs['pred_logits'].softmax(-1)[0, :, :-1]# 可修改阈值,只输出概率大于0.7的物体keep = probs.max(-1).values > 0.7bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], size)# 保存输出结果ori_img = np.array(img)plot_results(ori_img, probs[keep], bboxes_scaled, img_save_path)if __name__ == "__main__":# CLASSES = ['N/A', "aeroplane", "bicycle", "bird", "boat",# "bottle", "bus", "car", "cat", "chair",# "cow", "diningtable", "dog", "horse",# "motorbike", "person", "pottedplant",# "sheep", "sofa", "train", "tvmonitor", "background"]CLASSES = ['glasses', "hat", "nothing", "glasses_hat", "glasses_mask", "hat_mask", "glasses_hat_mask", "mask", "background"]main(chenkpoint_path="weights/checkpoint.pth", img_path="test.jpg",img_save_path="result.jpg")
python pre_img.py
相关资源免费获取
人脸伪装数据集
- 地址:https://download.csdn.net/download/FriendshipTang/88038140
预训练权重:detr_r50_8.pth
- 地址:https://download.csdn.net/download/FriendshipTang/88038804
本文源码
- 地址:https://download.csdn.net/download/FriendshipTang/88038809
注:如资源地址失效,请私信我!
参考
[1] Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko. End-to-End Object Detection with Transformers. 2020
[2] DETR 源代码地址. https://github.com/facebookresearch/detr.git
[3] https://blog.csdn.net/m0_46412065/article/details/128538040
- 更多精彩内容,可点击进入YOLO系列专栏或我的个人主页查看
- YOLOv5:添加SE、CBAM、CoordAtt、ECA注意力机制
- YOLOv5:yolov5s.yaml配置文件解读、增加小目标检测层
- YOLOv5:IoU、GIoU、DIoU、CIoU、EIoU
- YOLOv7训练自己的数据集(口罩检测)
- YOLOv8训练自己的数据集(足球检测)
- 玩转Jetson Nano(五):TensorRT加速YOLOv5目标检测