目录

注意力

2024年6月更新:移除DataPipes和DataLoader V2

我们正在重新聚焦torchdata仓库,使其成为torch.utils.data.DataLoader的迭代增强。我们不计划继续开发或维护[DataPipes]和[DataLoaderV2]解决方案,并且它们将从torchdata仓库中移除。我们还将重新审视pytorch/pytorch中的DataPipes引用。在发布torchdata==0.8.0(2024年7月)版本中,它们将被标记为已弃用,并在0.10.0(2024年末)版本中被删除。现有用户建议固定到torchdata<=0.9.0或更早版本,直到他们能够迁移为止。后续版本将不再包含DataPipes或DataLoaderV2。如果您有任何建议或意见,请通过此问题反馈。

torchdata.nodes (beta)

class torchdata.nodes.BaseNode(*args, **kwargs)

Bases: Iterator[T]

基础节点是创建可组合数据加载DAG的基础类在torchdata.nodes

大多数最终用户不会直接迭代 BaseNode 实例,而是将其包装在一个 torchdata.nodes.Loader 中,该实例将 DAG 转换为更熟悉的可迭代对象。

node = MyBaseNodeImpl()
loader = Loader(node)
# loader supports state_dict() and load_state_dict()

for epoch in range(5):
    for idx, batch in enumerate(loader):
        ...

# or if using node directly:
node = MyBaseNodeImpl()
for epoch in range(5):
    node.reset()
    for idx, batch in enumerate(loader):
        ...
get_state() Dict[str, Any]

子类必须实现此方法,而不是state_dict()。仅在BaseNode中调用。

Returns:

Dict[str, Any] - 一个状态字典,可能在将来某个时间点传递给 reset()

next() T

子类必须实现此方法,而不是__next__。仅在BaseNode中调用。

Returns:

T - 序列中的下一个值,或者抛出 StopIteration

reset(initial_state: Optional[dict] = None)

将迭代器重置到开始位置,或者重置为初始状态传递的值。

重置是一个放置昂贵初始化的好地方,因为它将在调用 next()state_dict() 时懒惰地被调用。 子类必须调用 super().reset(initial_state)

Parameters:

初始状态 – Optional[字典] – 一个传递给节点的状态字典。如果为None,则重置为开始。

state_dict() Dict[str, Any]

获取这个基节点的状态字典。

Returns:

Dict[str, Any] - 一个状态字典,可能在将来某个时刻传递给 reset()

class torchdata.nodes.Batcher(source: BaseNode[T], batch_size: int, drop_last: bool = True)

Bases: BaseNode[List[T]]

批次节点将从源节点的数据批量为 batch_size 大小的批次。 如果源节点耗尽,它将返回批次或抛出 StopIteration 异常。 如果 drop_last 为 True,则在批次大小小于 batch_size 的情况下丢弃最后一个批次。 如果 drop_last 为 False,则即使批次大小小于 batch_size,也会返回最后一个批次。

Parameters:
  • (BaseNode[T]) – 用于从其批量获取数据的源节点。

  • 批量大小 (整数) – 批次的大小。

  • drop_last (bool) – 是否在批次大小小于 batch_size 时丢弃最后一个批次。默认值为 True。

get_state() Dict[str, Any]

子类必须实现此方法,而不是state_dict()。仅在BaseNode中调用。

Returns:

Dict[str, Any] - 一个状态字典,可能在将来某个时间点传递给 reset()

next() List[T]

子类必须实现此方法,而不是__next__。仅在BaseNode中调用。

Returns:

T - 序列中的下一个值,或者抛出 StopIteration

reset(initial_state: Optional[Dict[str, Any]] = None)

将迭代器重置到开始位置,或者重置为初始状态传递的值。

重置是一个放置昂贵初始化的好地方,因为它将在调用 next()state_dict() 时懒惰地被调用。 子类必须调用 super().reset(initial_state)

Parameters:

初始状态 – Optional[字典] – 一个传递给节点的状态字典。如果为None,则重置为开始。

class torchdata.nodes.IterableWrapper(iterable: Iterable[T])

Bases: BaseNode[T]

一个薄的包装器,将任何可迭代(包括torch.utils.data.IterableDataset)转换为BaseNode。

如果可迭代对象实现了 Stateful 协议,它将通过其 state_dict/load_state_dict 方法进行保存和恢复。

Parameters:

可迭代 (可迭代[T]) – 可迭代转换为 BaseNode。IterableWrapper 调用 iter() 在它上面。

Warning:

注意区分在 Iterable 上定义的 state_dict/load_state_dict 与 Iterator。 只有 Iterable 的 state_dict/load_state_dict 被使用。

get_state() Dict[str, Any]

子类必须实现此方法,而不是state_dict()。仅在BaseNode中调用。

