目录

torch.utils.data

PyTorch 数据加载工具的核心是 torch.utils.data.DataLoader 类。它表示一个对数据集的 Python 可迭代对象,并提供支持

这些选项通过DataLoader的构造函数参数进行配置,其签名如下:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

以下部分将详细描述这些选项的效果和用法。

数据集类型

最重要的参数是 DataLoader 构造函数的 dataset,它表示一个数据集对象以加载数据 从。PyTorch 支持两种不同类型的数据集:

地图样式数据集

一种基于映射的数据集是指实现了 __getitem__()__len__() 协议,并表示从(可能为非整数的)索引/键到数据样本的映射。

例如,通过使用dataset[idx]访问这样的数据集时,可以从磁盘上的文件夹中读取第idx张图像及其对应的标签。

详见 Dataset 以获取更多详细信息。

可迭代风格的数据集

可迭代样式数据集是 IterableDataset 的子类实例 它实现了 __iter__() 协议,并表示对数据样本的可迭代对象。这种类型的数据集特别适用于随机读取成本较高甚至不可能的情况,以及批量大小取决于获取的数据的情况。

例如,当这样的数据集被调用为iter(dataset)时,它可以返回从数据库、远程服务器甚至实时生成的日志中读取的数据流。

详见 IterableDataset 以获取更多详细信息。

注意

当使用 IterableDataset多进程数据加载。每个工作进程都会复制同一个 数据集对象,因此必须对这些副本进行不同的配置以避免数据重复。请参阅 IterableDataset 文档了解如何实现这一点。

数据加载顺序和 Sampler

对于 可迭代式数据集,数据加载顺序完全由用户定义的可迭代对象控制。这使得实现分块读取和动态批量大小(例如,每次生成一个批量样本)变得更加容易。

本节的其余部分涉及 映射风格的数据集torch.utils.data.Sampler 类用于指定数据加载中使用的索引/键序列。 它们表示数据集索引上的可迭代对象。例如,在常见的随机梯度下降(SGD)情况下, Sampler 可以随机打乱一个索引列表, 并逐个提供每个索引,或者为小批量 SGD 提供一小部分索引。

一个顺序或洗牌采样器将根据传递给 shuffle 参数的 DataLoader 自动构建。 或者,用户可以使用 sampler 参数指定一个自定义的 Sampler 对象,该对象在每次调用时生成下一个索引/键以获取数据。

自定义的 Sampler 可以一次传递一个批次索引列表作为 batch_sampler 参数。 还可以通过 batch_sizedrop_last 参数启用自动批处理。有关详细信息,请参阅 下一节

注意

Neither sampler nor batch_sampler is compatible with iterable-style datasets, since such datasets have no notion of a key or an index。

加载批量和非批量数据

DataLoader 支持通过参数 batch_size, drop_last, batch_sampler, 和 collate_fn(具有默认功能)自动将单个获取的数据样本整理成批次。

自动批处理(默认)

这是最常见的情况,对应于获取一个数据小批量(minibatch)并将其整理为批量样本,即包含张量的维度中有一个是批量维度(通常是第一个维度)。

batch_size(默认值为 1)不是 None 时,数据加载器将提供批量样本而不是单个样本。使用 batch_sizedrop_last 参数来指定数据加载器如何获取数据集键的批次。对于映射式数据集,用户也可以选择性地指定 batch_sampler,它每次生成一组键。

注意

batch_sizedrop_last 参数主要用于 从 sampler 构建一个 batch_sampler。对于映射风格的数据集, sampler 由用户提供或根据 shuffle 参数构建。 对于可迭代风格的数据集,sampler 是一个虚拟的无限数据集。 有关采样器的更多详细信息,请参阅 此部分

注意

当从使用 可迭代数据集 并启用 多进程 时,drop_last 参数会丢弃每个工作进程数据集副本的最后一个不完整的批次。

在使用采样器中的索引获取样本列表后,作为collate_fn参数传递的函数用于将样本列表整理成批次。

在这种情况下,从映射式数据集加载大致等同于:

for indices in batch_sampler:
    yield collate_fn([dataset[i] for i in indices])

并从可迭代样式数据集加载大致等同于:

dataset_iter = iter(dataset)
for indices in batch_sampler:
    yield collate_fn([next(dataset_iter) for _ in indices])

可以使用自定义的 collate_fn 来自定义排序规则,例如将序列数据填充到批次的最大长度。有关 collate_fn 的更多信息,请参阅 此部分

