【PyTorch】(二)加载数据集,秘密结社鹰之爪(pytorch加载自己的数据集)
0evadmin
编程语言
3
文件名:【PyTorch】(二)加载数据集,秘密结社鹰之爪
【PyTorch】(二)加载数据集
文章目录 1. 通用方法 1. 通用方法 创建数据集 主要是将数据集读入内存,并用Dataset类封装。直接继承Dataset类的自定义数据集必须要重写__getitem__方法,用于根据索引获得相应样本数据。必要时还可以重写__len__方法,用于返回数据集的大小。加载数据集 使用DataLoader类将Dataset封装的数据集分成批次并进行迭代,以便于模型训练。DataLoader常用参数如下: dataset 要加载的数据集。batch_size 每个数据批次中包含的样本数。默认为1。shuffle 是否打乱数据集。默认为False。num_workers 使用几个进程来加载数据。默认为0,即在主进程中加载数据。drop_last 当数据集样本数不能被batch_size整除时,是否舍弃最后一个不完整的batch。默认为False。 将数据转移到GPU 可以使用方法:变量.to(device)。可以使用方法:变量.cuda(0)。 import torchfrom torch.utils.data import Dataset, DataLoaderimport numpy as npclass BostonHousingDataset(Dataset):"""定义波士顿房价数据集"""def __init__(self):self.data = np.load('../dataset/boston_housing/boston_housing.npz')def __getitem__(self, index):return self.data['x'][index], self.data['y'][index]def __len__(self):return self.data['x'].shape[0]dataset = BostonHousingDataset()dataloader = DataLoader(dataset, batch_size=16, shuffle=True)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")for X,y in dataloader:# 将数据转移到GPUX = X.to(device)y = y.to(device)# 也可以X = X.cuda(0)y = y.cuda(0)
同类推荐