目录

注意力

2024年6月更新:移除DataPipes和DataLoader V2

我们正在重新聚焦torchdata仓库,使其成为torch.utils.data.DataLoader的迭代增强。我们不计划继续开发或维护[DataPipes]和[DataLoaderV2]解决方案,并且它们将从torchdata仓库中移除。我们还将重新审视pytorch/pytorch中的DataPipes引用。在发布torchdata==0.8.0(2024年7月)版本中,它们将被标记为已弃用,并在0.10.0(2024年末)版本中被删除。现有用户建议固定到torchdata<=0.9.0或更早版本,直到他们能够迁移为止。后续版本将不再包含DataPipes或DataLoaderV2。如果您有任何建议或意见,请通过此问题反馈。

状态型数据加载器

StatefulDataLoader 是一个可以替换 torch.utils.data.DataLoader 的降级替代品,它提供了 state_dict / load_state_dict 种方法来处理中间epoch的检查点,这些方法在从数据加载器请求的上一个/下一个迭代器上操作(分别)。

默认情况下,状态包括生成的批次数量,并使用此来简单地向前推进采样器(映射式)或数据集(迭代式)。然而,如果采样器和/或数据集包含 state_dict / load_state_dict 种方法,则会在其自身的 state_dict / load_state_dict 次调用中调用它们。在幕后,StatefulDataLoader 处理多进程工作者之间的状态聚合和分布(但不跨队列)。

class torchdata.stateful_dataloader.StatefulDataLoader(dataset: Dataset[_T_co], batch_size: Optional[int] = 1, shuffle: Optional[bool] = None, sampler: Optional[Union[Sampler, Iterable]] = None, batch_sampler: Optional[Union[Sampler[List], Iterable[List]]] = None, num_workers: int = 0, collate_fn: Optional[Callable[[List[_T]], Any]] = None, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[Callable[[int], None]] = None, multiprocessing_context=None, generator=None, *, prefetch_factor: Optional[int] = None, persistent_workers: bool = False, pin_memory_device: str = '', snapshot_every_n_steps: Optional[int] = 1)

这是对 torch.utils.data.DataLoader 的直接替换, 实现了 state_dict 和 load_state_dict 方法,支持中途检查点保存。

所有参数都与 torch.utils.data.DataLoader 相同,新增一个 kwarg:snapshot_every_n_steps

Parameters:
  • dataset (Dataset) – 从其中加载数据的数据集。

  • 批量大小 (整数, 可选) – 要加载的样本数量(默认:1)。

  • 打乱 (bool, optional) – 设置为True以在每个epoch重新打乱数据(默认:False)。

  • 采样器 (采样器可迭代, 可选) – 定义从数据集中抽样的策略。可以是任何实现Iterable__len__。如果指定,shuffle不能被指定。

  • batch_sampler (Sampler or Iterable, optional) – 像 sampler,但 返回一个批次的索引一次。与 batch_size, shuffle, sampler, 和 drop_last 互斥。

  • num_workers (int, optional) – 使用多少个子进程来加载数据。0 表示数据将在主进程中加载。(默认:0)

  • collate_fn (Callable, 可选) – 合并一个样本列表以形成一个 张量的小批量。当使用来自映射式数据集的批量加载时使用。

  • pin_memory (bool, optional) – 如果 True,数据加载器将在返回它们之前将张量复制到设备/CUDA绑定内存中。如果你的数据元素是自定义类型,或者你的 collate_fn 返回的是一个自定义类型的批量,参见下面的示例。

  • drop_last (bool, optional) – 设置为True以丢弃最后一个不完整的批次, 如果数据集大小不能被批量大小整除。如果False且 数据集大小不能被批量大小整除,则最后一个批次将较小。(默认: False)

  • timeout (数字类型, 可选) – 如果为正数,则是从工作进程中收集一个批次数据的超时值。该值应始终为非负数。(默认: 0)

  • worker_init_fn (Callable, optional) – 如果不是 None,则在每个 工作子进程上会使用工作器 ID(一个在 [0, num_workers - 1] 中的 int)作为 输入,在设置随机种子之后和加载数据之前调用。 (默认值: None)

  • multiprocessing_context (str or multiprocessing.context.BaseContext, optional) – 如果 None,将使用操作系统默认的 multiprocessing context。 (默认: None)

  • 生成器 (torch.Generator, 可选) – 如果不是None,这个RNG将被RandomSampler用于生成随机索引和多进程以生成base_seed为工作者。 (默认:None)

  • prefetch_factor (int, optional, keyword-only arg) – Number of batches loaded in advance by each worker. 2 means there will be a total of 2 * num_workers batches prefetched across all workers. (default value depends on the set value for num_workers. If value of num_workers=0 default is None. Otherwise, if value of num_workers > 0 default is 2).

  • 持久化工作者 (bool, optional) – 如果 True,数据加载器在一次消耗完数据集后不会关闭工作进程。这允许保持工作者 Dataset 个实例存活。(默认:False)

  • pin_memory_device (str, optional) – the device to pin_memory to if pin_memory is True.

  • 每几步保存一次快照 (整数, 可选) – 定义状态从数据加载器工作者转移到数据加载器的频率。默认情况下,它被设置为1,即状态每一步都转移。如果状态较大,可以增加这个值(并最好将其设置为训练检查点频率)以减少每步转移状态的开销。

警告

如果使用spawn启动方法,worker_init_fn 不能是不可序列化的对象,例如lambda函数。有关PyTorch中多进程处理的更多详细信息,请参阅 multiprocessing-best-practices

警告

len(dataloader) 启发式算法基于所使用的采样器的长度。 当 dataset 是一个 IterableDataset 时, 它会根据 len(dataset) / batch_size 返回一个估计值,并根据 drop_last 进行适当的舍入,而不论多进程加载配置如何。这代表了 PyTorch 能做出的最佳猜测,因为 PyTorch 相信用户编写的 dataset 代码能够正确处理多进程加载以避免重复数据。

然而,如果分片导致多个工作器拥有不完整的最后一批数据, 这个估计仍然可能不准确,因为 (1) 一个原本完整的批次可以 被拆分成多个批次,以及 (2) 当设置 drop_last 时,可能会丢弃超过一个批次的数据。不幸的是,PyTorch 无法在一般情况下检测到这些情况。

参见数据集类型,了解更多关于这两种数据集的信息以及 IterableDataset 如何与 多进程数据加载 交互。

警告

参见可重复性,以及数据加载器工作线程随机种子,以及 数据加载随机性说明以了解与随机种子相关的问题。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源