禁用自动批处理

在某些情况下,用户可能希望在数据集代码中手动处理批处理, 或仅加载单个样本。例如,直接加载批处理数据可能更经济高效(例如,从数据库进行批量读取或读取连续的内存块), 或者批处理大小取决于数据,或者程序设计为处理单个样本。在这些场景下, 可能最好不使用自动批处理(其中使用 collate_fn 来合并样本),而是让数据加载器直接返回 dataset 对象中的每个成员。

batch_sizebatch_sampler 都为 Nonebatch_sampler 的默认值已经是 None)时,自动批处理将被禁用。从 dataset 获取的每个样本都会使用作为 collate_fn 参数传递的函数进行处理。

当自动批处理被禁用时,默认值 collate_fn 会简单地将 NumPy 数组转换为 PyTorch 张量,并保持其他内容不变。

在这种情况下,从映射式数据集加载大致等同于:

for index in sampler:
    yield collate_fn(dataset[index])

并从可迭代样式数据集加载大致等同于:

for data in iter(dataset):
    yield collate_fn(data)

请参阅此部分以了解有关collate_fn的更多信息。

使用 collate_fn

当启用了自动批处理或禁用时,collate_fn 的使用略有不同。

当自动批处理被禁用时collate_fn 会针对每个单独的数据样本进行调用,并从数据加载器迭代器中生成输出。在这种情况下,默认的 collate_fn 会简单地将 NumPy 数组转换为 PyTorch 张量。

当启用了自动批处理时collate_fn 每次会以一个数据样本列表作为参数调用。预期该函数将这些输入样本整理成一个批次,并从数据加载器迭代器中返回。本节其余部分描述默认的 collate_fn (default_collate()) 的行为。

例如,如果每个数据样本由一个3通道图像和一个整数类别标签组成,即数据集的每个元素返回一个元组 (image, class_index),默认的 collate_fn 会将这些元组列表合并为一个包含批量图像张量和批量类别标签张量的元组。特别是,默认的 collate_fn 具有以下属性:

  • 它始终会在前面添加一个新的维度作为批量维度。

  • 它会自动将 NumPy 数组和 Python 数值转换为 PyTorch 张量。

  • 它保留了数据结构,例如,如果每个样本是一个字典,它会输出一个具有相同键集的字典,但值是批量化的张量(Tensor)(如果无法转换为张量,则为列表)。对于 list s, tuple s, namedtuple s 等也是如此。

用户可以使用自定义的 collate_fn 来实现自定义批处理,例如, 沿着除第一个维度外的其他维度进行拼接,填充不同长度的序列,或添加对自定义数据类型的支持。

如果你遇到一种情况,其中 DataLoader 的输出维度或类型与你的预期不同,你可能需要检查你的 collate_fn

单进程和多进程数据加载

A DataLoader 默认使用单进程数据加载。

在一个Python进程中, 全局解释器锁(GIL) 会阻止Python代码在多线程之间实现真正的并行执行。为了避免数据加载阻塞计算代码,PyTorch 提供了一个简单的切换机制,通过将参数 num_workers 设置为一个正整数来实现多进程数据加载。

单进程数据加载(默认)

在此模式下,数据获取在与 DataLoader 初始化的同一进程中完成。因此,数据加载可能会阻塞计算。然而,当用于在进程之间共享数据的资源(例如,共享内存、文件描述符)有限时,或者整个数据集较小且可以完全加载到内存中时,可能更倾向于使用此模式。此外,单进程加载通常显示更易读的错误跟踪信息,因此对于调试很有用。

多进程数据加载

将参数 num_workers 设置为正整数将启用多进程数据加载,并使用指定数量的加载器工作进程。

警告

经过几次迭代后,加载器工作进程将消耗与父进程相同数量的CPU内存,这些内存用于父进程中所有被工作进程访问的Python对象。如果数据集包含大量数据(例如,在构建数据集时加载了非常大的文件名列表)和/或使用了很多工作进程(总体内存使用量为number of workers * size of parent process),这可能会出现问题。最简单的解决方法是用非引用计数表示形式(如Pandas、Numpy或PyArrow对象)替换Python对象。有关此问题发生原因及如何解决这些问题的示例代码,请查看 issue #13246.

