【kaggle代码】Plant Seedlings Classification (使用Resnet-18完成分类任务)

比赛地址:植物种子分类

注意的点:

  1. 使用datasets.ImageFolder读取数据,并且制作数据集。分类任务与图像分割任务不同。分类任务的数据是:【图片,标签(字符串类型)】,所以两者的数据读取方式不同。在分割任务中,常常需要重写Dataset便于图像预处理,而在该分类任务中,不需要重写Dataset,在datasets.ImageFolder中,可以接收transform参数对读入的图像进行处理,而不对标签(字符串)处理,且会将标签自动转为标签索引形式。关于datasets.ImageFolde

  2. torch的Dataloader接受的是(data, labels)的元组形式,在 PyTorch 的 DataLoader 中,元组列表中元素的数据类型要求相对较松。每个元组的第一个元素通常是输入数据,第二个元素是对应的标签。这两个元素可以是任何 PyTorch 支持的数据类型,例如张量(torch.Tensor)、NumPy 数组、PIL 图像等。

  3. 对于使用预训练好的Resnet-18,可以通过更改网络最后一层,来适应该分类任务。对于很多模型,model.fc 是最后一层的全连接层。
    在这里插入图片描述

  4. 在这个比赛中,最初得分总是很低。最后发现原因是:在提交submission中,图片名称是按照顺序读入的,但是在使用Dataloader读入测试集数据时,使用了shuffle=True,导致读入的顺序被打乱,从而使得图片名称和预测标签不对应,导致得分很低。改为shuffle=Flase问题解决。

代码,按照ipynb顺序排列:

# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to loadimport numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session
import numpy as np
import os
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
from torch import optim
from torch import nn
import cv2 
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.utils.data import random_split
from tqdm import tqdm
import imageio
from torchvision import datasets
from PIL import Image # Image模块是在Python PIL图像处理中常见的模块,对图像进行基础操作的功能基本都包含于此模块内。
work_dir = '/kaggle/input/plant-seedlings-classification'
os.listdir(work_dir)
`import glob
#读取数据,用于后续制作数据集
train_path = os.path.join(work_dir,'train')# 使用glob列出train文件夹下的所有文件夹
folders = glob.glob(os.path.join(train_path, '*'))print(f'总的类别数量:{len(folders)}')`````python
# values from ImageNet, recommended by PyTorch
transform_mean = [0.485, 0.456, 0.406]
transform_std = [0.229, 0.224, 0.225]transforms = transforms.Compose([transforms.Resize((224, 224)),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=transform_mean, std=transform_std),
])dataset = datasets.ImageFolder(root=train_path,transform=transforms)# self.classes:用一个 list 保存类别名称
# self.class_to_idx:类别对应的索引,与不做任何转换返回的 target 对应
# self.imgs:保存(img-path, class) tuple的 list
#查看有多少个样例和多少个类别
print('samples',len(dataset))
print('classes',len(dataset.classes))
print(dataset)
print(dataset.classes)
print(dataset.class_to_idx)
print(dataset.imgs[0])
label_counts = []# 遍历每个类别的文件夹
for d in glob.glob(os.path.join(train_path, '*')):
#glob.glob 返回一个包含匹配指定模式的所有文件或文件夹的列表。在这里,它返回了所有子文件夹的路径列表。    # 获取类别名称label = os.path.basename(d)# 计算该类别中图像的数量count = len(glob.glob(os.path.join(d, '*')))# 将类别名称和图像数量添加到列表中label_counts.append({'label': label, 'count': count})# 创建一个 Pandas DataFrame
label_counts_df = pd.DataFrame(label_counts)# 打印 DataFrame
print(label_counts_df)
## 划分训练集和验证集
train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [3750, 1000])
print(len(train_dataset))
print(len(valid_dataset))
#DataLoader 返回的是一个迭代器(iterator),每次迭代都会产生一个包含小批次数据的元组。
#这个元组的内容取决于你在创建 DataLoader 时指定的数据集的格式。
#通常情况下,这个元组包含两个元素,分别是输入数据和对应的标签。
#例如,如果你的数据集是一个 TensorDataset,那么每个小批次的元组就是 (inputs, targets)。# 这里是分类任务,和分割任务不同。
#dataset使用ImageFolder就对image已经进行了transform,而label使用的是索引(0,1...),所以不需要重写Dataloadertrain_loader = DataLoader(train_dataset,batch_size=16,shuffle=True,num_workers=4)
valid_loader = DataLoader(valid_dataset,batch_size=16,shuffle=True,num_workers=4)
train_features_batch, train_labels_batch = next(iter(train_loader))print(train_features_batch.shape, train_labels_batch.shape)
print(train_features_batch[0])a
print(train_labels_batch)
import torchvision.models as models#修改最后一层
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, len(dataset.classes), device=device)
# 模型中的关键部分:
# model.features:# 这是模型的特征提取部分,通常包含卷积层和池化层。
# model.avgpool:# 模型中的平均池化层,用于对特征进行全局平均池化。
# model.classifier:# 这是模型的分类部分,通常包含全连接层。# model.fc:# 对于很多模型,model.fc 是最后一层的全连接层。
model = model.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=1e-4)
# 假设输入大小为 (batch_size, num_channels, height, width)
# 这里使用随机数据,你需要根据你的模型和数据进行适当的调整
batch_size = 1
num_channels = 3  # 通常是3,表示RGB图像
height, width = 224, 224  # 这可能需要根据你的数据集进行调整# 创建随机输入数据
random_input = torch.randn(batch_size, num_channels, height, width)# 将数据移动到设备(GPU或CPU)
random_input = random_input.to(device)# 模型推断with torch.no_grad():output = model(random_input)# 打印模型输出
print("Model Output Shape:", output.shape)
print("Model Output Values:", output)
# 训练
epochs=10
train_loss_all = [] #定义一个列表用于保存总的训练集loss,方便后续打印
val_loss_all =[]   #定义一个列表用于保存总的验证集loss,方便后续打印
best_loss = 1e10   #记录最佳的loss
for epoch in range(epochs):train_loss =0val_loss =0train_num=0val_num =0correct = 0model.train()loop = tqdm(train_loader)for idx,(image,label) in enumerate(loop):image = image.to(device)label = label.to(device)optimizer.zero_grad()pre_lab = torch.argmax(output,1)output = model(image)loss = loss_fn(output ,label)loss.backward()optimizer.step() #梯度更新train_loss +=loss.item()train_num +=1train_loss_all.append(train_loss / train_num)print('{} *****Train Loss:{:.4f}'.format(epoch,train_loss_all[-1]))with torch.no_grad():loop=tqdm(valid_loader)for idx,(image,label) in enumerate(loop):image = image.to(device)label = label.to(device)output = model(image)pre_lab = torch.argmax(output,1)loss = loss_fn(output,label)val_loss += loss.item()val_num +=1correct += (pre_lab == label.data).sum().item()correct /= len(valid_loader.dataset)val_loss_all.append(val_loss / val_num)print(f'{epoch} *****Valid Loss:{val_loss_all[-1]:.4f}  Accuracy={(100 * correct):>0.1f}%')##保存模型if val_loss_all[-1] < best_loss :best_loss = val_loss_all[-1]check_points = model.state_dict()torch.save(check_points, '/kaggle/working/BestSave.pt')
#可视化模型训练过程中的loss曲线
epochs = list(range(1, 11))  # 或者任何你实际的 epochs 数量plt.figure(figsize=(10,6))
plt.plot(epochs,train_loss_all,"ro-",label = "Train Loss")
plt.plot(epochs,val_loss_all,"bs-",label = "Valid Loss")
plt.legend()
plt.xlabel("epoch")
plt.ylabel("Loss")
plt.show()
#单张图像查看from PIL import Image # Image模块是在Python PIL图像处理中常见的模块,对图像进行基础操作的功能基本都包含于此模块内。#读取数据,用于后续制作数据集
tem_path = os.path.join(train_path,'Black-grass')# 使用glob列出train文件夹下的所有文件夹
fold = glob.glob(os.path.join(tem_path, '*'))print(fold[0])input_image = Image.open(fold[0])print(input_image.size)input_tensor = transforms(input_image).unsqueeze(0).to(device)  # 添加 batch 维度
print(input_tensor.shape)model = model.to(device)
model.load_state_dict(torch.load('/kaggle/working/BestSave.pt'))
input_tensor = input_tensor.to(device)
with torch.no_grad():output = model(input_tensor)
print(output)
print(torch.argmax(output,1))
#Dataloader默认是返回(输入数据,标签),但是测试集中没有标签,故重写一个Dataloaderclass TestDataset(Dataset):def __init__(self,test_path,transform=None):self.test_path = test_pathself.test_images = os.listdir(self.test_path)self.transform = transformdef __len__(self):return len(self.test_images)def __getitem__(self,idx):self.image_path = os.path.join(self.test_path,os.listdir(self.test_path)[idx])img = Image.open(self.image_path)if self.transform is not None:img = self.transform(img)return img
#单张图像进行验证
import glob
#读取数据,用于后续制作数据集
test_path = os.path.join(work_dir,'test')# 使用glob列出train文件夹下的所有文件夹
folders = glob.glob(os.path.join(test_path, '*'))
print(folders[:2])
print(folders[1])

from torchvision import transforms
test_transforms = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(transform_mean, transform_std)
])test_dataset = TestDataset(test_path, transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False,num_workers=4)print('Test:', len(test_dataset), 'samples')
from tqdm import tqdm
labels = []
model = model.to(device)
model.load_state_dict(torch.load('/kaggle/working/BestSave.pt'))
model.eval()with torch.no_grad():loop = tqdm(test_loader)for idx ,(image)in enumerate(loop):image = image.to(device)output = model(image)preds = torch.argmax(output,1)labels.extend(preds.cpu().numpy().tolist())species = [dataset.classes[label] for label in labels]submission = pd.DataFrame({'file': os.listdir(test_path), 'species': species})
submission.to_csv('submission.csv', index=False)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/526941.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ArrayDeque集合源码分析

