【Pytorch库】自定义数据集相关的类
- 开源代码
- 2025-09-03 22:09:02

torch.utils.data.Dataset 类torch.utils.data.DataLoader 类自定义数据集示例1. 自定义 Dataset 类2. 在其他 .py 文件中引用和使用该自定义 Dataset torch_geometric.data.Dataset 类torch_geometric.data.Dataset VS torch.utils.data.Dataset
详细信息,参阅 torch.utils.data 文档页面 写得很棒的文章:PyTorch加载自己的数据集
在 PyTorch 数据加载工具的核心是 torch.utils.data.DataLoader 类。它表示一个 Python 可迭代对象,用于遍历数据集,并支持以下功能:
映射式和可迭代式数据集自定义数据加载顺序自动批量处理单进程和多进程数据加载自动内存固定这些选项由 DataLoader 构造函数的 参数 配置,如下:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, *, prefetch_factor=2, persistent_workers=False)DataLoader 构造函数中最重要的参数是 dataset,它表示一个数据集对象,用于加载数据。
PyTorch 支持两种不同类型的数据集:
映射式数据集(map-style datasets)可迭代式数据集(iterable-style datasets) torch.utils.data.Dataset 类torch.utils.data.Dataset 是 PyTorch 库中的一个标准类,它是 用于自定义数据集的基类。这个类是所有数据集的基础,适用于各种类型的数据加载,包括图像、文本、时间序列等。
作用:Dataset 类的作用是提供一个接口,用于 加载和处理原始数据。它是 PyTorch 的数据加载机制的一部分,通常 与 DataLoader 配合使用。
主要方法:
__len__():返回数据集的大小(即样本的数量)。__getitem__():根据索引返回单个数据样本。DataLoader 会使用该方法来迭代数据。torch.utils.data.Dataset 类 Pytorch 官网文档:
torch.utils.data.Dataset 类是一个抽象类,用于表示一个数据集。
所有表示 从键到数据样本映射的数据集 都应当继承此类。所有子类都应当 重写 __getitem__() 方法,用于 根据给定的键获取数据样本。子类还可以选择性地重写 __len__() 方法,许多 Sampler 实现和 DataLoader 的默认选项都期望返回数据集的大小。子类还可以 选择性地实现 __getitems__() 方法,以加速 批量样本的加载。该方法接受一个包含批次样本索引的列表,并返回一个包含样本的列表。注意:默认情况下,DataLoader 会构造一个索引采样器,该采样器返回整数索引。为了使其与使用非整数索引/键的映射风格数据集兼容,必须提供自定义的采样器。
# Pytorch torch.utils.data.Dataset 类的源码 class Dataset(Generic[_T_co]): r"""An abstract class representing a :class:`Dataset`. All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:`__len__`, which is expected to return the size of the dataset by many :class:`~torch.utils.data.Sampler` implementations and the default options of :class:`~torch.utils.data.DataLoader`. Subclasses could also optionally implement :meth:`__getitems__`, for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples. .. note:: :class:`~torch.utils.data.DataLoader` by default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided. """ def __getitem__(self, index) -> _T_co: raise NotImplementedError("Subclasses of Dataset should implement __getitem__.") # def __getitems__(self, indices: List) -> List[_T_co]: # Not implemented to prevent false-positives in fetcher check in # torch.utils.data._utils.fetch._MapDatasetFetcher def __add__(self, other: "Dataset[_T_co]") -> "ConcatDataset[_T_co]": return ConcatDataset([self, other]) # No `def __len__(self)` default? # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] # in pytorch/torch/utils/data/sampler.py torch.utils.data.DataLoader 类数据加载器(DataLoader)将数据集(dataset)和采样器(sampler)结合起来,并提供对给定数据集的可迭代访问。
DataLoader 支持单进程或多进程加载的 映射式 和 可迭代式 数据集,支持自定义加载顺序、可选的自动批量处理(拼接)以及内存固定(memory pinning)。
这些参数使得 DataLoader 能够在处理数据时非常灵活,支持不同的 数据加载策略、并行处理方式、内存管理 等。通过合理设置这些参数,可以在训练神经网络时实现高效的数据加载和处理。
dataset (Dataset) – 从中加载数据的数据集。
batch_size (int, 可选) – 每个批次加载多少个样本(默认值:1)。
shuffle (bool, 可选) – 设置为 True 时,每个周期(epoch)都会重新打乱数据(默认值:False)。
sampler (Sampler 或 Iterable, 可选) – 定义从数据集中抽取样本的策略。可以是任何实现了 __len__ 的 Iterable。如果指定了 sampler,则不能同时指定 shuffle。
batch_sampler (Sampler 或 Iterable, 可选) – 类似于 sampler,但一次返回一个批次的索引。与 batch_size、shuffle、sampler 和 drop_last 互斥。
num_workers (int, 可选) – 用于数据加载的子进程数量。设置为 0 时,数据将在主进程中加载。(默认值:0)
collate_fn (Callable, 可选) – 将一个样本列表合并为一个 mini-batch 的 Tensor(s)。在使用映射式数据集进行批量加载时使用。
pin_memory (bool, 可选) – 如果为 True,数据加载器将在返回数据之前将 Tensors 复制到设备/CUDA 固定内存中。如果你的数据元素是自定义类型,或者 collate_fn 返回的是自定义类型的批次,请参见下面的示例。
drop_last (bool, 可选) – 设置为 True 时,如果数据集大小不能被批次大小整除,则丢弃最后一个不完整的批次。如果设置为 False 且数据集大小不能被批次大小整除,那么最后一个批次将会更小。(默认值:False)
timeout (数值型, 可选) – 如果为正,则表示从工作进程收集批次的超时值。应该始终是非负值。(默认值:0)
worker_init_fn (Callable, 可选) – 如果不为 None,则会在每个工作进程中调用该函数,输入为工作进程的 ID(一个整数,在 [0, num_workers - 1] 范围内),在种子设置后、数据加载之前调用。(默认值:None)
multiprocessing_context (str 或 multiprocessing.context.BaseContext, 可选) – 如果为 None,则使用操作系统的默认多进程上下文。(默认值:None)
generator (torch.Generator, 可选) – 如果不为 None,则该随机数生成器将由 RandomSampler 用于生成随机索引,并且在多进程中用于生成工作进程的基准种子。(默认值:None)
prefetch_factor (int, 可选,仅限关键字参数) – 每个工作进程提前加载的批次数。设置为 2 时,所有工作进程将提前加载总计 2 * num_workers 个批次。(默认值取决于 num_workers 设置的值。如果 num_workers=0,默认值为 None;如果 num_workers > 0,默认值为 2)
persistent_workers (bool, 可选) – 如果为 True,数据加载器将在数据集被消费一次后不会关闭工作进程。这样可以保持工作进程中的数据集实例处于活动状态。(默认值:False)
pin_memory_device (str, 可选) – 如果 pin_memory 为 True,则为内存固定的设备指定设备名称。
in_order (bool, 可选) – 如果为 False,数据加载器将不强制按照先进先出(FIFO)的顺序返回批次。仅在 num_workers > 0 时生效。(默认值:True)
注意:
如果使用了 spawn 启动方法,则 worker_init_fn 不能是不可序列化的对象,例如 Lambda 函数。有关 PyTorch 中多进程的更多细节,请参阅“多进程最佳实践”。len(dataloader) 的启发式方法基于所使用的采样器的长度。当数据集是一个 IterableDataset 时,它会根据 len(dataset) / batch_size 来返回一个估计值,并根据 drop_last 设置进行适当的四舍五入,而不管多进程加载配置如何。这是 PyTorch 能做出的 最佳估算,因为 PyTorch 相信用户的数据集代码能够正确处理多进程加载,避免重复数据。 然而,如果数据分片导致多个工作进程的最后一个批次不完整,那么这个估计仍然可能不准确,因为 (1) 一个原本完整的批次可能会被分成多个批次,(2) 当 drop_last 设置为 True 时,可能会丢失多个批次的样本。不幸的是,PyTorch 通常无法检测到这种情况。 有关这两种数据集类型以及 IterableDataset 如何与多进程数据加载交互的更多细节,请参阅“数据集类型”。有关随机种子相关的问题,请参阅“可重现性”,“我的数据加载器工作进程返回相同的随机数”以及“多进程数据加载中的随机性”相关说明。将 in_order 设置为 False 可能会影响可重现性,并且在数据不平衡的情况下,可能会导致传递给训练器的数据分布偏斜。 自定义数据集示例 1. 自定义 Dataset 类首先,需要定义一个自定义的 Dataset 类,继承自 torch.utils.data.Dataset。
# dataset.py import torch from torch.utils.data import Dataset import os from PIL import Image from torchvision import transforms class CustomDataset(Dataset): def __init__(self, data_dir, mode='train', transform=None): """ :param data_dir: 数据集根目录 :param mode: 'train' 或 'test',指定加载训练集或测试集 :param transform: 数据转换(如图像缩放,裁剪,归一化等) """ self.data_dir = data_dir self.mode = mode self.transform = transform # 假设数据集结构是这样的: # data_dir/ # train/ # class1/ # class2/ # test/ # class1/ # class2/ self.image_paths = [] self.labels = [] # 加载数据 self._load_data() def _load_data(self): """根据模式加载训练集或测试集数据""" # 设置数据目录 data_folder = os.path.join(self.data_dir, self.mode) for label, class_name in enumerate(os.listdir(data_folder)): class_folder = os.path.join(data_folder, class_name) if os.path.isdir(class_folder): for img_name in os.listdir(class_folder): img_path = os.path.join(class_folder, img_name) self.image_paths.append(img_path) self.labels.append(label) def __len__(self): """返回数据集的大小""" return len(self.image_paths) def __getitem__(self, idx): """根据索引返回数据和标签""" img_path = self.image_paths[idx] label = self.labels[idx] # 读取图像 image = Image.open(img_path).convert('RGB') # 应用转换 if self.transform: image = self.transform(image) return image, label在自定义数据集 CustomDataset 中,通过 mode='train' 或 mode='test' 来决定加载训练集或测试集的数据。这个 mode 参数可以在创建数据集时传入:
mode='train' 时,加载训练集的数据。mode='test' 时,加载测试集的数据。该示例 假设数据集存储在 data_dir/train 和 data_dir/test 文件夹下,并按类存放在子文件夹中。可以根据自己的数据存储结构修改 _load_data() 方法。
2. 在其他 .py 文件中引用和使用该自定义 Dataset在其他 .py 文件中,可以引用该数据集类,并根据需要加载训练集或测试集。还可以传递数据增强或转换操作,例如使用 torchvision.transforms 来进行图像处理。
# main.py import torch from torch.utils.data import DataLoader from dataset import CustomDataset from torchvision import transforms # 定义数据增强和预处理操作 transform = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 创建训练集和测试集 train_dataset = CustomDataset(data_dir='./data', mode='train', transform=transform) test_dataset = CustomDataset(data_dir='./data', mode='test', transform=transform) # 使用 DataLoader 加载数据集 train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) # 使用训练集 for images, labels in train_loader: print(images.shape, labels.shape) # 输出每个批次的图像和标签的大小 # 在此进行训练 torch_geometric.data.Dataset 类torch_geometric.data.Dataset 类官方文档
在 PyTorch 中,torch.utils.data.Dataset 和 torch_geometric.data.Dataset 都是用来表示数据集的基类,但它们的作用、设计和用途有一些显著的区别,特别是在图神经网络(GNN)方面,
torch.utils.data.Dataset 更 通用,适用于 各种类型的数据集。torch_geometric.data.Dataset 是 专门为处理图数据(如图结构数据、边、节点特征等)而设计的,torch_geometric.data.Dataset 是 PyTorch Geometric(一个专门为图神经网络设计的扩展库)中的数据集类,继承自 torch.utils.data.Dataset,专门用于处理图数据,如图结构数据、节点特征 node_features、边特征 edge_index 等。
主要方法:
__len__():返回数据集中的图的数量。get():返回单个图的数据。get() 通常是实现单个数据项的加载过程,返回一个 Data 对象,其中包含图的结构信息(如 edge_index)、节点特征(如 x)等。用途:专门用于图神经网络(GNN)任务,适合处理图结构数据,例如社交网络、分子结构、物理网络等。
torch_geometric.data.Dataset VS torch.utils.data.Dataset主要区别:
特性torch.utils.data.Datasettorch_geometric.data.Dataset设计目的用于处理一般的数据集(如图像、文本、时间序列等)。专门为图神经网络设计,处理图结构数据(如图、边、节点特征)。继承关系基本类,PyTorch 数据加载的基类。继承自 torch.utils.data.Dataset,扩展为图数据处理。数据存储存储一般的数据(如图像、文本数据等)。存储图数据结构,包括 edge_index、x(节点特征)等。核心方法__getitem__() 返回单个数据项,通常是一个样本。get() 返回一个图数据对象,通常是 Data。数据类型适用于任何类型的数据集。适用于图结构数据集。多进程支持支持多进程数据加载(通过 num_workers)。同样支持多进程数据加载,但主要针对图数据。批处理支持支持自动批处理(batching),使用 DataLoader。也支持批处理,并且能够自动处理图结构的数据。图数据处理没有内建的图数据处理支持。内建支持图数据结构,如 edge_index、x(节点特征)、y(标签)。关系:
torch_geometric.data.Dataset 是 torch.utils.data.Dataset 的一个扩展,专门为图数据设计。torch_geometric.data.Dataset 在其基础上提供了对图数据的支持,包括图的边结构、节点特征等,使得 PyTorch Geometric 更适用于 图神经网络(GNN) 等图学习任务。【Pytorch库】自定义数据集相关的类由讯客互联开源代码栏目发布,感谢您对讯客互联的认可,以及对我们原创作品以及文章的青睐,非常欢迎各位朋友分享到个人网站或者朋友圈,但转载请说明文章出处“【Pytorch库】自定义数据集相关的类”