在此模式下,每次创建一个 DataLoader 的迭代器(例如,当你调用 enumerate(dataloader) 时),都会创建 num_workers 个工作进程。此时,datasetcollate_fnworker_init_fn 会被传递给每个 工作进程,在那里它们被用来初始化并获取数据。这意味着 数据集访问及其内部 IO、转换 (包括 collate_fn)在工作进程中运行。

torch.utils.data.get_worker_info() 返回工作进程中各种有用的信息(包括工作进程 ID、数据集副本、初始种子等),而在主进程中返回 None。用户可以在数据集代码和/或 worker_init_fn 中使用此函数,为每个数据集副本分别进行配置,并确定代码是否在工作进程中运行。例如,这在对数据集进行分片时特别有帮助。

对于地图样式的数据集,主进程使用 sampler 生成索引并将其发送到工作进程。因此,任何洗牌随机化都在主进程中完成,通过分配要加载的索引来指导加载过程。

对于可迭代风格的数据集,由于每个工作进程都会获得dataset对象的副本,简单的多进程加载通常会导致数据重复。使用torch.utils.data.get_worker_info()和/或worker_init_fn,用户可以独立配置每个副本。(参见IterableDataset文档了解如何实现这一点。)出于类似的原因,在多进程加载中,drop_last参数会丢弃每个工作进程的可迭代风格数据集副本中的最后一个不完整批次。

在迭代结束时,或者当迭代器被垃圾回收时,工作进程将被关闭。

警告

通常不建议在多进程加载中返回CUDA张量,因为使用CUDA和在多进程中共享CUDA张量有许多细微之处(参见多进程中的CUDA)。相反,我们建议使用自动内存固定(即设置pin_memory=True),这可以实现快速的数据传输到CUDA-enabled GPU。

特定平台的行为

由于工作进程依赖于 Python multiprocessing,因此在 Windows 上启动工作进程的行为与 Unix 不同。

  • 在 Unix 系统中,fork() 是默认的 multiprocessing 启动方法。 使用 fork() 时,子工作进程通常可以通过克隆的地址空间直接访问 dataset 和 Python 参数函数。

  • 在Windows或MacOS上,spawn() 是默认的 multiprocessing 启动方法。 使用 spawn(),将启动另一个解释器来运行你的主脚本, 然后是接收 datasetcollate_fn 和其他参数的内部工作函数,这些参数通过 pickle 序列化传递。

这种独立的序列化意味着在使用多进程数据加载时,为确保与 Windows 兼容,您应采取两个步骤:

  • 将你主脚本中的大部分代码放在 if __name__ == '__main__': 块中, 以确保在每个工作进程启动时不会再次运行(很可能会产生错误)。你可以在此处放置数据集和 DataLoader 实例的创建逻辑,因为它不需要在工作进程中重新执行。

  • 确保任何自定义的 collate_fnworker_init_fndataset 代码都声明为顶级定义,位于 __main__ 检查之外。这可以确保它们在工作进程中可用。 (这是必要的,因为函数仅以引用形式进行序列化,而不是 bytecode。)

多进程数据加载中的随机性

默认情况下,每个工作进程的 PyTorch 种子将被设置为 base_seed + worker_id, 其中 base_seed 是由主进程使用其随机数生成器(从而强制消耗一个随机数状态)生成的长整数,或者是指定的 generator。然而,在初始化工作进程时,其他库的种子可能会重复,导致每个工作进程返回相同的随机数。(请参阅 FAQ 中的 此部分。)。

worker_init_fn 中,您可以使用 torch.utils.data.get_worker_info().seedtorch.initial_seed() 访问每个工作器的 PyTorch 随机种子集,并在数据加载之前用它来设置其他库的随机种子。

内存固定

从固定(页面锁定)内存中进行主机到 GPU 的复制速度更快。有关何时以及如何一般性地使用固定内存的更多详细信息,请参阅 使用固定内存缓冲区

对于数据加载,将 pin_memory=True 传递给 DataLoader 将会自动把获取的数据 张量放入固定内存中,从而实现更快地传输到支持 CUDA 的 GPU 上。

默认的内存固定逻辑只识别张量(Tensors)以及包含张量的地图和可迭代对象。如果固定逻辑发现一个批次是自定义类型(如果你有一个collate_fn返回自定义批次类型),或者如果你批次中的每个元素都是自定义类型,那么固定逻辑将无法识别它们,并且会不进行内存固定地返回该批次(或这些元素)。要为自定义批次或数据类型启用内存固定,请在你的自定义类型上定义一个pin_memory()方法。

