注意力
2024 年 6 月状态更新:删除 DataPipes 和 DataLoader V2
我们将 torchdata 存储库重新调整为torch.utils.data.DataLoader的迭代增强。我们不打算 继续开发或维护 [DataPipes] 和 [DataLoaderV2] 解决方案,它们将从 torchdata 存储库。我们还将重新访问 pytorch/pytorch 中的 DataPipes 引用。在 torchdata==0.8.0(2024 年 7 月)版本中,它们将被标记为已弃用,而在 0.10.0(2024 年末)中,它们将被删除。现存 建议用户固定到 torchdata<=0.9.0 或更早版本,直到他们能够迁移出去。随后的 版本将不包含 DataPipes 或 DataLoaderV2。 如果您有建议或评论,请联系我们(请使用此问题进行反馈)
torchdata.nodes
(测试版)¶
- 类 torchdata.nodes 中。BaseNode(*args, **kwargs)¶
基地:
Iterator
[T
]BaseNode 是用于创建可组合数据加载 DAG 的基类。
torchdata.nodes
大多数最终用户不会直接迭代 BaseNode 实例,而是 将其包装在 a
中,这会将 DAG 转换为更熟悉的 Iterable。
node = MyBaseNodeImpl() loader = Loader(node) # loader supports state_dict() and load_state_dict() for epoch in range(5): for idx, batch in enumerate(loader): ... # or if using node directly: node = MyBaseNodeImpl() for epoch in range(5): node.reset() for idx, batch in enumerate(loader): ...
- get_state() Dict[str, Any] ¶
子类必须实现此方法,而不是 state_dict()。只能由 BaseNode 调用。 :return: Dict[str, Any] - 一个状态字典,可能在将来的某个时候传递给 reset()
- next() T ¶
子类必须实现此方法,而不是 .只能由 BaseNode 调用。 :return: T - 序列中的下一个值,或抛出 StopIteration
__next
- reset(initial_state: 可选[dict] = 无)¶
将迭代器重置为开头,或重置为 initial_state 传入的状态。
reset 是放置昂贵初始化的好地方,因为当 next() 或 state_dict() 被调用时,它将被延迟调用。 子类必须调用 .
super().reset(initial_state)
- 参数
initial_state – Optional[dict] - 要传递给节点的状态 dict。如果为 None,则重置为开头。
- state_dict() Dict[str, Any] ¶
获取此 BaseNode 的 state_dict。 :return: Dict[str, Any] - 一个状态字典,可能在将来的某个时候传递给 reset() 。
- 类 torchdata.nodes 中。Batcher(来源: BaseNode[T], batch_size: int, drop_last: bool = 真)¶
-
Batcher 节点将源节点中的数据批处理为大小为 batch_size 的批次。 如果源节点已用尽,它将返回批处理或引发 StopIteration。 如果 drop_last 为 True,则如果最后一个批次小于 batch_size,则将丢弃最后一个批次。 如果 drop_last 为 False,则即使最后一批小于 batch_size 也会返回该批次。
- 参数
source (BaseNode[T]) – 要从中批处理数据的源节点。
batch_size (int) – 批处理的大小。
drop_last (bool) – 如果最后一批小于 batch_size,是否丢弃它。默认值为 True。
- get_state() Dict[str, Any] ¶
子类必须实现此方法,而不是 state_dict()。只能由 BaseNode 调用。 :return: Dict[str, Any] - 一个状态字典,可能在将来的某个时候传递给 reset()
- next() List[T] ¶
子类必须实现此方法,而不是 .只能由 BaseNode 调用。 :return: T - 序列中的下一个值,或抛出 StopIteration
__next
- reset(initial_state: 可选[Dict[str, Any]] = 无)¶
将迭代器重置为开头,或重置为 initial_state 传入的状态。
reset 是放置昂贵初始化的好地方,因为当 next() 或 state_dict() 被调用时,它将被延迟调用。 子类必须调用 .
super().reset(initial_state)
- 参数
initial_state – Optional[dict] - 要传递给节点的状态 dict。如果为 None,则重置为开头。
- 类 torchdata.nodes 中。IterableWrapper(iterable: Iterable[T])¶
-
Thin Wrapper 将任何 Iterable(包括 torch.utils.data.IterableDataset)添加到 BaseNode 中。
如果 iterable 实现了 Stateful Protocol,它将被保存并恢复为其 state_dict/load_state_dict 方法。
- 参数
iterable (Iterable[T]) - 可迭代转换为 BaseNode。IterableWrapper 对它调用 iter()。
- 警告:
请注意 Iterable 上定义的 state_dict/load_state_dict 与 Iterator 之间的区别。 仅使用 Iterable 的 state_dict/load_state_dict。
- get_state() Dict[str, Any] ¶
子类必须实现此方法,而不是 state_dict()。只能由 BaseNode 调用。 :return: Dict[str, Any] - 一个状态字典,可能在将来的某个时候传递给 reset()
- next() T ¶
子类必须实现此方法,而不是 .只能由 BaseNode 调用。 :return: T - 序列中的下一个值,或抛出 StopIteration
__next
- reset(initial_state: 可选[Dict[str, Any]] = 无)¶
将迭代器重置为开头,或重置为 initial_state 传入的状态。
reset 是放置昂贵初始化的好地方,因为当 next() 或 state_dict() 被调用时,它将被延迟调用。 子类必须调用 .
super().reset(initial_state)
- 参数
initial_state – Optional[dict] - 要传递给节点的状态 dict。如果为 None,则重置为开头。
- 类 torchdata.nodes 中。Loader(根: BaseNode[T], restart_on_stop_iteration: bool = True)¶
基地:
Generic
[T
]包装根 BaseNode (迭代器) 并提供有状态的可迭代接口。
最后返回的迭代器的状态由 state_dict() 方法返回,并且可以是 使用 load_state_dict() 方法加载。
- 参数
root (BaseNode[T]) – 数据管道的根节点。
restart_on_stop_iteration (bool) – 是否在迭代器到达末尾时重新启动迭代器。默认值为 True
- load_state_dict(state_dict: dict[str, any])¶
加载一个 state_dict,该 将用于初始化请求的下一个 iter() 从这个加载程序。
- 参数
state_dict (Dict[str, Any]) – 要加载的state_dict。应通过调用 state_dict() 生成。
- state_dict() Dict[str, Any] ¶
返回一个 state_dict,该 将来可以传递给 load_state_dict() resume 迭代。
state_dict将来自最近一次调用 iter() 返回的迭代器。 如果未创建迭代器,则将创建一个新的迭代器,并从中返回state_dict。
- torchdata.nodes 中。MapStyleWrapper (map_dataset: 映射 [K, T], 采样器: 采样器 [K]) 基节点[T] ¶
将任何 MapDataset 转换为 torchdata.node 的 Thin Wrapper 如果需要并行性,请复制此文件并将 Mapper 替换为 ParallelMapper。
- 参数
map_dataset (映射 [K, T]) –
将 map_dataset.__getitem__ 应用于 sampler 的输出。
采样器 (Sampler[K]) –
- torchdata.nodes 中。Mapper(来源: BaseNode[X], map_fn: Callable[[X], T]) ParallelMapper[T] ¶
返回 num_workers=0 的节点,该节点将在当前进程/线程中执行 map_fn。
- 参数
source (BaseNode[X]) – 要映射的源节点。
map_fn (Callable[[X], T]) – 要应用于源节点中每个项目的函数。
- 类 torchdata.nodes 中。MultiNodeWeightedSampler(source_nodes: 映射[str, BaseNode[T]], 权重: Dict[str, float], stop_criteria: str = 'CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED', rank: 可选[int] = 无,world_size:可选[int] = 无,种子:int = 0)¶
-
一个从具有权重的多个数据集中采样的节点。
此节点需要接收源节点的字典和权重的字典。 源节点和权重的 key 必须相同。权重用于采样 从源节点。我们使用 torch.multinomial 从源节点中采样,请 有关如何使用的信息,请参阅 https://pytorch.org/docs/stable/generated/torch.multinomial.html weights 进行采样。seed 用于初始化随机数生成器。
节点使用以下键实现状态: - DATASET_NODE_STATES_KEY:每个源节点的状态字典。 - DATASETS_EXHAUSTED_KEY:一个布尔值字典,指示每个源节点是否已用尽。 - EPOCH_KEY:用于初始化随机数生成器的纪元计数器。 - NUM_YIELDED_KEY:生成的项数。 - WEIGHTED_SAMPLER_STATE_KEY:加权采样器的状态。
我们支持多个停止标准: - CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED:循环访问源节点,直到所有数据集都用完。这是默认行为。 - FIRST_DATASET_EXHAUSTED:当第一个数据集用完时停止。 - ALL_DATASETS_EXHAUSTED:当所有数据集都用完时停止。
在源节点完全耗尽时,该节点将引发 StopIteration。
- 参数
source_nodes (Mapping[str, BaseNode[T]]) – 源节点的字典。
weights (Dict[str, float]) – 每个源节点的权重字典。
stop_criteria (str) – 停止条件。默认值为 CYCLE_UNTIL_ALL_DATASETS_EXHAUST
rank (int) (排名) – 当前进程的排名。默认值为 None,在这种情况下,排名 将从分布式环境中获取。
world_size (int) – 分布式环境的世界大小。默认值为 None,在 在这种情况下,将从分布式环境中获取世界大小。
seed (int) – 随机数生成器的种子。默认值为 0。
- get_state() Dict[str, Any] ¶
子类必须实现此方法,而不是 state_dict()。只能由 BaseNode 调用。 :return: Dict[str, Any] - 一个状态字典,可能在将来的某个时候传递给 reset()
- next() T ¶
子类必须实现此方法,而不是 .只能由 BaseNode 调用。 :return: T - 序列中的下一个值,或抛出 StopIteration
__next
- reset(initial_state: 可选[Dict[str, Any]] = 无)¶
将迭代器重置为开头,或重置为 initial_state 传入的状态。
reset 是放置昂贵初始化的好地方,因为当 next() 或 state_dict() 被调用时,它将被延迟调用。 子类必须调用 .
super().reset(initial_state)
- 参数
initial_state – Optional[dict] - 要传递给节点的状态 dict。如果为 None,则重置为开头。
- 类 torchdata.nodes 中。ParallelMapper(来源: BaseNode[X], map_fn: Callable[[X], T], num_workers: int, in_order: bool = True, method: Literal['thread'、 'process'] = 'thread', multiprocessing_context: 可选[str] = 无, max_concurrent: 可选 [int] = 无,snapshot_frequency: int = 1)¶
-
ParallelMapper 在 num_workers 线程中并行执行 map_fn 或 过程。对于进程,multiprocessing_context可以是 spawn、forkserver、fork、 或 None(选择操作系统默认值)。最多将处理 max_concurrent 项 或在迭代器的输出队列中,以限制 CPU 和内存利用率。如果没有 (默认)值为 2 * num_workers。
最多一个 iter() 是从源创建的,最多一个线程将调用 next() 的 intent 值。
如果 in_order 为 true,则迭代器将按项目到达的顺序返回项目 from source 的迭代器,即使其他项目可用,也可能阻止。
- 参数
source (BaseNode[X]) – 要映射的源节点。
map_fn (Callable[[X], T]) – 要应用于源节点中每个项目的函数。
num_workers (int) – 用于并行处理的工作线程数。
in_order (bool) – 是否按项目到达的顺序返回项目。默认值为 True。
method (Literal[“thread”, “process”]) – 用于并行处理的方法。默认值为 “thread”。
multiprocessing_context (Optional[str]) – 用于并行处理的多处理上下文。默认值为 None。
max_concurrent (Optional[int]) – 一次要处理的最大项目数。默认值为 None。
snapshot_frequency (int) – 对源节点的状态进行快照的频率。默认值为 1。
- get_state() Dict[str, Any] ¶
子类必须实现此方法,而不是 state_dict()。只能由 BaseNode 调用。 :return: Dict[str, Any] - 一个状态字典,可能在将来的某个时候传递给 reset()
- 下一个()¶
子类必须实现此方法,而不是 .只能由 BaseNode 调用。 :return: T - 序列中的下一个值,或抛出 StopIteration
__next
- reset(initial_state: 可选[Dict[str, Any]] = 无)¶
将迭代器重置为开头,或重置为 initial_state 传入的状态。
reset 是放置昂贵初始化的好地方,因为当 next() 或 state_dict() 被调用时,它将被延迟调用。 子类必须调用 .
super().reset(initial_state)
- 参数
initial_state – Optional[dict] - 要传递给节点的状态 dict。如果为 None,则重置为开头。
- 类 torchdata.nodes 中。PinMemory(来源:BaseNode[T],pin_memory_device:str = '',snapshot_frequency: int = 1)¶
-
将底层节点的数据固定到设备。这由 torch.utils.data._utils.pin_memory._pin_memory_loop 提供支持。
- 参数
source (BaseNode[T]) – 要从中固定数据的源节点。
pin_memory_device (str) – 要将数据固定到的设备。默认值为 “”。
snapshot_frequency (int) – 对源节点的状态进行快照的频率。默认值为 1,这意味着源节点的状态将在每个项目之后进行快照。如果已设置 设置为更高的值,则源节点的状态将在每 snapshot_frequency 后进行快照 项目。
- get_state() Dict[str, Any] ¶
子类必须实现此方法,而不是 state_dict()。只能由 BaseNode 调用。 :return: Dict[str, Any] - 一个状态字典,可能在将来的某个时候传递给 reset()
- 下一个()¶
子类必须实现此方法,而不是 .只能由 BaseNode 调用。 :return: T - 序列中的下一个值,或抛出 StopIteration
__next
- reset(initial_state: 可选[Dict[str, Any]] = 无)¶
将迭代器重置为开头,或重置为 initial_state 传入的状态。
reset 是放置昂贵初始化的好地方,因为当 next() 或 state_dict() 被调用时,它将被延迟调用。 子类必须调用 .
super().reset(initial_state)
- 参数
initial_state – Optional[dict] - 要传递给节点的状态 dict。如果为 None,则重置为开头。
- 类 torchdata.nodes 中。预取器(源:BaseNode[T],prefetch_factor:int,snapshot_frequency:int = 1)¶
-
从源节点预取数据并将其存储在队列中。
- 参数
source (BaseNode[T]) – 要从中预取数据的源节点。
prefetch_factor (int) – 要提前预取的项目数。
snapshot_frequency (int) – 对源节点的状态进行快照的频率。默认值为 1,这意味着源节点的状态将在每个项目之后进行快照。如果已设置 设置为更高的值,则源节点的状态将在每 snapshot_frequency 后进行快照 项目。
- get_state() Dict[str, Any] ¶
子类必须实现此方法,而不是 state_dict()。只能由 BaseNode 调用。 :return: Dict[str, Any] - 一个状态字典,可能在将来的某个时候传递给 reset()
- 下一个()¶
子类必须实现此方法,而不是 .只能由 BaseNode 调用。 :return: T - 序列中的下一个值,或抛出 StopIteration
__next
- reset(initial_state: 可选[Dict[str, Any]] = 无)¶
将迭代器重置为开头,或重置为 initial_state 传入的状态。
reset 是放置昂贵初始化的好地方,因为当 next() 或 state_dict() 被调用时,它将被延迟调用。 子类必须调用 .
super().reset(initial_state)
- 参数
initial_state – Optional[dict] - 要传递给节点的状态 dict。如果为 None,则重置为开头。
- 类 torchdata.nodes 中。SamplerWrapper(sampler: Sampler[T], initial_epoch: int = 0, epoch_updater: 可选[Callable[[int], int]] = 无)¶
-
将采样器转换为 BaseNode。这与 IterableWrapper 的 Wrapper 中,但它包含一个用于在采样器上调用 set_epoch 的钩子, 如果它支持它。
- 参数
sampler (Sampler) – 要包装的采样器。
initial_epoch (int) – 采样器上设置的初始纪元
epoch_updater (Optional[Callable[[int], int]] = None) – 在新迭代开始时更新 epoch 的回调。它在每个迭代器请求的开头调用,第一个请求除外。
- get_state() Dict[str, Any] ¶
子类必须实现此方法,而不是 state_dict()。只能由 BaseNode 调用。 :return: Dict[str, Any] - 一个状态字典,可能在将来的某个时候传递给 reset()
- next() T ¶
子类必须实现此方法,而不是 .只能由 BaseNode 调用。 :return: T - 序列中的下一个值,或抛出 StopIteration
__next
- reset(initial_state: 可选[Dict[str, Any]] = 无)¶
将迭代器重置为开头,或重置为 initial_state 传入的状态。
reset 是放置昂贵初始化的好地方,因为当 next() 或 state_dict() 被调用时,它将被延迟调用。 子类必须调用 .
super().reset(initial_state)
- 参数
initial_state – Optional[dict] - 要传递给节点的状态 dict。如果为 None,则重置为开头。
- 类 torchdata.nodes 中。Stateful(*args, **kwargs)¶
基地:
Protocol
实现 和 的对象协议
state_dict()
load_state_dict(state_dict: Dict[str, Any])
- 类 torchdata.nodes 中。StopCriteria¶
基地:
object
数据集采样器的停止条件。
CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED:在最后一个未看到的数据集用尽后停止。 所有数据集至少可见一次。在某些情况下,某些数据集可能是 当仍有未穷尽的数据集时,会多次看到。
ALL_DATASETS_EXHAUSTED:所有数据集都用完后停止。每 数据集只看到一次。不会执行环绕或重新启动。
FIRST_DATASET_EXHAUSTED:当第一个数据集用完时停止。