注意力
2024年6月更新:移除DataPipes和DataLoader V2
我们正在重新聚焦torchdata仓库,使其成为torch.utils.data.DataLoader的迭代增强。我们不计划继续开发或维护[DataPipes]和[DataLoaderV2]解决方案,并且它们将从torchdata仓库中移除。我们还将重新审视pytorch/pytorch中的DataPipes引用。在发布torchdata==0.8.0(2024年7月)版本中,它们将被标记为已弃用,并在0.10.0(2024年末)版本中被删除。现有用户建议固定到torchdata<=0.9.0或更早版本,直到他们能够迁移为止。后续版本将不再包含DataPipes或DataLoaderV2。如果您有任何建议或意见,请通过此问题反馈。
开始使用 torchdata.nodes (beta)¶
使用 pip 安装 torchdata。
pip install torchdata>=0.10.0
生成器示例¶
将生成器(或任何可迭代对象)包裹起来,将其转换为 BaseNode,并开始使用。
from torchdata.nodes import IterableWrapper, ParallelMapper, Loader
node = IterableWrapper(range(10))
node = ParallelMapper(node, map_fn=lambda x: x**2, num_workers=3, method="thread")
loader = Loader(node)
result = list(loader)
print(result)
# [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
Sampler 示例¶
采样器仍然受支持,您可以使用现有的
torch.utils.data.Dataset's。有关深入示例,请参阅从 torch.utils.data 迁移到 torchdata.nodes。
import torch.utils.data
from torchdata.nodes import SamplerWrapper, ParallelMapper, Loader
class SquaredDataset(torch.utils.data.Dataset):
def __getitem__(self, i: int) -> int:
return i**2
def __len__(self):
return 10
dataset = SquaredDataset()
sampler = RandomSampler(dataset)
# For fine-grained control of iteration order, define your own sampler
node = SamplerWrapper(sampler)
# Simply apply dataset's __getitem__ as a map function to the indices generated from sampler
node = ParallelMapper(node, map_fn=dataset.__getitem__, num_workers=3, method="thread")
# Loader is used to convert a node (iterator) into an Iterable that may be reused for multi epochs
loader = Loader(node)
print(list(loader))
# [25, 36, 9, 49, 0, 81, 4, 16, 64, 1]
print(list(loader))
# [0, 4, 1, 64, 49, 25, 9, 16, 81, 36]