注意力
2024年6月更新:移除DataPipes和DataLoader V2
我们正在重新聚焦torchdata仓库,使其成为torch.utils.data.DataLoader的迭代增强。我们不计划继续开发或维护[DataPipes]和[DataLoaderV2]解决方案,并且它们将从torchdata仓库中移除。我们还将重新审视pytorch/pytorch中的DataPipes引用。在发布torchdata==0.8.0(2024年7月)版本中,它们将被标记为已弃用,并在0.10.0(2024年末)版本中被删除。现有用户建议固定到torchdata<=0.9.0或更早版本,直到他们能够迁移为止。后续版本将不再包含DataPipes或DataLoaderV2。如果您有任何建议或意见,请通过此问题反馈。
状态型数据加载器¶
StatefulDataLoader 是一个可以替换 torch.utils.data.DataLoader 的降级替代品,它提供了 state_dict / load_state_dict 种方法来处理中间epoch的检查点,这些方法在从数据加载器请求的上一个/下一个迭代器上操作(分别)。
默认情况下,状态包括生成的批次数量,并使用此来简单地向前推进采样器(映射式)或数据集(迭代式)。然而,如果采样器和/或数据集包含 state_dict / load_state_dict 种方法,则会在其自身的 state_dict / load_state_dict 次调用中调用它们。在幕后,StatefulDataLoader 处理多进程工作者之间的状态聚合和分布(但不跨队列)。
- class torchdata.stateful_dataloader.StatefulDataLoader(dataset: Dataset[_T_co], batch_size: Optional[int] = 1, shuffle: Optional[bool] = None, sampler: Optional[Union[Sampler, Iterable]] = None, batch_sampler: Optional[Union[Sampler[List], Iterable[List]]] = None, num_workers: int = 0, collate_fn: Optional[Callable[[List[_T]], Any]] = None, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[Callable[[int], None]] = None, multiprocessing_context=None, generator=None, *, prefetch_factor: Optional[int] = None, persistent_workers: bool = False, pin_memory_device: str = '', snapshot_every_n_steps: Optional[int] = 1)¶
这是对
torch.utils.data.DataLoader的直接替换, 实现了 state_dict 和 load_state_dict 方法,支持中途检查点保存。所有参数都与
torch.utils.data.DataLoader相同,新增一个 kwarg:snapshot_every_n_steps。- Parameters:
dataset (Dataset) – 从其中加载数据的数据集。
批量大小 (整数, 可选) – 要加载的样本数量(默认:
1)。打乱 (bool, optional) – 设置为
True以在每个epoch重新打乱数据(默认:False)。采样器 (采样器 或 可迭代, 可选) – 定义从数据集中抽样的策略。可以是任何实现
Iterable的__len__。如果指定,shuffle不能被指定。batch_sampler (Sampler or Iterable, optional) – 像
sampler,但 返回一个批次的索引一次。与batch_size,shuffle,sampler, 和drop_last互斥。num_workers (int, optional) – 使用多少个子进程来加载数据。
0表示数据将在主进程中加载。(默认:0)collate_fn (Callable, 可选) – 合并一个样本列表以形成一个 张量的小批量。当使用来自映射式数据集的批量加载时使用。
pin_memory (bool, optional) – 如果
True,数据加载器将在返回它们之前将张量复制到设备/CUDA绑定内存中。如果你的数据元素是自定义类型,或者你的collate_fn返回的是一个自定义类型的批量,参见下面的示例。drop_last (bool, optional) – 设置为
True以丢弃最后一个不完整的批次, 如果数据集大小不能被批量大小整除。如果False且 数据集大小不能被批量大小整除,则最后一个批次将较小。(默认:False)timeout (数字类型, 可选) – 如果为正数,则是从工作进程中收集一个批次数据的超时值。该值应始终为非负数。(默认:
0)worker_init_fn (Callable, optional) – 如果不是
None,则在每个 工作子进程上会使用工作器 ID(一个在[0, num_workers - 1]中的 int)作为 输入,在设置随机种子之后和加载数据之前调用。 (默认值:None)multiprocessing_context (str or multiprocessing.context.BaseContext, optional) – 如果
None,将使用操作系统默认的 multiprocessing context。 (默认:None)生成器 (torch.Generator, 可选) – 如果不是
None,这个RNG将被RandomSampler用于生成随机索引和多进程以生成base_seed为工作者。 (默认:None)prefetch_factor (int, optional, keyword-only arg) – Number of batches loaded in advance by each worker.
2means there will be a total of 2 * num_workers batches prefetched across all workers. (default value depends on the set value for num_workers. If value of num_workers=0 default isNone. Otherwise, if value ofnum_workers > 0default is2).持久化工作者 (bool, optional) – 如果
True,数据加载器在一次消耗完数据集后不会关闭工作进程。这允许保持工作者 Dataset 个实例存活。(默认:False)pin_memory_device (str, optional) – the device to
pin_memoryto ifpin_memoryisTrue.每几步保存一次快照 (整数, 可选) – 定义状态从数据加载器工作者转移到数据加载器的频率。默认情况下,它被设置为
1,即状态每一步都转移。如果状态较大,可以增加这个值(并最好将其设置为训练检查点频率)以减少每步转移状态的开销。
警告
如果使用
spawn启动方法,worker_init_fn不能是不可序列化的对象,例如lambda函数。有关PyTorch中多进程处理的更多详细信息,请参阅 multiprocessing-best-practices。警告
len(dataloader)启发式算法基于所使用的采样器的长度。 当dataset是一个IterableDataset时, 它会根据len(dataset) / batch_size返回一个估计值,并根据drop_last进行适当的舍入,而不论多进程加载配置如何。这代表了 PyTorch 能做出的最佳猜测,因为 PyTorch 相信用户编写的dataset代码能够正确处理多进程加载以避免重复数据。然而,如果分片导致多个工作器拥有不完整的最后一批数据, 这个估计仍然可能不准确,因为 (1) 一个原本完整的批次可以 被拆分成多个批次,以及 (2) 当设置
drop_last时,可能会丢弃超过一个批次的数据。不幸的是,PyTorch 无法在一般情况下检测到这些情况。警告
参见可重复性,以及数据加载器工作线程随机种子,以及 数据加载随机性说明以了解与随机种子相关的问题。