请参见下面的示例。

Example:

class SimpleCustomBatch:
    def __init__(self, data):
        transposed_data = list(zip(*data))
        self.inp = torch.stack(transposed_data[0], 0)
        self.tgt = torch.stack(transposed_data[1], 0)

    # custom memory pinning method on custom type
    def pin_memory(self):
        self.inp = self.inp.pin_memory()
        self.tgt = self.tgt.pin_memory()
        return self

def collate_wrapper(batch):
    return SimpleCustomBatch(batch)

inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)

loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
                    pin_memory=True)

for batch_ndx, sample in enumerate(loader):
    print(sample.inp.is_pinned())
    print(sample.tgt.is_pinned())
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=2, persistent_workers=False, pin_memory_device='')[source]

数据加载器。将数据集和采样器结合起来,并提供对给定数据集的可迭代对象。

The DataLoader 支持单进程或多进程加载的映射式和可迭代式数据集,可以自定义加载顺序,并具有可选的自动批处理(整理)和内存固定功能。

请参阅 torch.utils.data 文档页面以获取更多详细信息。

Parameters
  • 数据集 (Dataset) – 用于加载数据的数据集。

  • batch_size (int, 可选) – 每个批次加载多少样本 (默认: 1)。

  • shuffle (bool, optional) – 设置为 True 以在每个训练周期重新洗牌数据 (默认: False)。

  • sampler (SamplerIterable, 可选) – 定义从数据集中抽取样本的策略。可以是任何实现了 Iterable__len__。 如果指定了此参数,则必须不能指定 shuffle

  • batch_sampler (SamplerIterable, 可选) – 类似于 sampler,但 每次返回一组索引。与 batch_size, shuffle, sampler, 以及 drop_last 互斥。

  • num_workers (int, optional) – 使用多少个子进程来加载数据。 0 表示数据将在主进程中加载。(默认值: 0)

  • collate_fn (Callable, 可选) – 合并一个样本列表以形成一个 张量的小批量。当使用来自映射式数据集的批量加载时使用。

  • pin_memory (bool, optional) – 如果 True,数据加载器将在返回张量之前将其复制到设备/CUDA固定内存中。如果您的数据元素是自定义类型,或者您的 collate_fn 返回的批次是自定义类型,请参见下面的示例。

  • drop_last (bool, optional) – 设置为 True 以丢弃最后一个不完整的批次, 如果数据集大小不能被批次大小整除。如果设置为 False 并且数据集大小不能被批次大小整除,则最后一个批次会更小。(默认: False)

  • timeout (数字类型, 可选) – 如果为正数,则是从工作进程中收集一个批次数据的超时值。该值应始终为非负数。(默认: 0)

  • worker_init_fn (Callable, optional) – 如果不是 None,则在每个 工作子进程上会使用工作器 ID(一个在 [0, num_workers - 1] 中的 int)作为 输入,在设置随机种子之后和加载数据之前调用。 (默认值: None)

  • 生成器 (torch.Generator, 可选) – 如果不为 None,此随机数生成器将被 RandomSampler 用于生成随机索引,并被多进程用于为工作进程生成 base_seed。 (默认值: None)

  • prefetch_factor (int, optional, keyword-only arg) – 每个工作线程提前加载的批次数量。 2 表示所有工作线程总共会提前加载 2 * num_workers 个批次。(默认值: 2)

  • persistent_workers (bool, 可选) – 如果为 True,数据加载器在数据集被消费一次后不会关闭工作进程。这允许保持工作进程 Dataset 个实例处于活动状态。(默认: False

  • pin_memory_device (str, optional) – 如果将 pin_memory 设置为 true,数据加载器会在返回 Tensor 之前将其复制到设备的 pinned 内存中。

警告

如果使用 spawn 启动方法,worker_init_fn 不能是不可序列化的对象,例如 lambda 函数。有关 PyTorch 中多进程的更多详细信息,请参阅 多进程最佳实践

警告

len(dataloader) 的启发式方法基于所使用的采样器的长度。 当 dataset 是一个 IterableDataset 时, 它会返回基于 len(dataset) / batch_size 的估计值,并根据 drop_last 进行适当的四舍五入, 无论多进程加载配置如何。这代表了 PyTorch 能够做出的最佳猜测,因为 PyTorch 信任用户提供的 dataset 代码能够正确处理多进程加载以避免重复数据。

然而,如果分片导致多个工作器拥有不完整的最后一批数据, 这个估计仍然可能不准确,因为 (1) 一个原本完整的批次可以 被拆分成多个批次,以及 (2) 当设置 drop_last 时,可能会丢弃超过一个批次的数据。不幸的是,PyTorch 无法在一般情况下检测到这些情况。

查看 数据集类型 以了解这两种数据集的更多详细信息,以及 IterableDataset 如何与 多进程数据加载 交互。

警告

查看 可复现性,以及 我的数据加载器工作进程返回相同的随机数多进程数据加载中的随机性 的说明以了解与随机种子相关的问题。

class torch.utils.data.Dataset(*args, **kwds)[source]

一个表示 Dataset 的抽象类。

所有表示从键到数据样本映射的数据集都应继承自它。所有子类都应该重写 __getitem__(),以支持根据给定的键获取一个数据样本。子类还可以选择性地重写 __len__(),这被许多 Sampler 实现和 DataLoader 的默认选项所期望返回数据集的大小。

注意

DataLoader 默认会构建一个索引采样器,生成整数索引。要使其与具有非整数索引/键的映射式数据集一起工作,必须提供自定义的采样器。

class torch.utils.data.IterableDataset(*args, **kwds)[source]

一个可迭代的数据集。

所有表示数据样本可迭代的数据集都应继承它。 当数据来自流时,这种形式的数据集特别有用。

所有子类都应覆盖 __iter__(),该方法将返回一个迭代器,用于遍历此数据集中的样本。

当使用子类与 DataLoader 时,数据集中的每个项目将从 DataLoader 迭代器中生成。当 num_workers > 0 时,每个工作进程将拥有一个不同的数据集对象副本,因此通常需要独立配置每个副本以避免从工作进程中返回重复的数据。 get_worker_info() 在工作进程中调用时会返回有关该工作进程的信息。它可以在数据集的 __iter__() 方法或 DataLoaderworker_init_fn 选项中使用,以修改每个副本的行为。

示例 1:在 __iter__() 中将工作负载分配到所有工作人员上:

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         worker_info = torch.utils.data.get_worker_info()
...         if worker_info is None:  # single-process data loading, return the full iterator
...             iter_start = self.start
...             iter_end = self.end
...         else:  # in a worker process
...             # split workload
...             per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
...             worker_id = worker_info.id
...             iter_start = self.start + worker_id * per_worker
...             iter_end = min(iter_start + per_worker, self.end)
...         return iter(range(iter_start, iter_end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]

>>> # Mult-process loading with two worker processes
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 5, 4, 6]

>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20)))
[3, 4, 5, 6]

