目录

注意力

2024 年 6 月状态更新:删除 DataPipes 和 DataLoader V2

我们将 torchdata 存储库重新调整为torch.utils.data.DataLoader的迭代增强。我们不打算 继续开发或维护 [DataPipes] 和 [DataLoaderV2] 解决方案,它们将从 torchdata 存储库。我们还将重新访问 pytorch/pytorch 中的 DataPipes 引用。在 torchdata==0.8.0(2024 年 7 月)版本中,它们将被标记为已弃用,而在 0.9.0(2024 年 10 月)中,它们将被删除。现存 建议用户固定到 torchdata==0.8.0 或更旧版本,直到他们能够迁移出去。随后的 版本将不包含 DataPipes 或 DataLoaderV2。 如果您有建议或评论,请联系我们(请使用此问题进行反馈)

有状态 DataLoader 教程

保存和加载状态

有状态 DataLoader 将 , 方法添加到 .状态获取和设置可以按如下方式完成:load_state_dictstate_dicttorch.utils.data.DataLoader

from torchdata.stateful_dataloader import StatefulDataLoader

dataloader = StatefulDataLoader(dataset, num_workers=2)
for i, batch in enumerate(dataloader):
    ...
    if i == 10:
        state_dict = dataloader.state_dict()
        break

# Training run resumes with the previous checkpoint
dataloader = StatefulDataLoader(dataset, num_workers=2)
# Resume state with DataLoader
dataloader.load_state_dict(state_dict)
for i, batch in enumerate(dataloader):
    ...

使用地图样式数据集保存自定义状态

为了有效地恢复 Map 样式的数据集,您可以通过在采样器中定义 / 方法来恢复迭代。如果您的数据集具有特定于 worker 的状态(例如 RNG 转换状态),则可以向数据集添加 / 方法。state_dictload_state_dictstate_dictload_state_dict

from typing import *
import torch
import torch.utils.data
from torchdata.stateful_dataloader import StatefulDataLoader

# If you are using the default RandomSampler and BatchSampler in torch.utils.data, they are patched when you import torchdata.stateful_dataloader so that defining, a custom sampler here is unnecessary
class MySampler(torch.utils.data.Sampler[int]):
    def __init__(self, high: int, seed: int, limit: int):
        self.seed, self.high, self.limit = seed, high, limit
        self.g = torch.Generator()
        self.g.manual_seed(self.seed)
        self.i = 0

    def __iter__(self):
        while self.i < self.limit:
        val = int(torch.randint(high=self.high, size=(1,), generator=self.g))
        self.i += 1
        yield val

    def load_state_dict(self, state_dict: Dict[str, Any]):
        self.i = state_dict["i"]
        self.g.set_state(state_dict["rng"])

    def state_dict(self) -> Dict[str, Any]:
        return {"i": self.i, "rng": self.g.get_state()}

# Optional: save dataset random transform state
class NoisyRange(torch.utils.data.Dataset):
    def __init__(self, high: int, mean: float, std: float):
        self.high, self.mean, self.std = high, torch.tensor([float(mean)]), float(std)

    def __len__(self):
        return self.high

    def __getitem__(self, idx: int) -> float:
        if not (0 <= idx < self.high):
        raise IndexError()
        x = torch.normal(self.mean, self.std)
        noise = x.item()
        return idx + noise

    def load_state_dict(self, state_dict):
        torch.set_rng_state(state_dict["rng"])

    def state_dict(self):
        return {"rng": torch.get_rng_state()}

# Test both single/multiprocess dataloading
for num_workers in [0, 2]:
    print(f"{num_workers=}")
    dl = StatefulDataLoader(NoisyRange(5, 1, 1), sampler=MySampler(5, 1, 10),
        batch_size=2, drop_last=False, num_workers=num_workers)

batches = []
for i, batch in enumerate(dl):
    batches.append(batch)
    if i == 2:
    sd = dl.state_dict()

dl.load_state_dict(sd)
batches2 = list(dl)

print(batches[3:])
print(batches2)

"""
Output:
num_workers=0
[tensor([-0.4526,  3.7948], dtype=torch.float64), tensor([6.5494, 3.0470], dtype=torch.float64)]
[tensor([-0.4526,  3.7948], dtype=torch.float64), tensor([6.5494, 3.0470], dtype=torch.float64)]
num_workers=2
[tensor([3.7412, 1.2438], dtype=torch.float64), tensor([4.4807, 4.0036], dtype=torch.float64)]
[tensor([3.7412, 1.2438], dtype=torch.float64), tensor([4.4807, 4.0036], dtype=torch.float64)]
"""

使用 Iterable 样式的数据集保存自定义状态

使用 Iterable 样式的数据集跟踪迭代顺序需要捕获数据集的每个 worker 级实例的 state。您可以在数据集上定义 / 方法来捕获 worker 级状态。 将处理跨 worker 的聚合并分发回 worker。调用 要求与提供的 .state_dictload_state_dictStatefulDataLoaderload_state_dictStatefulDataLoader`num_workersstate_dict

from typing import *
import torch
import torch.utils.data
from torchdata.stateful_dataloader import StatefulDataLoader


class MyIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, high: int, seed: int):
        self.high, self.seed = high, seed
        self.g = torch.Generator()
        self.i = 0

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is not None:
        worker_id = worker_info.id
        num_workers = worker_info.num_workers
        else:
        worker_id = 0
        num_workers = 1
        self.g.manual_seed(self.seed)
        arr = torch.randperm(self.high, generator=self.g)
        arr = arr[worker_id:self.high:num_workers]
        for idx in range(self.i, len(arr)):
        self.i += 1
        yield arr[idx]
        self.i = 0

    def state_dict(self):
        return {"i": self.i}

    def load_state_dict(self, state_dict):
        self.i = state_dict["i"]

# Test both single/multiprocess dataloading
for num_workers in [0, 2]:
print(f"{num_workers=}")
dl = StatefulDataLoader(
    MyIterableDataset(12, 0), batch_size=2, drop_last=False,
    num_workers=num_workers)

batches = []
for i, batch in enumerate(dl):
    batches.append(batch)
    if i == 2:
    sd = dl.state_dict()

dl.load_state_dict(sd)
batches2 = list(dl)

print(batches[3:])
print(batches2)

"""
Output:
num_workers=0
[tensor([ 2, 10]), tensor([3, 1]), tensor([11,  6])]
[tensor([ 2, 10]), tensor([3, 1]), tensor([11,  6])]
num_workers=2
[tensor([ 4, 10]), tensor([ 3, 11]), tensor([1, 6])]
[tensor([ 4, 10]), tensor([ 3, 11]), tensor([1, 6])]
"""

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源