【Image captioning】ruotianluo/self-critical.pytorch之1—数据集的加载与使用

【Image captioning】ruotianluo/self-critical.pytorch之1—数据集的加载与使用

作者:安静到无声 个人主页

数据加载程序示意图

数据集程序

使用方法

示例代码

#%%from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriterimport numpy as npimport time
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0' ##from six.moves import cPickle
import traceback
from collections import defaultdictimport captioning.utils.opts as opts
import captioning.models as models
from captioning.data.dataloader import *
import skimage.io
import captioning.utils.eval_utils as eval_utils
import captioning.utils.misc as utils
from captioning.utils.rewards import init_scorer, get_self_critical_reward
from captioning.modules.loss_wrapper import LossWrapperimport sys
sys.path.append("..")
import time
#%%opt = opts.parse_opt()
opt.input_json = '/home/lihuanyu/code/07ImageCaptioning/data/cocotalk.json'
opt.input_label_h5 = '/home/lihuanyu/code/07ImageCaptioning/data/cocotalk_label.h5'
opt.input_fc_dir = '/home/lihuanyu/code/07ImageCaptioning/data/cocotalk_fc'
opt.input_att_dir = '/home/lihuanyu/code/07ImageCaptioning/data/cocotalk_att'
opt.batch_size = 1
opt.train_only = 1opt.use_att = True
opt.use_att = True
opt.use_box = 0#%%
print(opt.input_json)
print(opt.batch_size)  #批量化为16
loader = DataLoader(opt)  # 数据加载
#打印字内容
#print(loader.get_vocab())  #返回字典
for i in range(2):data = loader.get_batch('train')print('———————————————————※输入的信息特征※——————————————————')  #[1,2048] 全连接特征print('全连接特征【fc_feats】的形状:', data['fc_feats'].shape)  #[1,2048] 全连接特征print('全连接特征【att_feats】的形状:', data['att_feats'].shape)  #[1,2048] 注意力特征print('att_masks', data['att_masks'])print('含有的信息infos:', data['infos'])  #infos [{'ix': 117986, 'id': 495956, 'file_path': 'train2014/COCO_train2014_000000495956.jpg'}]print('———————————————————※标签信息※——————————————————')  #[1,2048] 全连接特征print('labels', data['labels'])     #添加了一些0print('gts:', data['gts'])          #没有添加的原始句子print('masks', data['masks'])print('———————————————————※记录遍历的位置※——————————————————')  #[1,2048] 全连接特征print('bounds', data['bounds'])time.sleep(1)print(data.keys())

输出结果:

