美容行业网站建设方案,高质量的南昌网站建设,专做女鞋的网站代发广州,律师事务所网站建设策划方案更多Python学习内容#xff1a;ipengtao.com 大家好#xff0c;今天为大家分享一个无敌的 Python 库 - torchmetrics。 Github地址#xff1a;https://github.com/Lightning-AI/torchmetrics 在深度学习和机器学习项目中#xff0c;模型评估是一个至关重要的环节。为了准确… 更多Python学习内容ipengtao.com 大家好今天为大家分享一个无敌的 Python 库 - torchmetrics。 Github地址https://github.com/Lightning-AI/torchmetrics 在深度学习和机器学习项目中模型评估是一个至关重要的环节。为了准确地评估模型的性能开发者通常需要计算各种指标metrics如准确率、精确率、召回率、F1 分数等。torchmetrics 是一个用于 PyTorch 的开源库提供了一组方便且高效的评估指标计算工具。本文将详细介绍 torchmetrics 库包括其安装方法、主要特性、基本和高级功能以及实际应用场景帮助全面了解并掌握该库的使用。 安装 要使用 torchmetrics 库首先需要安装它。可以通过 pip 工具方便地进行安装。 以下是安装步骤 pip install torchmetrics 安装完成后可以通过导入 torchmetrics 库来验证是否安装成功 import torchmetrics
print(torchmetrics 库安装成功) 特性 广泛的指标支持提供多种评估指标包括分类、回归、图像处理和生成模型等领域的常用指标。模块化设计指标可以像模块一样轻松集成到 PyTorch Lightning 或任何 PyTorch 项目中。GPU 加速支持 GPU 加速能够高效处理大规模数据。易于扩展用户可以自定义指标并轻松集成到现有项目中。高效计算优化的计算方法确保在训练过程中实时计算指标性能开销最小。 基本功能 计算准确率 使用 torchmetrics 库可以方便地计算分类任务的准确率。 import torch
import torchmetrics# 创建 Accuracy 指标
accuracy torchmetrics.Accuracy()# 模拟预测和真实标签
preds torch.tensor([0, 2, 1, 3])
target torch.tensor([0, 1, 2, 3])# 计算准确率
acc accuracy(preds, target)
print(f准确率{acc}) 计算精确率和召回率 torchmetrics 库可以计算分类任务的精确率和召回率。 import torch
import torchmetrics# 创建 Precision 和 Recall 指标
precision torchmetrics.Precision(num_classes4)
recall torchmetrics.Recall(num_classes4)# 模拟预测和真实标签
preds torch.tensor([0, 2, 1, 3])
target torch.tensor([0, 1, 2, 3])# 计算精确率和召回率
prec precision(preds, target)
rec recall(preds, target)
print(f精确率{prec})
print(f召回率{rec}) 计算 F1 分数 torchmetrics 库还可以计算分类任务的 F1 分数。 import torch
import torchmetrics# 创建 F1 指标
f1 torchmetrics.F1(num_classes4)# 模拟预测和真实标签
preds torch.tensor([0, 2, 1, 3])
target torch.tensor([0, 1, 2, 3])# 计算 F1 分数
f1_score f1(preds, target)
print(fF1 分数{f1_score}) 高级功能 自定义指标 torchmetrics 库允许用户自定义指标以满足特定需求。 import torch
import torchmetricsclass CustomMetric(torchmetrics.Metric):def __init__(self):super().__init__()self.add_state(sum, defaulttorch.tensor(0), dist_reduce_fxsum)self.add_state(count, defaulttorch.tensor(0), dist_reduce_fxsum)def update(self, preds: torch.Tensor, target: torch.Tensor):self.sum torch.sum(preds target)self.count target.numel()def compute(self):return self.sum.float() / self.count# 创建自定义指标
custom_metric CustomMetric()# 模拟预测和真实标签
preds torch.tensor([0, 2, 1, 3])
target torch.tensor([0, 1, 2, 3])# 计算自定义指标
result custom_metric(preds, target)
print(f自定义指标结果{result}) 与 PyTorch Lightning 集成 torchmetrics 库可以无缝集成到 PyTorch Lightning 中简化指标计算流程。 import torch
import torchmetrics
import pytorch_lightning as pl
from torch import nnclass LitModel(pl.LightningModule):def __init__(self):super().__init__()self.model nn.Linear(10, 4)self.accuracy torchmetrics.Accuracy()def forward(self, x):return self.model(x)def training_step(self, batch, batch_idx):x, y batchpreds self(x)loss nn.functional.cross_entropy(preds, y)acc self.accuracy(preds, y)self.log(train_acc, acc)return lossdef configure_optimizers(self):return torch.optim.Adam(self.parameters(), lr0.001)# 示例数据
train_data torch.utils.data.TensorDataset(torch.randn(100, 10), torch.randint(0, 4, (100,)))
train_loader torch.utils.data.DataLoader(train_data, batch_size32)# 训练模型
model LitModel()
trainer pl.Trainer(max_epochs5)
trainer.fit(model, train_loader) GPU 加速 torchmetrics 库支持 GPU 加速可以在 GPU 上高效地计算指标。 import torch
import torchmetrics# 创建 Accuracy 指标并移动到 GPU
accuracy torchmetrics.Accuracy().cuda()# 模拟预测和真实标签并移动到 GPU
preds torch.tensor([0, 2, 1, 3]).cuda()
target torch.tensor([0, 1, 2, 3]).cuda()# 计算准确率
acc accuracy(preds, target)
print(f准确率{acc}) 实际应用场景 图像分类任务中的指标计算 在图像分类任务中需要计算各种评估指标如准确率、精确率、召回率等。 import torch
import torchmetrics
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader# 加载数据
transform transforms.Compose([transforms.ToTensor()])
train_data CIFAR10(root./data, trainTrue, downloadTrue, transformtransform)
train_loader DataLoader(train_data, batch_size32, shuffleTrue)# 创建模型和指标
model models.resnet18(num_classes10)
accuracy torchmetrics.Accuracy()# 训练模型并计算准确率
for inputs, targets in train_loader:outputs model(inputs)acc accuracy(outputs, targets)print(f批次准确率{acc}) 文本分类任务中的指标计算 在文本分类任务中需要计算评估指标如 F1 分数。 import torch
import torchmetrics
from transformers import BertTokenizer, BertForSequenceClassification# 加载模型和分词器
tokenizer BertTokenizer.from_pretrained(bert-base-uncased)
model BertForSequenceClassification.from_pretrained(bert-base-uncased)# 示例数据
texts [I love this!, This is bad.]
labels torch.tensor([1, 0])# 预处理数据
inputs tokenizer(texts, return_tensorspt, paddingTrue, truncationTrue)
outputs model(**inputs)# 创建 F1 指标
f1 torchmetrics.F1(num_classes2)# 计算 F1 分数
preds torch.argmax(outputs.logits, dim1)
f1_score f1(preds, labels)
print(fF1 分数{f1_score}) 生成对抗网络GAN中的指标计算 在生成对抗网络GAN的训练中需要计算生成图片的质量指标如 Frechet Inception DistanceFID。 import torch
import torchmetrics
from torchvision.models import inception_v3
from torchvision.transforms import transforms
from torch.utils.data import DataLoader, TensorDataset# 创建生成对抗网络GAN的生成器模型
class Generator(torch.nn.Module):def __init__(self):super(Generator, self).__init__()self.fc torch.nn.Linear(100, 128 * 7 * 7)self.deconv torch.nn.Sequential(torch.nn.ConvTranspose2d(128, 64, 4, stride2, padding1),torch.nn.BatchNorm2d(64),torch.nn.ReLU(True),torch.nn.ConvTranspose2d(64, 1, 4, stride2, padding1),torch.nn.Tanh())def forward(self, x):x self.fc(x).view(-1, 128, 7, 7)return self.deconv(x)# 创建生成器模型
generator Generator()# 创建 FID 指标
fid torchmetrics.image.fid.FrechetInceptionDistance(feature64)# 模拟生成图片和真实图片
latent_vectors torch.randn(100, 100)
generated_images generator(latent_vectors)
real_images torch.randn(100, 1, 28, 28)# 转换图片为 Inception V3 输入格式
transform transforms.Compose([transforms.Resize((299, 299)),transforms.Normalize(mean[0.5], std[0.5])
])
generated_images transform(generated_images)
real_images transform(real_images)# 创建 DataLoader
generated_loader DataLoader(TensorDataset(generated_images), batch_size32)
real_loader DataLoader(TensorDataset(real_images), batch_size32)# 计算 FID
for gen_batch, real_batch in zip(generated_loader, real_loader):fid.update(real_batch[0], gen_batch[0])fid_value fid.compute()
print(fFID 分数{fid_value}) 总结 torchmetrics 库是一个功能强大且易于使用的评估指标计算工具能够帮助开发者在深度学习和机器学习项目中高效地计算各种评估指标。通过支持广泛的指标、多种计算模式、GPU 加速和自定义扩展torchmetrics 库能够满足各种复杂的评估需求。本文详细介绍了 torchmetrics 库的安装方法、主要特性、基本和高级功能以及实际应用场景。希望本文能帮助大家全面掌握 torchmetrics 库的使用并在实际项目中发挥其优势。 如果你觉得文章还不错请大家 点赞、分享、留言 下因为这将是我持续输出更多优质文章的最强动力 如果想要系统学习Python、Python问题咨询或者考虑做一些工作以外的副业都可以扫描二维码添加微信围观朋友圈一起交流学习。 我们还为大家准备了Python资料和副业项目合集感兴趣的小伙伴快来找我领取一起交流学习哦 往期推荐 历时一个月整理的 Python 爬虫学习手册全集PDF免费开放下载 Python基础学习常见的100个问题.pdf附答案 学习 数据结构与算法这是我见过最友好的教程(PDF免费下载) Python办公自动化完全指南(免费PDF) Python Web 开发常见的100个问题.PDF 肝了一周整理了Python 从0到1学习路线附思维导图和PDF下载