当前位置:首页 >> 编程语言 >> 【PyTorch】数据集,小米2手机

【PyTorch】数据集,小米2手机

0evadmin 编程语言 3
文件名:【PyTorch】数据集,小米2手机 【PyTorch】数据集

文章目录 1. 创建数据集1.1. 直接继承Dataset类1.2. 使用TensorDataset类 2. 数据集的划分3. 加载数据集4. 将数据转移到GPU

1. 创建数据集

主要是将数据集读入内存,并用Dataset类封装。

1.1. 直接继承Dataset类

必须要重写__getitem__方法,用于根据索引获得相应样本数据。必要时还可以重写__len__方法,用于返回数据集的大小。

from torch.utils.data import Datasetclass BostonHousingDataset(Dataset):"""定义波士顿房价数据集"""def __init__(self):self.data = np.load('../dataset/boston_housing/boston_housing.npz')def __getitem__(self, index):return self.data['x'][index], self.data['y'][index]def __len__(self):return self.data['x'].shape[0] 1.2. 使用TensorDataset类

将多个张量组合成一个数据集,要保证所有张量的第一个维度相等,保证每批样本数据格式相同。

import torchfrom torch.utils.data import TensorDatasetdata = np.load('../dataset/boston_housing/boston_housing.npz')X = torch.tensor(data['x'])y = torch.tensor(data['y'])dataset = TensorDataset(X, y) 2. 数据集的划分

数据集可以划分为训练集、验证集和测试集。

训练集:用于模型拟合的数据样本集合。验证集:通常被用来调整模型的参数,以找出效果最佳的模型。测试集:用于训练好的模型性能评估的数据样本集合。 from torch.utils.data import random_splittrain_size = int(0.8 * len(dataset))test_size = len(dataset) - train_sizetrain_dataset, test_dataset = random_split(dataset, [train_size, test_size]) 3. 加载数据集

使用DataLoader类将Dataset封装的数据集分成批次并进行迭代,以便于模型训练。DataLoader常用参数如下:

dataset 要加载的数据集。batch_size 每个数据批次中包含的样本数。默认为1。shuffle 是否打乱数据集。默认为False。num_workers 使用几个进程来加载数据。默认为0,即在主进程中加载数据。drop_last 当数据集样本数不能被batch_size整除时,是否舍弃最后一个不完整的batch。默认为False。 from torch.utils.data import DataLoaderdataloader = DataLoader(dataset, batch_size=16, shuffle=True) 4. 将数据转移到GPU

一般在要运算时才将数据转移到GPU,有以下两种方法:

var.to(device)var.cuda() import torchdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")for X,y in dataloader:# 将数据转移到GPUX = X.to(device)y = y.to(device)# 也可以X = X.cuda()y = y.cuda()
协助本站SEO优化一下,谢谢!
关键词不能为空
同类推荐
«    2025年12月    »
1234567
891011121314
15161718192021
22232425262728
293031
控制面板
您好,欢迎到访网站!
  查看权限
网站分类
搜索
最新留言
文章归档
网站收藏
友情链接