目录

注意力

2024 年 6 月状态更新:删除 DataPipes 和 DataLoader V2

我们将 torchdata 存储库重新调整为torch.utils.data.DataLoader的迭代增强。我们不打算 继续开发或维护 [DataPipes] 和 [DataLoaderV2] 解决方案,它们将从 torchdata 存储库。我们还将重新访问 pytorch/pytorch 中的 DataPipes 引用。在 torchdata==0.8.0(2024 年 7 月)版本中,它们将被标记为已弃用,而在 0.10.0(2024 年末)中,它们将被删除。现存 建议用户固定到 torchdata<=0.9.0 或更早版本,直到他们能够迁移出去。随后的 版本将不包含 DataPipes 或 DataLoaderV2。 如果您有建议或评论,请联系我们(请使用此问题进行反馈)

迁移到 从torchdata.nodestorch.utils.data

本指南旨在帮助熟悉 或 的用户 以开始使用 ,并为定义 您自己的数据加载管道。torch.utils.datatorchdata.nodes

我们将演示如何实现最常见的 DataLoader 功能,重用现有的采样器和数据集,以及 和 load/save dataloader 状态。它的性能至少与 和 一样好。 请参阅 torchdata.nodes 的执行方式?DataLoaderStatefulDataLoader

地图样式数据集

让我们看看构造函数 args,然后从那里开始DataLoader

class DataLoader:
    def __init__(
        self,
        dataset: Dataset[_T_co],
        batch_size: Optional[int] = 1,
        shuffle: Optional[bool] = None,
        sampler: Union[Sampler, Iterable, None] = None,
        batch_sampler: Union[Sampler[List], Iterable[List], None] = None,
        num_workers: int = 0,
        collate_fn: Optional[_collate_fn_t] = None,
        pin_memory: bool = False,
        drop_last: bool = False,
        timeout: float = 0,
        worker_init_fn: Optional[_worker_init_fn_t] = None,
        multiprocessing_context=None,
        generator=None,
        *,
        prefetch_factor: Optional[int] = None,
        persistent_workers: bool = False,
        pin_memory_device: str = "",
        in_order: bool = True,
    ):
        ...

作为 referesher,数据加载大致是 中的工作原理:首先从 a 生成索引,然后创建batch_size索引的批次。 如果未提供采样器,则默认创建 RandomSampler 或 SequentialSampler。 将索引传递给 ,然后将 a 应用于批处理 的样本。如果 ,它将使用多进程来创建 子进程,并将索引批次传递给工作进程,然后工作进程将在将批次返回给主进程之前调用和应用。此时,可以应用于批处理中的张量。torch.utils.data.DataLoaderDataLoadersamplerDataset.__getitem__()collate_fnnum_workers > 0Dataset.__getitem__()collate_fnpin_memory

现在让我们看看 DataLoader 的等效实现可能是什么样子,它使用 .torchdata.nodes

from typing import List, Callable
import torchdata.nodes as tn
from torch.utils.data import RandomSampler, SequentialSampler, default_collate, Dataset

class MapAndCollate:
    """A simple transform that takes a batch of indices, maps with dataset, and then applies
    collate.
    TODO: make this a standard utility in torchdata.nodes
    """
    def __init__(self, dataset, collate_fn):
        self.dataset = dataset
        self.collate_fn = collate_fn

    def __call__(self, batch_of_indices: List[int]):
        batch = [self.dataset[i] for i in batch_of_indices]
        return self.collate_fn(batch)

# To keep things simple, let's assume that the following args are provided by the caller
def NodesDataLoader(
    dataset: Dataset,
    batch_size: int,
    shuffle: bool,
    num_workers: int,
    collate_fn: Callable | None,
    pin_memory: bool,
    drop_last: bool,
):
    # Assume we're working with a map-style dataset
    assert hasattr(dataset, "__getitem__") and hasattr(dataset, "__len__")
    # Start with a sampler, since caller did not provide one
    sampler = RandomSampler(dataset) if shuffle else SequentialSampler(dataset)
    # Sampler wrapper converts a Sampler to a BaseNode
    node = tn.SamplerWrapper(sampler)

    # Now let's batch sampler indices together
    node = tn.Batcher(node, batch_size=batch_size, drop_last=drop_last)

    # Create a Map Function that accepts a list of indices, applies getitem to it, and
    # then collates them
    map_and_collate = MapAndCollate(dataset, collate_fn or default_collate)

    # MapAndCollate is doing most of the heavy lifting, so let's parallelize it. We could
    # choose process or thread workers. Note that if you're not using Free-Threaded
    # Python (eg 3.13t) with -Xgil=0, then multi-threading might result in GIL contention,
    # and slow down training.
    node = tn.ParallelMapper(
        node,
        map_fn=map_and_collate,
        num_workers=num_workers,
        method="process",  # Set this to "thread" for multi-threading
        in_order=True,
    )

    # Optionally apply pin-memory, and we usually do some pre-fetching
    if pin_memory:
        node = tn.PinMemory(node)
    node = tn.Prefetcher(node, prefetch_factor=num_workers * 2)

    # Note that node is an iterator, and once it's exhausted, you'll need to call .reset()
    # on it to start a new Epoch.
    # Insteaad, we wrap the node in a Loader, which is an iterable and handles reset. It
    # also provides state_dict and load_state_dict methods.
    return tn.Loader(node)

现在让我们用一个简单的数据集来测试一下,并演示状态管理是如何工作的。

class SquaredDataset(Dataset):
    def __init__(self, len: int):
        self.len = len
    def __len__(self):
        return self.len
    def __getitem__(self, i: int) -> int:
        return i**2

loader = NodesDataLoader(
    dataset=SquaredDataset(14),
    batch_size=3,
    shuffle=False,
    num_workers=2,
    collate_fn=None,
    pin_memory=False,
    drop_last=False,
)

batches = []
for idx, batch in enumerate(loader):
    if idx == 2:
        state_dict = loader.state_dict()
        # Saves the state_dict after batch 2 has been returned
    batches.append(batch)

loader.load_state_dict(state_dict)
batches_after_loading = list(loader)
print(batches[3:])
# [tensor([ 81, 100, 121]), tensor([144, 169])]
print(batches_after_loading)
# [tensor([ 81, 100, 121]), tensor([144, 169])]

让我们也将其与 torch.utils.data.DataLoader 进行比较,作为健全性检查。

loaderv1 = torch.utils.data.DataLoader(
    dataset=SquaredDataset(14),
    batch_size=3,
    shuffle=False,
    num_workers=2,
    collate_fn=None,
    pin_memory=False,
    drop_last=False,
    persistent_workers=False,  # Coming soon to torchdata.nodes!
)
print(list(loaderv1))
# [tensor([0, 1, 4]), tensor([ 9, 16, 25]), tensor([36, 49, 64]), tensor([ 81, 100, 121]), tensor([144, 169])]
print(batches)
# [tensor([0, 1, 4]), tensor([ 9, 16, 25]), tensor([36, 49, 64]), tensor([ 81, 100, 121]), tensor([144, 169])]

IterableDatasets 数据集

即将推出!虽然您已经可以将 IterableDataset 插入到 中,但目前尚不支持某些函数,例如。然而,我们相信,分片通常在 多进程 worker 实际上不是必需的,你可以在主进程中保留某种索引,而 仅并行化一些较重的转换,类似于上面的 Map 样式 Datasets 的工作方式。tn.IterableWrapperget_worker_info

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源