示例 2:使用 worker_init_fn 将工作负载分配到所有工作人员上:

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         return iter(range(self.start, self.end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>>
>>> # Directly doing multi-process loading yields duplicate data
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 3, 4, 4, 5, 5, 6, 6]

>>> # Define a `worker_init_fn` that configures each dataset copy differently
>>> def worker_init_fn(worker_id):
...     worker_info = torch.utils.data.get_worker_info()
...     dataset = worker_info.dataset  # the dataset copy in this worker process
...     overall_start = dataset.start
...     overall_end = dataset.end
...     # configure the dataset to only process the split workload
...     per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
...     worker_id = worker_info.id
...     dataset.start = overall_start + worker_id * per_worker
...     dataset.end = min(dataset.start + per_worker, overall_end)
...

>>> # Mult-process loading with the custom `worker_init_fn`
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
[3, 5, 4, 6]

>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn)))
[3, 4, 5, 6]
class torch.utils.data.TensorDataset(*tensors)[source]

封装张量的数据集。

每个样本将通过沿着第一个维度对张量进行索引来检索。

Parameters

*张量 (张量) – 张量的第一个维度大小相同。

class torch.utils.data.ConcatDataset(datasets)[source]

将多个数据集串联起来形成的数据集。

此类有助于整合不同的现有数据集。

Parameters

数据集 (序列) – 要连接的数据集列表

class torch.utils.data.ChainDataset(datasets)[source]

用于连接多个 IterableDataset 的数据集。

此类对于组合不同的现有数据集流非常有用。链接操作是在运行时进行的,因此使用此类连接大规模数据集将非常高效。

Parameters

数据集 (可迭代的 IterableDataset 对象集合) – 要串联在一起的数据集

