注意力
2024年6月更新:移除DataPipes和DataLoader V2
我们正在重新聚焦torchdata仓库,使其成为torch.utils.data.DataLoader的迭代增强。我们不计划继续开发或维护[DataPipes]和[DataLoaderV2]解决方案,并且它们将从torchdata仓库中移除。我们还将重新审视pytorch/pytorch中的DataPipes引用。在发布torchdata==0.8.0(2024年7月)版本中,它们将被标记为已弃用,并在0.10.0(2024年末)版本中被删除。现有用户建议固定到torchdata<=0.9.0或更早版本,直到他们能够迁移为止。后续版本将不再包含DataPipes或DataLoaderV2。如果您有任何建议或意见,请通过此问题反馈。
什么是 torchdata.nodes (beta)?¶
torchdata.nodes 是一个可组合的迭代器库(而不是可迭代对象!)它允许你将常见的数据加载和预处理操作串联在一起。它遵循流式编程模型,尽管“采样器 + 映射风格”仍然可以配置,如果你愿意的话。
torchdata.nodes 添加了对标准的更多灵活性
torch.utils.data 提供,以及引入了多线程并行计算
除了多进程(在torch.utils.data.DataLoader中唯一支持的方法),
以及通过state_dict/load_state_dict接口提供第一类支持的
中期epoch检查点。
torchdata.nodes 努力包含尽可能多的有用操作符,然而它被设计为可扩展。新节点需要继承 torchdata.nodes.BaseNode(本身又继承 typing.Iterator),并实现 next(), reset(initial_state) 和 get_state() 操作(值得注意的是,不包括 __next__, load_state_dict, 或 state_dict)
参见 开始使用 torchdata.nodes(测试版) 以开始
为什么 torchdata.nodes?¶
我们明白了,torch.utils.data 对于很多很多的使用案例来说都行得通。
然而它确实有很多粗糙的地方:
多进程很糟糕¶
你需要复制存储在Dataset中的内存(因为Python的读取时复制)
IPC 在多进程队列上运行缓慢,并可能导致启动时间变慢。
你被迫在工作者上进行批量处理,而不是在主进程上,以减少IPC开销,增加峰值内存。
在释放GIL和使用自由线程的Python中,多线程可能不像以前那样被GIL绑定。
torchdata.nodes 启用了多线程和多进程,因此您可以根据您特定的设置选择最适合您的选项。并行处理主要配置在Mapper操作器中,为您提供灵活性以确定何时、如何以及对什么进行并行化。
地图样式和随机访问不支持扩展¶
当前的映射数据集方法非常适合内存中可以容纳的数据集,但一旦你的数据集超过内存限制,真正的随机访问性能将不会很好,除非你通过一个特殊采样器跳过一些障碍。
torchdata.nodes 遵循流数据模型,其中操作符是迭代器,可以组合在一起以定义数据加载和预处理管道。仍然支持采样器(请参阅 从 torch.utils.data 迁移到 torchdata.nodes),并且可以与 Mapper 结合使用以生成迭代器
多数据集与当前实现不匹配torch.utils.data¶
当前的采样器(每个数据加载器一个)的概念在尝试结合多个数据集时开始崩溃。对于单个数据集,它们是一个很好的抽象,并将继续得到支持!
对于多数据集,考虑以下场景:
len(dsA): 10. 现在我们想在这两个数据集中进行轮询(或均匀采样)以供我们的训练器使用。仅用一个采样器如何实现该策略?也许是一个发射元组的采样器?如果要与RandomSampler或DistributedSampler互换?sampler.set_epoch将如何工作?
torchdata.nodes 帮助通过只处理迭代器来解决和扩展多数据集的数据加载,从而迫使采样器和数据集在一起,专注于将更小的原始节点组合成一个更复杂的数据加载管道。
IterableDataset + multiprocessing需要额外的数据集分片¶
数据集分片是进行数据并行训练所必需的,这是相当合理的。但是,对于数据加载器工作者之间的分片呢?对于使用映射式数据集,工作分配由主进程处理,它将采样索引分发给工作者。对于IterableDatasets,每个工作者需要通过torch.utils.data.get_worker_info来确定应该返回的数据。
如何 torchdata.nodes 运行?¶
我们展示了一些早期版本的torchdata.nodes在视频解码基准测试中的结果,这些结果是在PyTorch Conf 2024上展示的,我们展示了以下内容:
torchdata.nodes 在使用多进程时表现相当或更好
torch.utils.data.DataLoader(请参阅 从 torch.utils.data 迁移到 torchdata.nodes)在使用GIL的Python中,torchdata节点在多线程环境下表现优于多进程环境,在某些场景下,这使得像GPU预处理这样的功能更容易实现,从而可以提升许多应用场景的吞吐量。
使用无GIL/自由线程的Python(3.13t),我们运行了一个基准测试,从磁盘加载ImageNet数据集,并且在显著更低的CPU利用率下达到了主内存带宽的饱和,相比于多进程工作者(预计2025年初发布博客文章)。请参阅 imagenet_benchmark.py 在您自己的硬件上尝试。
设计选择¶
无基节点生成器¶
查看 https://github.com/pytorch/data/pull/1362 以获取更多见解。
我们做出的一个艰难选择是,在定义新基节点实现时,不允许使用生成器。然而,出于管理状态的考虑,我们放弃了这个选项,并转向了仅支持迭代器的基础架构。
在BaseNode实现中,我们需要明确的状态处理。 生成器将状态隐式地存储在栈上,我们发现需要跳过许多步骤,并编写非常复杂的代码来使基本状态工作。
迭代结束状态字典:迭代器可能感觉更自然,然而在状态管理方面会遇到一些问题。考虑一下迭代结束时的状态字典。如果你将这个状态字典加载到你的迭代器中,它应该代表迭代结束还是下一个迭代的开始?
加载状态:如果你调用 load_state_dict() 在一个可迭代对象上,大多数用户会期望从它请求的下一个迭代器开始时加载的状态。然而,如果在开始迭代之前 iter 被调用了两次呢?
多个活跃迭代器问题:如果你有一个实例的可迭代对象,但有两个活跃的迭代器,那么对这个可迭代对象调用state_dict()意味着什么?在数据加载中,这种情况非常罕见,但我们仍然需要绕过它并做出很多假设。迫使那些正在实现基节点的开发人员考虑这些场景,在我们看来,比不允许生成器和可迭代对象更糟糕。
torchdata.nodes.BaseNode 实现是迭代器。迭代器
定义 next(), get_state(), 和 reset(initial_state | None)。
所有重初始化应在 reset() 中进行,包括在传递特定状态时初始化。
然而,终端用户习惯于处理可迭代对象,例如,
for epoch in range(5):
# Most frameworks and users don't expect to call loader.reset()
for batch in loader:
...
sd = loader.state_dict()
# Loading sd should not throw StopIteration right away, but instead start at the next epoch
为了处理这个问题,我们将所有假设和特殊结束周期处理放在一个名为 Loader 的类中,该类接受任何 BaseNode 并将其转换为 Iterable,处理 reset() 调用和周期状态字典加载。