Hugginface transformers not installed; please visit https://github.com/huggingface/transformers
meshed-memory-transformer not installed; please run `pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git`
Warning: coco-caption not available
cider or coco-caption missing
/home/lihuanyu/code/07ImageCaptioning/data/cocotalk.json
1
是否使用【注意力特征[use_fc]: True
是否使用【注意力特征[use_att]: True
是否在注意力特征中使用【检测框特征[use_box]: 0
DataLoader loading json file:  /home/lihuanyu/code/07ImageCaptioning/data/cocotalk.json
vocab size is  9487
DataLoader loading h5 file:  /home/lihuanyu/code/07ImageCaptioning/data/cocotalk_fc /home/lihuanyu/code/07ImageCaptioning/data/cocotalk_att data/cocotalk_box /home/lihuanyu/code/07ImageCaptioning/data/cocotalk_label.h5
max sequence length in data is 16
read 123287 image features
assigned 82783 images to split train(训练集有多少图片)
assigned 5000 images to split val(验证集有多少图片)
assigned 5000 images to split test(测试集有多少图片)
———————————————————※输入的信息特征※——————————————————
全连接特征【fc_feats】的形状: torch.Size([1, 2048])
全连接特征【att_feats】的形状: torch.Size([1, 196, 2048])
att_masks None
含有的信息infos: [{'ix': 60494, 'id': 46065, 'file_path': 'train2014/COCO_train2014_000000046065.jpg'}]
———————————————————※标签信息※——————————————————
labels tensor([[[   0,    1,  271,   17, 7068,   35,   98,    6,    1,  102,    3,912,    0,    0,    0,    0,    0,    0],[   0,  995, 2309, 2308,  609,    6,    1,  271,  119,  912,    0,0,    0,    0,    0,    0,    0,    0],[   0, 2309, 9487,  179,   98,    6,    1,   46,  271,    0,    0,0,    0,    0,    0,    0,    0,    0],[   0,  182,   35,  995, 7068,    6,    1,  271,    3,   60,  678,32,   14,   29,    0,    0,    0,    0],[   0,  995,  915,   17, 2309, 3130,    6,    1,   46,  271,    0,0,    0,    0,    0,    0,    0,    0]]])
gts: [array([[   1,  271,   17, 7068,   35,   98,    6,    1,  102,    3,  912,0,    0,    0,    0,    0],[ 995, 2309, 2308,  609,    6,    1,  271,  119,  912,    0,    0,0,    0,    0,    0,    0],[2309, 9487,  179,   98,    6,    1,   46,  271,    0,    0,    0,0,    0,    0,    0,    0],[ 182,   35,  995, 7068,    6,    1,  271,    3,   60,  678,   32,14,   29,    0,    0,    0],[ 995,  915,   17, 2309, 3130,    6,    1,   46,  271,    0,    0,0,    0,    0,    0,    0]], dtype=uint32)]
masks tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0.,0.],[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.,0.],[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.,0.],[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.,0.],[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.,0.]]])
———————————————————※记录遍历的位置※——————————————————
bounds {'it_pos_now': 1, 'it_max': 82783, 'wrapped': False}
dict_keys(['fc_feats', 'att_feats', 'att_masks', 'labels', 'masks', 'gts', 'bounds', 'infos'])
———————————————————※输入的信息特征※——————————————————
全连接特征【fc_feats】的形状: torch.Size([1, 2048])
全连接特征【att_feats】的形状: torch.Size([1, 196, 2048])
att_masks None
含有的信息infos: [{'ix': 106440, 'id': 151264, 'file_path': 'train2014/COCO_train2014_000000151264.jpg'}]
———————————————————※标签信息※——————————————————
labels tensor([[[   0,    1,  230,    6,   14,  230,  237,   32, 1086,  627,    0,0,    0,    0,    0,    0,    0,    0],[   0,    1, 6035,  230,   35,  274,  127,  225, 1598,  335,    1,940,    0,    0,    0,    0,    0,    0],[   0,    1,  230,   35,  900,   32,  307,  756,   61,  607,    0,0,    0,    0,    0,    0,    0,    0],[   0,    1,  230,   35,   98,   79,    1,  230,  224,    0,    0,0,    0,    0,    0,    0,    0,    0],[   0,    1,   46, 1109,  230, 1596,  245,    1,  224,    0,    0,0,    0,    0,    0,    0,    0,    0]]])
gts: [array([[   1,  230,    6,   14,  230,  237,   32, 1086,  627,    0,    0,0,    0,    0,    0,    0],[   1, 6035,  230,   35,  274,  127,  225, 1598,  335,    1,  940,0,    0,    0,    0,    0],[   1,  230,   35,  900,   32,  307,  756,   61,  607,    0,    0,0,    0,    0,    0,    0],[   1,  230,   35,   98,   79,    1,  230,  224,    0,    0,    0,0,    0,    0,    0,    0],[   1,   46, 1109,  230, 1596,  245,    1,  224,    0,    0,    0,0,    0,    0,    0,    0]], dtype=uint32)]
masks tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.,0.],[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0.,0.],[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.,0.],[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.,0.],[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.,0.]]])
———————————————————※记录遍历的位置※——————————————————
bounds {'it_pos_now': 2, 'it_max': 82783, 'wrapped': False}
dict_keys(['fc_feats', 'att_feats', 'att_masks', 'labels', 'masks', 'gts', 'bounds', 'infos'])

推荐专栏

🔥 手把手实现Image captioning

💯CNN模型压缩

💖模式识别与人工智能(程序与算法)