Returns:

Dict[str, Any] - 一个状态字典,可能在将来某个时间点传递给 reset()

next() T

子类必须实现此方法,而不是__next__。仅在BaseNode中调用。

Returns:

T - 序列中的下一个值,或者抛出 StopIteration

reset(initial_state: Optional[Dict[str, Any]] = None)

将迭代器重置到开始位置,或者重置为初始状态传递的值。

重置是一个放置昂贵初始化的好地方,因为它将在调用 next()state_dict() 时懒惰地被调用。 子类必须调用 super().reset(initial_state)

Parameters:

初始状态 – Optional[字典] – 一个传递给节点的状态字典。如果为None,则重置为开始。

class torchdata.nodes.Loader(root: BaseNode[T], restart_on_stop_iteration: bool = True)

Bases: Generic[T]

将根节点(迭代器)包裹起来,并提供一个状态化的可迭代接口。

返回迭代器的最后状态由state_dict()方法返回,并且可以使用load_state_dict()方法加载。

Parameters:
  • (BaseNode[T]) – 数据管道的根节点。

  • 重启在停止迭代时 (布尔值) – 是否在到达末尾时重新启动迭代器。默认值为True

load_state_dict(state_dict: Dict[str, Any])

加载一个state_dict,它将用于初始化从这个加载器请求的下一个iter()。

Parameters:

state_dict (Dict[str, Any]) – 从 state_dict() 调用生成的状态字典。

state_dict() Dict[str, Any]

返回一个状态字典,可以在将来通过传递给 load_state_dict() 来恢复迭代。

状态字典将来自最近一次调用 iter() 返回的迭代器。 如果尚未创建迭代器,将创建一个新的迭代器并返回其状态字典。

torchdata.nodes.MapStyleWrapper(map_dataset: Mapping[K, T], sampler: Sampler[K]) BaseNode[T]

薄的包装器,将任何 MapDataset 转换为 torchdata.node 如果你想实现并行性,请复制这个,并将 Mapper 替换为 ParallelMapper。

Parameters:
  • map_dataset (Mapping[K, T]) –

    • 将 map_dataset.__getitem__ 应用于采样器的输出。

  • 采样器 (采样器[K]) –

torchdata.nodes.Mapper(source: BaseNode[X], map_fn: Callable[[X], T]) ParallelMapper[T]

返回一个 ParallelMapper 节点,其中 num_workers=0,将在当前进程/线程中执行 map_fn。

Parameters:
  • (BaseNode[X]) – 要映射的源节点。

  • map_fn (Callable[[X], T]) – 应用到每个源节点项上的函数。

class torchdata.nodes.MultiNodeWeightedSampler(source_nodes: Mapping[str, BaseNode[T]], weights: Dict[str, float], stop_criteria: str = 'CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED', rank: Optional[int] = None, world_size: Optional[int] = None, seed: int = 0)

Bases: BaseNode[T]

一个从多个数据集采样并带有权重的节点。

此节点期望接收一个源节点的字典和一个权重的字典。 源节点和权重的键必须相同。权重用于从源节点中采样。 我们使用 torch.multinomial 从源节点中采样,请参阅 https://pytorch.org/docs/stable/generated/torch.multinomial.html 以了解如何使用权重进行采样。 seed 用于初始化随机数生成器。

该节点使用以下键来实现状态:

  • DATASET_NODE_STATES_KEY: 每个源节点的状态字典。

  • 数据集耗尽键:一个字典,其中包含布尔值,表示每个源节点是否耗尽。

  • EPOCH_KEY: 一个用于初始化随机数生成器的迭代计数器。

  • 生成的键:生成的项的数量。

  • 加权采样器状态键:加权采样器的状态。

我们支持多种停止条件:

  • 循环直到所有数据集耗尽:循环通过源节点,直到所有数据集耗尽。这是默认行为。

  • 当第一个数据集耗尽时停止。

  • 所有数据集耗尽:当所有数据集耗尽时停止。

当源节点完全耗尽时,该节点将引发 StopIteration 异常。

Parameters:
  • source_nodes (Mapping[str, BaseNode[T]]) – 一个源节点的字典。

  • 权重 (Dict[字符串, 浮点数]) – 一个包含每个源节点权重的字典。

  • 停止标准 (字符串) – 停止标准。默认值为CYCLE_UNTIL_ALL_DATASETS_EXHAUST

  • 排名 (整数) – 当前进程的排名。默认值为None,在这种情况下,将从分布式环境中获取排名。

  • world_size (int) – 世界环境的分布式大小。默认值为None,在这种情况下,将从分布式环境中获取世界大小。

  • 种子 (整数) – 随机数生成器的种子。默认值为0。

get_state() Dict[str, Any]

