注意力
2024 年 6 月状态更新:删除 DataPipes 和 DataLoader V2
我们将 torchdata 存储库重新调整为torch.utils.data.DataLoader的迭代增强。我们不打算 继续开发或维护 [DataPipes] 和 [DataLoaderV2] 解决方案,它们将从 torchdata 存储库。我们还将重新访问 pytorch/pytorch 中的 DataPipes 引用。在 torchdata==0.8.0(2024 年 7 月)版本中,它们将被标记为已弃用,而在 0.10.0(2024 年末)中,它们将被删除。现存 建议用户固定到 torchdata<=0.9.0 或更早版本,直到他们能够迁移出去。随后的 版本将不包含 DataPipes 或 DataLoaderV2。 如果您有建议或评论,请联系我们(请使用此问题进行反馈)
有状态 DataLoader¶
StatefulDataLoader 是 torch.utils.data.DataLoader 的直接替代品,它提供用于处理中期检查点的方法,这些方法对从数据加载器请求的上一个/下一个迭代器进行操作(或)。state_dict
load_state_dict
默认情况下,状态包括生成的批次数,并使用它来天真地快进采样器(map-style)或数据集(iterable-style)。但是,如果采样器和/或数据集包含 / 方法,则它将在自己的 / 调用期间调用它们。在后台,处理跨多进程 worker(但不能跨等级)的状态聚合和分配。
state_dict
load_state_dict
state_dict
load_state_dict
- torchdata.stateful_dataloader 类。StatefulDataLoader(dataset: Dataset[_T_co], batch_size: Optional[int] = 1, shuffle: 可选[bool] = 无, sampler: 可选[Union[Sampler, Iterable]] = None, batch_sampler: optional[Union[Sampler[List], Iterable[列表]]] = 无,num_workers:int = 0,collate_fn: 可选[Callable[[List[_T]], any]] = None, pin_memory: bool = False,drop_last:bool = False,超时:float = 0, worker_init_fn: 可选[Callable[[int], None]] = None, multiprocessing_context=无,生成器=无,*, prefetch_factor: 可选[int] = 无, persistent_workers: bool = False, pin_memory_device: str = '', snapshot_every_n_steps: 可选[int] = 1)¶
这是实现 state_dict 和 load_state_dict 方法的直接替代品,可实现 epoch 中期 检查点。
torch.utils.data.DataLoader
所有参数都与 , 相同,其中 新的 kwarg: .
torch.utils.data.DataLoader
snapshot_every_n_steps
- 参数
dataset (Dataset) – 从中加载数据的数据集。
batch_size (int, optional) – 每批要加载的样本数 (默认值:)。
1
shuffle (bool, optional) – 设置为重新洗牌数据 在每个 epoch (默认值: )。
True
False
sampler (Sampler 或 Iterable,可选) – 定义要绘制的策略 数据集中的样本。可以是任何已实施的。如果指定,则不得指定。
Iterable
__len__
shuffle
batch_sampler (Sampler 或 Iterable,可选) – like ,但 一次返回一批索引。与 、 、 互斥 和。
sampler
batch_size
shuffle
sampler
drop_last
num_workers (int, optional) – 用于数据的子进程数 装载。 表示数据将在主进程中加载。 (默认:
0
0
)collate_fn (Callable, optional) – 合并样本列表以形成 小批量的 Tensor 中。当使用 batch loading from 地图样式数据集。
pin_memory (bool, optional) – 如果 ,数据加载器将复制张量 放入 device/CUDA 固定内存中。如果您的数据元素 是自定义类型,或者您返回的批次是自定义类型, 请参阅下面的示例。
True
collate_fn
drop_last (bool, optional) – 设置为 以删除最后一个未完成的批次, 如果数据集大小不能被批量大小整除。If 和 数据集的大小不能被批次大小整除,然后是最后一个批次 会更小。(默认:
True
False
False
)timeout (numeric, optional) – 如果为正数,则为收集批次的超时值 从工人。应始终为非负数。(默认:
0
)worker_init_fn (Callable, optional) – 如果不是 ,则将在每个 worker 子进程,其中 worker id ( int in ) 为 input、seeding 之后和 data loading 之前。(默认:
None
[0, num_workers - 1]
None
)multiprocessing_context (str 或 multiprocessing.context.BaseContext,可选) – 如果 ,则操作系统的默认多处理上下文将 被使用。(默认:
None
None
)发电机 (Torch.生成器,可选) – 如果没有,将使用此 RNG 通过 RandomSampler 生成随机索引,并通过 multiprocessing 为 worker 生成。(默认:
None
base_seed
None
)prefetch_factor (int, optional, keyword-only arg) – 加载的批次数 由每个 worker 提前完成。 表示总共会有 2 * num_workers 个批次,在所有工作程序中预取。(默认值取决于 在 num_workers 的 Set 值上。如果值 num_workers=0,则默认值为 。 否则,如果 default 的值为 )。
2
None
num_workers > 0
2
persistent_workers (bool, optional) – 如果 ,则数据加载程序不会关闭 工作程序在 dataset 被使用一次后进行处理。这允许 保持 worker Dataset 实例处于活动状态。(默认:
True
False
)pin_memory_device (str, optional) – 如果设备为 .
pin_memory
pin_memory
True
snapshot_every_n_steps (int, optional) – 定义状态的频率 从 DataLoader 工作程序传输到 DataLoader。默认情况下,它设置为 ,即每一步都传输状态。如果 state 很大,则可以增加此值(最好设置为训练 checkpoint 的频率),以减少每一步传输 state 的开销。
1
警告
如果使用 start 方法,则不能是不可封存的对象,例如 lambda 函数。请参阅 multiprocessing-best-practices 了解更多相关详细信息 添加到 PyTorch 中的 multiprocessing 中。
spawn
worker_init_fn
警告
len(dataloader)
启发式 (heuristic) 基于所使用的采样器的长度。 当 为 时 , 相反,它返回基于 的估计值,并使用适当的 舍入取决于 ,而不考虑多进程加载 配置。这代表了 PyTorch 可以做出的最佳猜测,因为 PyTorch 信任用户代码正确处理多进程 loading 以避免重复数据。dataset
IterableDataset
len(dataset) / batch_size
drop_last
dataset
但是,如果分片导致多个 worker 具有不完整的最后一批,则 此估计仍然可能不准确,因为 (1) 否则完整的批次可能 被分成多个 1 和 (2) 多个批次的样品可以是 set 时丢弃。不幸的是,PyTorch 无法检测到此类 一般情况。
drop_last
警告
有关随机种子相关问题,请参阅可重复性、 Dataloader-workers-random-seed 和 Data-loading-randomness 注释。