🔥FPGA—Verilog与Hls学习与实践

💯基于Pytorch的自然语言处理入门与实践

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/62822.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL SUBSTRING_INDEX() 函数的详细介绍

MySQL SUBSTRING_INDEX() 从给定字符串中返回指定数量的分隔符出现之前的子字符串。 当指定数字为正数时从最终分隔符的左侧返回子字符串,当指定数字为负数时从最终分隔符的右侧返回子字符串。 如果指定的次数大于分隔符的出现次数,则返回的子字符串将…

从零开始搭建个人博客网站(hexo框架)

1.工具及环境搭建 1)注册GitHub并且新建一个repositories 2)下载node.js以及Git 下载链接: 检验安装是否成功: 【注】:MacOS自带Git,可以直接在终端输入git --version进行检验 3)新建一个…

基于rsesnet网络架构的图像分类模型

数据预处理部分: 数据增强:torchvision中transforms模块自带功能,比较实用数据预处理:torchvision中transforms也帮我们实现好了,直接调用即可DataLoader模块直接读取batch数据 网络模块设置: 加载预训练…

AtCoder Beginner Contest 313D题题解

文章目录 [ Odd or Even](https://atcoder.jp/contests/abc313/tasks/abc313_d)问题建模问题分析1.分析每次查询的作用2.利用异或运算的性质设计查询方法 Odd or Even 问题建模 有n个数,每个数为0或者1,最多可以进行n次询问,每次询问选择k个…

UE4/5 GAS技能系统入门3 - GameplayEffect

阅读本文需要上一篇AttributeSet的基础知识: https://blog.csdn.net/grayrail/article/details/132148492 本文也并非教程性质文章,主要讲解学习记录为主。 这篇开始讲AttributeSet配置好后,GameplayEffect的使用。 1.将GE配置至Ability Co…

关于新手学习STM32开发应该如何入门?

对于新手来说,学习STM32开发可能会感到困惑,尤其是在拿到开发板后该如何入门。在这里有嵌入式学习路线,毕设,各种项目,需要留个6。以下是部分内容概述:硬件介绍:了解STM32开发板的基本硬件组成和…

平替 Docker - 玩转容器新利器 Podman Desktop (视频)

《OpenShift 4.x HOL教程汇总》 在 podman-desktop 1.2.1 podman 4.4 环境中验证。 文章目录 什么是 podman 和 podman-desktop安装 podman 和 podman-desktop 基本环境Image、Container 和 Pod 的基本操作拉取 Image运行 Container 将 Pod 部署到 Kubernetes安装 Kind 扩展插…

mysql高级三:sql性能优化+索引优化+慢查询日志

内容介绍 单表索引失效案例 0、思考题:如果把100万数据插入MYSQL ,如何提高插入效率 (1)关闭自动提交,只手动提交一次 (2)删除除主键索引外其他索引 (3)拼写mysql可以执…

Blazor 简单组件(2):B_row/B_col 12分隔布局 简单开发

文章目录 前言12分隔布局开发B_col.razorB_col.razor.cssB_row.razorB_row.razor.css 使用案例 前言 Blazor 简单组件(0)&#xff1a;简单介绍 12分隔布局开发 B_col.razor if (Offset ! "0") {<div style" grid-column-start: span (Offset)">&l…

系统架构设计师-软件架构设计(7)

目录 大型网站系统架构演化 一、第一阶段&#xff1a;单体架构 到 第二阶段&#xff1a;垂直架构 二、第三阶段&#xff1a;使用缓存改善网站性能 1、缓存与数据库的数据一致性问题 2、缓存技术对比【MemCache与Redis】 3、Redis分布式存储方案 4、Redis集群切片的常见方式 …

【css】textarea-通过resize:none 禁止拖动设置大小

使用 resize 属性可防止调整 textareas 的大小&#xff08;禁用右下角的“抓取器”&#xff09;&#xff1a; 没有设置resize:none 代码&#xff1a; <!DOCTYPE html> <html> <head> <style> textarea {width: 100%;height: 150px;padding: 12px 20p…