class torch.utils.data.Subset(dataset, indices)[source]

指定索引处的数据集子集。

Parameters
  • dataset (数据集) – 整个数据集

  • 索引 (序列) – 在整个数据集中选择子集的索引

torch.utils.data.default_collate(batch)[source]

Function that takes in a batch of data and puts the elements within the batch into a tensor with an additional outer dimension - batch size. The exact output type can be a torch.Tensor, a Sequence of torch.Tensor, a Collection of torch.Tensor, or left unchanged, depending on the input type. This is used as the default function for collation when batch_size or batch_sampler is defined in DataLoader.

下面是通用的输入类型(基于批量中元素的类型)到输出类型的映射:

  • torch.Tensor -> torch.Tensor (with an added outer dimension batch size)

  • NumPy Arrays -> torch.Tensor

  • float -> torch.Tensor

  • int -> torch.Tensor

  • str -> str (unchanged)

  • bytes -> bytes (unchanged)

  • Mapping[K, V_i] -> Mapping[K, default_collate([V_1, V_2, …])]

  • NamedTuple[V1_i, V2_i, …] -> NamedTuple[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …]

  • Sequence[V1_i, V2_i, …] -> Sequence[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …]

Parameters

batch – 要进行拼接的单个批次

示例

>>> # Example with a batch of `int`s:
>>> default_collate([0, 1, 2, 3])
tensor([0, 1, 2, 3])
>>> # Example with a batch of `str`s:
>>> default_collate(['a', 'b', 'c'])
['a', 'b', 'c']
>>> # Example with `Map` inside the batch:
>>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}])
{'A': tensor([  0, 100]), 'B': tensor([  1, 100])}
>>> # Example with `NamedTuple` inside the batch:
>>> Point = namedtuple('Point', ['x', 'y'])
>>> default_collate([Point(0, 0), Point(1, 1)])
Point(x=tensor([0, 1]), y=tensor([0, 1]))
>>> # Example with `Tuple` inside the batch:
>>> default_collate([(0, 1), (2, 3)])
[tensor([0, 2]), tensor([1, 3])]
>>> # Example with `List` inside the batch:
>>> default_collate([[0, 1], [2, 3]])
[tensor([0, 2]), tensor([1, 3])]
torch.utils.data.default_convert(data)[source]

Function that converts each NumPy array element into a torch.Tensor. If the input is a Sequence, Collection, or Mapping, it tries to convert each element inside to a torch.Tensor. If the input is not an NumPy array, it is left unchanged. This is used as the default function for collation when both batch_sampler and batch_size are NOT defined in DataLoader.

从一般输入类型到输出类型的映射类似于 default_collate()。有关更多详细信息,请参阅那里的描述。

Parameters

data – 要转换的单个数据点

示例

>>> # Example with `int`
>>> default_convert(0)
0
>>> # Example with NumPy array
>>> default_convert(np.array([0, 1]))
tensor([0, 1])
>>> # Example with NamedTuple
>>> Point = namedtuple('Point', ['x', 'y'])
>>> default_convert(Point(0, 0))
Point(x=0, y=0)
>>> default_convert(Point(np.array(0), np.array(0)))
Point(x=tensor(0), y=tensor(0))
>>> # Example with List
>>> default_convert([np.array([0, 1]), np.array([2, 3])])
[tensor([0, 1]), tensor([2, 3])]
torch.utils.data.get_worker_info()[source]

返回当前 DataLoader 迭代器工作进程的信息。

当在工作进程中调用时,这将返回一个对象,该对象保证具有以下属性:

  • id:当前工作进程的ID。

  • num_workers:工作者的总数。

  • seed:为当前工作进程设置的随机种子。该值由主进程的 RNG 和工作进程 ID 决定。有关更多详细信息,请参阅 DataLoader 的文档。

  • dataset: 数据集对象在此进程中的副本。请注意,这与主进程中的对象在不同进程中会是不同的对象。

当在主进程中调用时,这将返回 None

注意

当在 worker_init_fn 传递给 DataLoader 时,此方法可以用于 为每个工作进程设置不同的配置,例如使用 worker_id 来配置 dataset 对象以仅读取分片数据集的特定部分, 或者使用 seed 来为数据集中使用的其他库设置种子。

torch.utils.data.random_split(dataset, lengths, generator=<torch._C.Generator object>)[source]