ArrayDeque集合源码分析 文章目录 ArrayDeque集合源码分析一、字段分析二、构造函数分析方法、方法分析四、总结 实现了 Deque&#xff0c;说面该数据结构一定是个双端队列&#xff0c;我们知道 LinkedList 也是双端队列&#xff0c;并且是用双向链表 存储结构的。而 ArrayDequ…

【Redis知识点总结】(二)——Redis高性能IO模型剖析

Redis知识点总结&#xff08;二&#xff09;——Redis高性能IO模型及其事件驱动框架剖析 IO多路复用传统的阻塞式IO同步非阻塞IOIO多路复用机制 Redis的IO模型Redis的事件驱动框架 IO多路复用 Redis的高性能的秘密&#xff0c;在于它底层使用了IO多路复用这种高性能的网络IO&a…

2024年装修新潮流,你get到了吗?福州中宅装饰,福州装修

在装修这个行业&#xff0c;每年都会出现一些新的设计理念和流行趋势&#xff0c;同时也存在一些传统的设计理念。今天&#xff0c;我们就来对比一下2024年装修设计的传统与新趋势。 传统设计理念 1. 落地电视柜 在过去&#xff0c;落地电视柜被认为是一种实用的设计&#xf…

React-路由小知识

1.默认路由 说明&#xff1a;当访问的是一级路由时&#xff0c;默认的二级路由组件可以得到渲染&#xff0c;只需要在二级路由的位置去掉path,设置index.属性为true。 2.404路由 说明&#xff1a;当浏览器输入ul的路径在整个路由配置中都找不到对应的pth,为了用户体验&#x…

深入理解React中的useReducer:管理复杂状态逻辑的利器

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…

Coggle数据科学 | 小白学数据科学:20个技术和框架(建议收藏!)

本文来源公众号“Coggle数据科学”&#xff0c;仅用于学术分享&#xff0c;侵权删&#xff0c;干货满满。 原文链接&#xff1a;小白学数据科学&#xff1a;20个技术和框架 随着互联网使用率的增长&#xff0c;公司如何利用数据进行创新和获得竞争优势。截至2024年1月&#x…

Qt 实现诈金花的牌面值分析工具

诈金花是很多男人最爱的卡牌游戏 , 每当你拿到三张牌的时候, 生活重新充满了期待和鸟语花香. 那么我们如果判断手中的牌在所有可能出现的牌中占据的百分比位置呢. 这是最终效果: 这是更多的结果: 在此做些简单的说明: 炸弹(有些地方叫豹子) > 同花顺 > 同花 > 顺…

css--浮动

一. 浮动的简介 在最初&#xff0c;浮动是用来实现文字环绕图片效果的&#xff0c;现在浮动是主流的页面布局方式之一。 二. 元素浮动后的特点 &#x1f922;脱离文档流。&#x1f60a;不管浮动前是什么元素&#xff0c;浮动后&#xff1a;默认宽与高都是被内容撑开&#xff0…

[HackMyVM]靶场 Espo

kali:192.168.56.104 主机发现 arp-scan -l # arp-scan -l Interface: eth0, type: EN10MB, MAC: 00:0c:29:d2:e0:49, IPv4: 192.168.56.104 Starting arp-scan 1.10.0 with 256 hosts (https://github.com/royhills/arp-scan) 192.168.56.1 0a:00:27:00:00:05 (Un…

manjaro 安装 wps 教程

内核: Linux 6.6.16.2 wps-office版本&#xff1a; 11.10.11719-1 本文仅作为参考使用, 如果以上版本差别较大不建议参考 安装wps主体 yay -S wps-office 安装wps字体 &#xff08;如果下载未成功看下面的方法&#xff09; yay -S ttf-waps-fonts 安装wps中文语言 yay …

ZJUBCA研报分享 | 《BTC/USDT周内效应研究》

ZJUBCA研报分享 引言 2023 年 11 月 — 2024 年初&#xff0c;浙大链协顺利举办为期 6 周的浙大链协加密创投训练营 &#xff08;ZJUBCA Community Crypto VC Course&#xff09;。在本次训练营中&#xff0c;我们组织了投研比赛&#xff0c;鼓励学员分析感兴趣的 Web3 前沿话题…

kali当中不同的python版本切换(超简单)

kali当中本身就是自带两个python版本的 配置 update-alternatives --install /usr/bin/python python /usr/bin/python2 100 update-alternatives --install /usr/bin/python python /usr/bin/python3 150 切换版本 update-alternatives --config python 0 1 2编号选择一个即可…