子类必须实现此方法,而不是state_dict()。仅在BaseNode中调用。

Returns:

Dict[str, Any] - 一个状态字典,可能在将来某个时间点传递给 reset()

next() T

子类必须实现此方法,而不是__next__。仅在BaseNode中调用。

Returns:

T - 序列中的下一个值,或者抛出 StopIteration

reset(initial_state: Optional[Dict[str, Any]] = None)

将迭代器重置到开始位置,或者重置为初始状态传递的值。

重置是一个放置昂贵初始化的好地方,因为它将在调用 next()state_dict() 时懒惰地被调用。 子类必须调用 super().reset(initial_state)

Parameters:

初始状态 – Optional[字典] – 一个传递给节点的状态字典。如果为None,则重置为开始。

class torchdata.nodes.ParallelMapper(source: BaseNode[X], map_fn: Callable[[X], T], num_workers: int, in_order: bool = True, method: Literal['thread', 'process'] = 'thread', multiprocessing_context: Optional[str] = None, max_concurrent: Optional[int] = None, snapshot_frequency: int = 1, prebatch: Optional[int] = None)

Bases: BaseNode[T]

ParallelMapper 并行执行 map_fn,要么在 num_workers 个线程中运行,要么在进程中运行。对于进程,multiprocessing_context 可以是 spawn、forkserver、fork 或 None(选择操作系统默认值)。最多 max_concurrent 个任务将被处理或存放在迭代器的输出队列中,以限制 CPU 和内存使用量。如果为 None(默认),则该值将设置为 2 * num_workers。

最多只有一个 iter() 从源创建,且最多只有一个线程同时调用 next()。

如果在_order 是 true,迭代器将按它们到达源迭代器的顺序返回项目,即使其他项目可用也可能阻塞。

Parameters:
  • (BaseNode[X]) – 要映射的源节点。

  • map_fn (Callable[[X], T]) – 应用到每个源节点项上的函数。

  • num_workers (int) – 使用进行并行处理的工人数量。

  • 按顺序 (bool) – 是否返回从到达顺序中获取的项目。默认值为True。

  • 方法 (Literal["线程", "进程"]) – 使用进行并行处理的方法。默认是“线程”。

  • multiprocessing_context (Optional[str]) – 使用进行并行处理的多进程上下文。默认值为None。

  • 最大并发 (可选[整数]) – 一次处理的最大项目数量。默认值为None。

  • 快照频率 (整数) – 源节点状态的快照频率。默认值为1。

  • 预批量 (可选[整数]) – 可选地在映射之前对源中的项目进行预批量处理。 对于较小的项目,这可能会以牺牲峰值内存为代价来提高吞吐量。

get_state() Dict[str, Any]

子类必须实现此方法,而不是state_dict()。仅在BaseNode中调用。

Returns:

Dict[str, Any] - 一个状态字典,可能在将来某个时间点传递给 reset()

next() T

子类必须实现此方法,而不是__next__。仅在BaseNode中调用。

Returns:

T - 序列中的下一个值,或者抛出 StopIteration

reset(initial_state: Optional[Dict[str, Any]] = None)

将迭代器重置到开始位置,或者重置为初始状态传递的值。

重置是一个放置昂贵初始化的好地方,因为它将在调用 next()state_dict() 时懒惰地被调用。 子类必须调用 super().reset(initial_state)

Parameters:

初始状态 – Optional[字典] – 一个传递给节点的状态字典。如果为None,则重置为开始。

class torchdata.nodes.PinMemory(source: BaseNode[T], pin_memory_device: str = '', snapshot_frequency: int = 1)

Bases: BaseNode[T]

将底层节点的数据绑定到设备上。这由 torch.utils.data._utils.pin_memory._pin_memory_loop 支持。

Parameters:
  • (BaseNode[T]) – 要固定数据的源节点。

  • pin_memory_device (str) – 将数据绑定到的设备。默认值为空。

  • 快照频率 (整数) – 源节点状态的快照频率。默认值为 1,这意味着源节点的状态将在每个项目后进行快照。如果设置为更高的值,源节点的状态将每 snapshot_frequency 个项目进行快照。

get_state() Dict[str, Any]

子类必须实现此方法,而不是state_dict()。仅在BaseNode中调用。

Returns:

Dict[str, Any] - 一个状态字典,可能在将来某个时间点传递给 reset()

next()

子类必须实现此方法,而不是__next__。仅在BaseNode中调用。

Returns:

T - 序列中的下一个值,或者抛出 StopIteration

reset(initial_state: Optional[Dict[str, Any]] = None)

将迭代器重置到开始位置,或者重置为初始状态传递的值。

重置是一个放置昂贵初始化的好地方,因为它将在调用 next()state_dict() 时懒惰地被调用。 子类必须调用 super().reset(initial_state)