将数据集随机拆分成给定长度的非重叠新数据集。 可选地,固定生成器以获得可重复的结果,例如:

>>> random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42))
Parameters
  • 数据集 (数据集) – 要拆分的数据集

  • lengths (序列) – 要生成的拆分长度

  • 生成器 (Generator) – 用于随机排列的生成器。

class torch.utils.data.Sampler(data_source)[source]

所有 Samplers 的基类。

每个 Sampler 子类都必须提供一个 __iter__() 方法,用于遍历数据集元素的索引,并提供一个 __len__() 方法,用于返回迭代器的长度。

注意

The __len__() method isn’t strictly required by DataLoader, but is expected in any calculation involving the length of a DataLoader.

class torch.utils.data.SequentialSampler(data_source)[source]

按顺序依次采样元素,始终以相同的顺序进行。

Parameters

data_source (数据集) – 用于采样的数据集

class torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None, generator=None)[source]

随机抽取样本元素。如果不放回,则从打乱的数据集中进行抽样。 如果放回,则用户可以指定 num_samples 进行抽取。

Parameters
  • data_source (数据集) – 用于采样的数据集

  • replacement (bool) – 如果为 True,则在需要时进行有放回抽样,default=``False``

  • num_samples (int) – 要抽取的样本数量,默认=`len(dataset)`。

  • 生成器 (Generator) – 采样中使用的生成器。

class torch.utils.data.SubsetRandomSampler(indices, generator=None)[source]

从给定的索引列表中无放回地随机选取元素。

Parameters
  • 索引 (序列) – 一个索引序列

  • 生成器 (Generator) – 采样中使用的生成器。

class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True, generator=None)[source]

根据给定的概率(权重)从 [0,..,len(weights)-1] 中采样元素。

Parameters
  • weights (sequence) – 权重序列,不一定要加起来等于一

  • num_samples (int) – 要抽取的样本数量

  • replacement (bool) – 如果 True,则有放回地抽取样本。否则,无放回地抽取样本,这意味着当为某一行抽取一个样本索引时,该行不能再抽取相同的索引。

  • 生成器 (Generator) – 采样中使用的生成器。

示例

>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
[4, 4, 1, 4, 5]
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
[0, 1, 4, 3, 2]
class torch.utils.data.BatchSampler(sampler, batch_size, drop_last)[source]

将另一个采样器包装起来,以生成一个索引的小批量。

Parameters
  • sampler (SamplerIterable) – 基础采样器。可以是任何可迭代对象

  • batch_size (int) – 小批量的大小。

  • drop_last (bool) – 如果为 True,采样器将丢弃最后一个批次,如果其大小小于 batch_size

示例

>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False)[source]

限制数据加载到数据集子集的采样器。

它与 torch.nn.parallel.DistributedDataParallel 结合使用尤其有用。在这种情况下,每个进程可以传递一个 DistributedSampler 实例作为 DataLoader 采样器,并加载仅属于它的原始数据集的子集。

注意

假设数据集的大小是固定的,并且其任何实例始终以相同的顺序返回相同的元素。

Parameters
  • 数据集 – 用于采样的数据集。

  • num_replicas (int, optional) – 参与分布式训练的进程数量。默认情况下,world_size 会从当前分布式组中获取。

  • rank (int, 可选) – 当前进程在 num_replicas 中的等级。 默认情况下,从当前分布式组中获取 rank

  • shuffle (bool, optional) – 如果为 True(默认值),采样器将打乱索引。

  • seed (int, optional) – 如果 shuffle=True,则用于打乱采样器的随机种子。 分布式组中的所有进程应使用相同的数字。默认值: 0

  • drop_last (bool, 可选) – 如果 True,则采样器将丢弃数据的尾部以使其在副本数量上均匀分配。如果 False,采样器将添加额外的索引以使数据在副本之间均匀分配。默认值: False

警告

在分布式模式下,在每个 epoch 开始时调用 set_epoch() 方法 之前 创建 DataLoader 迭代器是必要的,以确保跨多个 epoch 的正确打乱。否则, 将始终使用相同的顺序。

Example:

>>> sampler = DistributedSampler(dataset) if is_distributed else None
>>> loader = DataLoader(dataset, shuffle=(sampler is None),
...                     sampler=sampler)
>>> for epoch in range(start_epoch, n_epochs):
...     if is_distributed:
...         sampler.set_epoch(epoch)
...     train(loader)

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源