注意力
2024年6月更新:移除DataPipes和DataLoader V2
我们正在重新聚焦torchdata仓库,使其成为torch.utils.data.DataLoader的迭代增强。我们不计划继续开发或维护[DataPipes]和[DataLoaderV2]解决方案,并且它们将从torchdata仓库中移除。我们还将重新审视pytorch/pytorch中的DataPipes引用。在发布torchdata==0.8.0(2024年7月)版本中,它们将被标记为已弃用,并在0.10.0(2024年末)版本中被删除。现有用户建议固定到torchdata<=0.9.0或更早版本,直到他们能够迁移为止。后续版本将不再包含DataPipes或DataLoaderV2。如果您有任何建议或意见,请通过此问题反馈。
torchdata.nodes (beta)¶
- class torchdata.nodes.BaseNode(*args, **kwargs)¶
Bases:
Iterator[T]基础节点是创建可组合数据加载DAG的基础类在
torchdata.nodes。大多数最终用户不会直接迭代 BaseNode 实例,而是将其包装在一个
torchdata.nodes.Loader中,该实例将 DAG 转换为更熟悉的可迭代对象。node = MyBaseNodeImpl() loader = Loader(node) # loader supports state_dict() and load_state_dict() for epoch in range(5): for idx, batch in enumerate(loader): ... # or if using node directly: node = MyBaseNodeImpl() for epoch in range(5): node.reset() for idx, batch in enumerate(loader): ...
- get_state() Dict[str, Any]¶
子类必须实现这个方法,而不是 state_dict()。只应在 BaseNode 中调用。 :return: Dict[str, Any] - 一个可能在将来某个时间点通过 reset() 方法传递的状态字典
- next() T¶
子类必须实现这个方法,而不是
__next。只应由基节点调用。 :return: T - 序列中的下一个值,或者抛出StopIteration
- reset(initial_state: Optional[dict] = None)¶
将迭代器重置到开始位置,或者重置为初始状态传递的值。
重置是一个放置昂贵初始化的好地方,因为它将在调用 next() 或 state_dict() 时懒惰地被调用。 子类必须调用
super().reset(initial_state)。- Parameters:
初始状态 – Optional[字典] – 一个传递给节点的状态字典。如果为None,则重置为开始。
- state_dict() Dict[str, Any]¶
获取这个基节点的状态字典。 :return: Dict[str, Any] - 一个状态字典,可能在将来某个时间点通过reset()方法传递。
- class torchdata.nodes.Batcher(source: BaseNode[T], batch_size: int, drop_last: bool = True)¶
Bases:
BaseNode[List[T]]批次节点将从源节点的数据批量为 batch_size 大小的批次。 如果源节点耗尽,它将返回批次或抛出 StopIteration 异常。 如果 drop_last 为 True,则在批次大小小于 batch_size 的情况下丢弃最后一个批次。 如果 drop_last 为 False,则即使批次大小小于 batch_size,也会返回最后一个批次。
- Parameters:
源 (BaseNode[T]) – 用于从其批量获取数据的源节点。
批量大小 (整数) – 批次的大小。
drop_last (bool) – 是否在批次大小小于 batch_size 时丢弃最后一个批次。默认值为 True。
- get_state() Dict[str, Any]¶
子类必须实现这个方法,而不是 state_dict()。只应在 BaseNode 中调用。 :return: Dict[str, Any] - 一个可能在将来某个时间点通过 reset() 方法传递的状态字典
- next() List[T]¶
子类必须实现这个方法,而不是
__next。只应由基节点调用。 :return: T - 序列中的下一个值,或者抛出StopIteration
- reset(initial_state: Optional[Dict[str, Any]] = None)¶
将迭代器重置到开始位置,或者重置为初始状态传递的值。
重置是一个放置昂贵初始化的好地方,因为它将在调用 next() 或 state_dict() 时懒惰地被调用。 子类必须调用
super().reset(initial_state)。- Parameters:
初始状态 – Optional[字典] – 一个传递给节点的状态字典。如果为None,则重置为开始。
- class torchdata.nodes.IterableWrapper(iterable: Iterable[T])¶
Bases:
BaseNode[T]一个薄的包装器,将任何可迭代(包括torch.utils.data.IterableDataset)转换为BaseNode。
如果可迭代对象实现了 Stateful 协议,它将通过其 state_dict/load_state_dict 方法进行保存和恢复。
- Parameters:
可迭代 (可迭代[T]) – 可迭代转换为 BaseNode。IterableWrapper 调用 iter() 在它上面。
- Warning:
注意区分在 Iterable 上定义的 state_dict/load_state_dict 与 Iterator。 只有 Iterable 的 state_dict/load_state_dict 被使用。
- get_state() Dict[str, Any]¶
子类必须实现这个方法,而不是 state_dict()。只应在 BaseNode 中调用。 :return: Dict[str, Any] - 一个可能在将来某个时间点通过 reset() 方法传递的状态字典
- next() T¶
子类必须实现这个方法,而不是
__next。只应由基节点调用。 :return: T - 序列中的下一个值,或者抛出StopIteration
- reset(initial_state: Optional[Dict[str, Any]] = None)¶
将迭代器重置到开始位置,或者重置为初始状态传递的值。
重置是一个放置昂贵初始化的好地方,因为它将在调用 next() 或 state_dict() 时懒惰地被调用。 子类必须调用
super().reset(initial_state)。- Parameters:
初始状态 – Optional[字典] – 一个传递给节点的状态字典。如果为None,则重置为开始。
- class torchdata.nodes.Loader(root: BaseNode[T], restart_on_stop_iteration: bool = True)¶
Bases:
Generic[T]将根节点(迭代器)包裹起来,并提供一个状态化的可迭代接口。
返回迭代器的最后状态由state_dict()方法返回,并且可以使用load_state_dict()方法加载。
- Parameters:
根 (BaseNode[T]) – 数据管道的根节点。
重启在停止迭代时 (布尔值) – 是否在到达末尾时重新启动迭代器。默认值为True
- load_state_dict(state_dict: Dict[str, Any])¶
加载一个state_dict,它将用于初始化从这个加载器请求的下一个iter()。
- Parameters:
state_dict (Dict[str, Any]) – 从 state_dict() 调用生成的状态字典。
- state_dict() Dict[str, Any]¶
返回一个状态字典,可以在将来通过传递给 load_state_dict() 来恢复迭代。
状态字典将来自最近一次调用 iter() 返回的迭代器。 如果尚未创建迭代器,将创建一个新的迭代器并返回其状态字典。
- torchdata.nodes.MapStyleWrapper(map_dataset: Mapping[K, T], sampler: Sampler[K]) BaseNode[T]¶
薄的包装器,将任何 MapDataset 转换为 torchdata.node 如果你想实现并行性,请复制这个,并将 Mapper 替换为 ParallelMapper。
- Parameters:
map_dataset (Mapping[K, T]) –
将 map_dataset.__getitem__ 应用于采样器的输出。
采样器 (采样器[K]) –
- torchdata.nodes.Mapper(source: BaseNode[X], map_fn: Callable[[X], T]) ParallelMapper[T]¶
返回一个
ParallelMapper节点,其中 num_workers=0,将在当前进程/线程中执行 map_fn。- Parameters:
源 (BaseNode[X]) – 要映射的源节点。
map_fn (Callable[[X], T]) – 应用到每个源节点项上的函数。
- class torchdata.nodes.MultiNodeWeightedSampler(source_nodes: Mapping[str, BaseNode[T]], weights: Dict[str, float], stop_criteria: str = 'CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED', rank: Optional[int] = None, world_size: Optional[int] = None, seed: int = 0)¶
Bases:
BaseNode[T]一个从多个数据集采样并带有权重的节点。
此节点期望接收一个源节点的字典和一个权重的字典。 源节点和权重的键必须相同。权重用于从源节点中采样。 我们使用 torch.multinomial 从源节点中采样,请参阅 https://pytorch.org/docs/stable/generated/torch.multinomial.html 以了解如何使用权重进行采样。 seed 用于初始化随机数生成器。
该节点使用以下键来实现状态: - DATASET_NODE_STATES_KEY: 每个源节点的状态字典。 - DATASETS_EXHAUSTED_KEY: 每个源节点是否耗尽的布尔字典。 - EPOCH_KEY: 用于初始化随机数生成器的epoch计数器。 - NUM_YIELDED_KEY: 输出的项目数量。 - WEIGHTED_SAMPLER_STATE_KEY: 加权采样器的状态。
我们支持多种停止条件: - CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED:循环遍历源节点,直到所有数据集耗尽。这是默认行为。 - FIRST_DATASET_EXHAUSTED:当第一个数据集耗尽时停止。 - ALL_DATASETS_EXHAUSTED:当所有数据集耗尽时停止。
当源节点完全耗尽时,该节点将引发 StopIteration 异常。
- Parameters:
source_nodes (Mapping[str, BaseNode[T]]) – 一个源节点的字典。
权重 (Dict[字符串, 浮点数]) – 一个包含每个源节点权重的字典。
停止标准 (字符串) – 停止标准。默认值为CYCLE_UNTIL_ALL_DATASETS_EXHAUST
排名 (整数) – 当前进程的排名。默认值为None,在这种情况下,将从分布式环境中获取排名。
world_size (int) – 世界环境的分布式大小。默认值为None,在这种情况下,将从分布式环境中获取世界大小。
种子 (整数) – 随机数生成器的种子。默认值为0。
- get_state() Dict[str, Any]¶
子类必须实现这个方法,而不是 state_dict()。只应在 BaseNode 中调用。 :return: Dict[str, Any] - 一个可能在将来某个时间点通过 reset() 方法传递的状态字典
- next() T¶
子类必须实现这个方法,而不是
__next。只应由基节点调用。 :return: T - 序列中的下一个值,或者抛出StopIteration
- reset(initial_state: Optional[Dict[str, Any]] = None)¶
将迭代器重置到开始位置,或者重置为初始状态传递的值。
重置是一个放置昂贵初始化的好地方,因为它将在调用 next() 或 state_dict() 时懒惰地被调用。 子类必须调用
super().reset(initial_state)。- Parameters:
初始状态 – Optional[字典] – 一个传递给节点的状态字典。如果为None,则重置为开始。
- class torchdata.nodes.ParallelMapper(source: BaseNode[X], map_fn: Callable[[X], T], num_workers: int, in_order: bool = True, method: Literal['thread', 'process'] = 'thread', multiprocessing_context: Optional[str] = None, max_concurrent: Optional[int] = None, snapshot_frequency: int = 1)¶
Bases:
BaseNode[T]ParallelMapper 并行执行 map_fn,要么在 num_workers 个线程中运行,要么在进程中运行。对于进程,multiprocessing_context 可以是 spawn、forkserver、fork 或 None(选择操作系统默认值)。最多 max_concurrent 个任务将被处理或存放在迭代器的输出队列中,以限制 CPU 和内存使用量。如果为 None(默认),则该值将设置为 2 * num_workers。
最多只有一个 iter() 从源创建,且最多只有一个线程同时调用 next()。
如果在_order 是 true,迭代器将按它们到达源迭代器的顺序返回项目,即使其他项目可用也可能阻塞。
- Parameters:
源 (BaseNode[X]) – 要映射的源节点。
map_fn (Callable[[X], T]) – 应用到每个源节点项上的函数。
num_workers (int) – 使用进行并行处理的工人数量。
按顺序 (bool) – 是否返回从到达顺序中获取的项目。默认值为True。
方法 (Literal["线程", "进程"]) – 使用进行并行处理的方法。默认是“线程”。
multiprocessing_context (Optional[str]) – 使用进行并行处理的多进程上下文。默认值为None。
最大并发 (可选[整数]) – 一次处理的最大项目数量。默认值为None。
快照频率 (整数) – 源节点状态的快照频率。默认值为1。
- get_state() Dict[str, Any]¶
子类必须实现这个方法,而不是 state_dict()。只应在 BaseNode 中调用。 :return: Dict[str, Any] - 一个可能在将来某个时间点通过 reset() 方法传递的状态字典
- next()¶
子类必须实现这个方法,而不是
__next。只应由基节点调用。 :return: T - 序列中的下一个值,或者抛出StopIteration
- reset(initial_state: Optional[Dict[str, Any]] = None)¶
将迭代器重置到开始位置,或者重置为初始状态传递的值。
重置是一个放置昂贵初始化的好地方,因为它将在调用 next() 或 state_dict() 时懒惰地被调用。 子类必须调用
super().reset(initial_state)。- Parameters:
初始状态 – Optional[字典] – 一个传递给节点的状态字典。如果为None,则重置为开始。
- class torchdata.nodes.PinMemory(source: BaseNode[T], pin_memory_device: str = '', snapshot_frequency: int = 1)¶
Bases:
BaseNode[T]将底层节点的数据绑定到设备上。这由 torch.utils.data._utils.pin_memory._pin_memory_loop 支持。
- Parameters:
源 (BaseNode[T]) – 要固定数据的源节点。
pin_memory_device (str) – 将数据绑定到的设备。默认值为空。
快照频率 (整数) – 源节点状态的快照频率。默认值为 1,这意味着源节点的状态将在每个项目后进行快照。如果设置为更高的值,源节点的状态将每 snapshot_frequency 个项目进行快照。
- get_state() Dict[str, Any]¶
子类必须实现这个方法,而不是 state_dict()。只应在 BaseNode 中调用。 :return: Dict[str, Any] - 一个可能在将来某个时间点通过 reset() 方法传递的状态字典
- next()¶
子类必须实现这个方法,而不是
__next。只应由基节点调用。 :return: T - 序列中的下一个值,或者抛出StopIteration
- reset(initial_state: Optional[Dict[str, Any]] = None)¶
将迭代器重置到开始位置,或者重置为初始状态传递的值。
重置是一个放置昂贵初始化的好地方,因为它将在调用 next() 或 state_dict() 时懒惰地被调用。 子类必须调用
super().reset(initial_state)。- Parameters:
初始状态 – Optional[字典] – 一个传递给节点的状态字典。如果为None,则重置为开始。
- class torchdata.nodes.Prefetcher(source: BaseNode[T], prefetch_factor: int, snapshot_frequency: int = 1)¶
Bases:
BaseNode[T]从源节点预取数据并将其存储在队列中。
- Parameters:
源 (BaseNode[T]) – 用于预取数据的源节点。
prefetch_factor (int) – 预先缓存的项目数量。
快照频率 (整数) – 源节点状态的快照频率。默认值为 1,这意味着源节点的状态将在每个项目后进行快照。如果设置为更高的值,源节点的状态将每 snapshot_frequency 个项目进行快照。
- get_state() Dict[str, Any]¶
子类必须实现这个方法,而不是 state_dict()。只应在 BaseNode 中调用。 :return: Dict[str, Any] - 一个可能在将来某个时间点通过 reset() 方法传递的状态字典
- next()¶
子类必须实现这个方法,而不是
__next。只应由基节点调用。 :return: T - 序列中的下一个值,或者抛出StopIteration
- reset(initial_state: Optional[Dict[str, Any]] = None)¶
将迭代器重置到开始位置,或者重置为初始状态传递的值。
重置是一个放置昂贵初始化的好地方,因为它将在调用 next() 或 state_dict() 时懒惰地被调用。 子类必须调用
super().reset(initial_state)。- Parameters:
初始状态 – Optional[字典] – 一个传递给节点的状态字典。如果为None,则重置为开始。
- class torchdata.nodes.SamplerWrapper(sampler: Sampler[T], initial_epoch: int = 0, epoch_updater: Optional[Callable[[int], int]] = None)¶
Bases:
BaseNode[T]将采样器转换为BaseNode。这几乎与IterableWrapper相同,除了它包括一个钩子,在支持的情况下调用set_epoch方法。
- Parameters:
抽样器 (Sampler) – 用于包装的抽样器。
初始轮次 (整数) – 在采样器上设置的初始轮次
epoch_updater (Optional[Callable[[int], int]] = None) – callback to update epoch at start of new iteration. It’s called at the beginning of each iterator request, except the first one.
- get_state() Dict[str, Any]¶
子类必须实现这个方法,而不是 state_dict()。只应在 BaseNode 中调用。 :return: Dict[str, Any] - 一个可能在将来某个时间点通过 reset() 方法传递的状态字典
- next() T¶
子类必须实现这个方法,而不是
__next。只应由基节点调用。 :return: T - 序列中的下一个值,或者抛出StopIteration
- reset(initial_state: Optional[Dict[str, Any]] = None)¶
将迭代器重置到开始位置,或者重置为初始状态传递的值。
重置是一个放置昂贵初始化的好地方,因为它将在调用 next() 或 state_dict() 时懒惰地被调用。 子类必须调用
super().reset(initial_state)。- Parameters:
初始状态 – Optional[字典] – 一个传递给节点的状态字典。如果为None,则重置为开始。
- class torchdata.nodes.Stateful(*args, **kwargs)¶
Bases:
Protocol协议对象,既支持
state_dict()又支持load_state_dict(state_dict: Dict[str, Any])
- class torchdata.nodes.StopCriteria¶
Bases:
object数据采样器的停止条件。
直到所有数据集耗尽:在最后一个未见过的数据集耗尽时停止。 所有数据集至少被看到一次。在某些情况下,当还有未耗尽的数据集时,某些数据集可能会被看到多次。
所有数据集耗尽:当所有数据集都耗尽时停止。每个数据集仅被看到一次。不会进行循环或重新开始。
当第一个数据集耗尽时停止。