Parameters:

初始状态 – Optional[字典] – 一个传递给节点的状态字典。如果为None,则重置为开始。

class torchdata.nodes.Prefetcher(source: BaseNode[T], prefetch_factor: int, snapshot_frequency: int = 1)

Bases: BaseNode[T]

从源节点预取数据并将其存储在队列中。

Parameters:
  • (BaseNode[T]) – 用于预取数据的源节点。

  • prefetch_factor (int) – 预先缓存的项目数量。

  • 快照频率 (整数) – 源节点状态的快照频率。默认值为 1,这意味着源节点的状态将在每个项目后进行快照。如果设置为更高的值,源节点的状态将每 snapshot_frequency 个项目进行快照。

get_state() Dict[str, Any]

子类必须实现此方法,而不是state_dict()。仅在BaseNode中调用。

Returns:

Dict[str, Any] - 一个状态字典,可能在将来某个时间点传递给 reset()

next()

子类必须实现此方法,而不是__next__。仅在BaseNode中调用。

Returns:

T - 序列中的下一个值,或者抛出 StopIteration

reset(initial_state: Optional[Dict[str, Any]] = None)

将迭代器重置到开始位置,或者重置为初始状态传递的值。

重置是一个放置昂贵初始化的好地方,因为它将在调用 next()state_dict() 时懒惰地被调用。 子类必须调用 super().reset(initial_state)

Parameters:

初始状态 – Optional[字典] – 一个传递给节点的状态字典。如果为None,则重置为开始。

class torchdata.nodes.SamplerWrapper(sampler: Sampler[T], initial_epoch: int = 0, epoch_updater: Optional[Callable[[int], int]] = None)

Bases: BaseNode[T]

将采样器转换为BaseNode。这几乎与IterableWrapper相同,除了它包括一个钩子,在支持的情况下调用set_epoch方法。

Parameters:
  • 抽样器 (Sampler) – 用于包装的抽样器。

  • 初始轮次 (整数) – 在采样器上设置的初始轮次

  • epoch_updater (Optional[Callable[[int], int]] = None) – callback to update epoch at start of new iteration. It’s called at the beginning of each iterator request, except the first one.

get_state() Dict[str, Any]

子类必须实现此方法,而不是state_dict()。仅在BaseNode中调用。

Returns:

Dict[str, Any] - 一个状态字典,可能在将来某个时间点传递给 reset()

next() T

子类必须实现此方法,而不是__next__。仅在BaseNode中调用。

Returns:

T - 序列中的下一个值,或者抛出 StopIteration

reset(initial_state: Optional[Dict[str, Any]] = None)

将迭代器重置到开始位置,或者重置为初始状态传递的值。

重置是一个放置昂贵初始化的好地方,因为它将在调用 next()state_dict() 时懒惰地被调用。 子类必须调用 super().reset(initial_state)

Parameters:

初始状态 – Optional[字典] – 一个传递给节点的状态字典。如果为None,则重置为开始。

class torchdata.nodes.Stateful(*args, **kwargs)

Bases: Protocol

协议对象,既支持 state_dict() 又支持 load_state_dict(state_dict: Dict[str, Any])

class torchdata.nodes.StopCriteria

Bases: object

数据采样器的停止条件。

  1. 直到所有数据集耗尽:在最后一个未见过的数据集耗尽时停止。 所有数据集至少被看到一次。在某些情况下,当还有未耗尽的数据集时,某些数据集可能会被看到多次。

  2. 所有数据集耗尽:当所有数据集都耗尽时停止。每个数据集仅被看到一次。不会进行循环或重新开始。

  3. 当第一个数据集耗尽时停止。

class torchdata.nodes.Unbatcher(source: BaseNode[Sequence[T]])

Bases: BaseNode[T]

去批次化器将从源拉取的批次展平,并在调用 next() 时按顺序输出元素。

Parameters:

(BaseNode[T]) – 从中提取批次的源节点。

get_state() Dict[str, Any]

子类必须实现此方法,而不是state_dict()。仅在BaseNode中调用。

Returns:

Dict[str, Any] - 一个状态字典,可能在将来某个时间点传递给 reset()

next() T

子类必须实现此方法,而不是__next__。仅在BaseNode中调用。

Returns:

T - 序列中的下一个值,或者抛出 StopIteration

reset(initial_state: Optional[Dict[str, Any]] = None)

将迭代器重置到开始位置,或者重置为初始状态传递的值。

重置是一个放置昂贵初始化的好地方,因为它将在调用 next()state_dict() 时懒惰地被调用。 子类必须调用 super().reset(initial_state)

Parameters:

初始状态 – Optional[字典] – 一个传递给节点的状态字典。如果为None,则重置为开始。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源