torch.utils.data¶
PyTorch 数据加载工具的核心是 torch.utils.data.DataLoader
类。它表示一个对数据集的 Python 可迭代对象,并提供支持
这些选项通过DataLoader的构造函数参数进行配置,其签名如下:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
以下部分将详细描述这些选项的效果和用法。
数据集类型¶
构造函数中最重要的参数是 DataLoader,
dataset,它表示一个用于加载数据的数据集对象。PyTorch 支持两种不同类型的 dataset:
地图样式数据集¶
一种基于映射的数据集是指实现了 __getitem__() 和
__len__() 协议,并表示从(可能为非整数的)索引/键到数据样本的映射。
例如,通过使用dataset[idx]访问这样的数据集时,可以从磁盘上的文件夹中读取第idx张图像及其对应的标签。
详见 Dataset 以获取更多详细信息。
可迭代风格的数据集¶
可迭代样式数据集是 IterableDataset 的子类实例
它实现了 __iter__() 协议,并表示对数据样本的可迭代对象。这种类型的数据集特别适用于随机读取成本较高甚至不可能的情况,以及批量大小取决于获取的数据的情况。
例如,当这样的数据集被调用为iter(dataset)时,它可以返回从数据库、远程服务器甚至实时生成的日志中读取的数据流。
详见 IterableDataset 以获取更多详细信息。
注意
当使用 IterableDataset 与
多进程数据加载。相同的数据集对象会在每个工作进程中复制一次,因此必须对这些副本进行不同的配置以避免重复数据。请参阅
IterableDataset 文档了解如何实现这一点。
数据加载顺序和 Sampler¶
对于 可迭代式数据集,数据加载顺序完全由用户定义的可迭代对象控制。这使得实现分块读取和动态批量大小(例如,每次生成一个批量样本)变得更加容易。
本节的其余部分涉及
映射风格的数据集。 torch.utils.data.Sampler
类用于指定数据加载中使用的索引/键序列。
它们表示数据集索引上的可迭代对象。例如,在常见的随机梯度下降(SGD)情况下,
Sampler 可以随机打乱一个索引列表,
并逐个提供每个索引,或者为小批量 SGD 提供一小部分索引。
将根据传递给 shuffle 的参数自动构建一个顺序或随机的采样器 DataLoader.
或者,用户可以使用 sampler 参数来指定一个
自定义的 Sampler 对象,该对象每次生成下一个要获取的索引/键。
自定义的 Sampler 可以一次传递一个批次索引列表作为 batch_sampler 参数。
还可以通过 batch_size 和 drop_last 参数启用自动批处理。有关详细信息,请参阅
下一节。
注意
Neither sampler nor batch_sampler is compatible with
iterable-style datasets, since such datasets have no notion of a key or an
index.
加载批量和非批量数据¶
DataLoader 支持通过参数
batch_size, drop_last, batch_sampler, 和
collate_fn(具有默认功能)自动将单个获取的数据样本整理成批次。
自动批处理(默认)¶
这是最常见的情况,对应于获取一个数据小批量(minibatch)并将其整理为批量样本,即包含张量的维度中有一个是批量维度(通常是第一个维度)。
当 batch_size(默认值为 1)不是 None 时,数据加载器将提供批量样本而不是单个样本。使用 batch_size 和 drop_last 参数来指定数据加载器如何获取数据集键的批次。对于映射式数据集,用户也可以选择性地指定 batch_sampler,它每次生成一组键。
注意
batch_size 和 drop_last 参数本质上用于从 sampler 构建一个 batch_sampler。对于映射式数据集,sampler 由用户提供或基于 shuffle 参数构建。对于可迭代式数据集,sampler 是一个虚拟的无限采样器。有关采样器的更多详细信息,请参见 本节。
在使用采样器中的索引获取样本列表后,作为collate_fn参数传递的函数用于将样本列表整理成批次。
在这种情况下,从映射式数据集加载大致等同于:
for indices in batch_sampler:
yield collate_fn([dataset[i] for i in indices])
并从可迭代样式数据集加载大致等同于:
dataset_iter = iter(dataset)
for indices in batch_sampler:
yield collate_fn([next(dataset_iter) for _ in indices])
可以使用自定义的 collate_fn 来自定义排序规则,例如将序列数据填充到批次的最大长度。有关 collate_fn 的更多信息,请参阅
此部分。
禁用自动批处理¶
在某些情况下,用户可能希望在数据集代码中手动处理批处理,或者仅加载单个样本。例如,直接加载批量数据(例如,从数据库进行批量读取或读取连续的内存块)可能更高效,或者批量大小依赖于数据本身,或者程序设计用于处理单个样本。在这些场景下,最好不使用自动批处理(其中collate_fn 用于合并样本),而是让数据加载器直接返回 dataset 对象中的每个成员。
当 batch_size 和 batch_sampler 都为 None(batch_sampler 的默认值已经是 None)时,自动批处理将被禁用。从 dataset 中获取的每个样本都会使用作为 collate_fn 参数传递的函数进行处理。
当自动批处理被禁用时,默认值 collate_fn 会简单地将 NumPy 数组转换为 PyTorch 张量,并保持其他内容不变。
在这种情况下,从映射式数据集加载大致等同于:
for index in sampler:
yield collate_fn(dataset[index])
并从可迭代样式数据集加载大致等同于:
for data in iter(dataset):
yield collate_fn(data)
请参阅此部分以了解有关collate_fn的更多信息。
使用 collate_fn¶
当启用了自动批处理或禁用时,collate_fn 的使用略有不同。
当自动批处理被禁用时,collate_fn 会针对每个单独的数据样本进行调用,并从数据加载器迭代器中生成输出。在这种情况下,默认的 collate_fn 会简单地将 NumPy 数组转换为 PyTorch 张量。
当启用了自动批处理时,collate_fn 每次会以一个数据样本列表作为参数调用。预期该函数将这些输入样本整理成一个批次,并从数据加载器迭代器中返回。本节其余部分描述默认的 collate_fn
(default_collate()) 的行为。
例如,如果每个数据样本由一个3通道图像和一个整数类别标签组成,即数据集的每个元素返回一个元组
(image, class_index),默认的 collate_fn 会将这些元组列表合并为一个包含批量图像张量和批量类别标签张量的元组。特别是,默认的 collate_fn 具有以下属性:
它始终会在前面添加一个新的维度作为批量维度。
它会自动将 NumPy 数组和 Python 数值转换为 PyTorch 张量。
它保留了数据结构,例如,如果每个样本是一个字典,它会输出一个具有相同键集的字典,但值是批量化的张量(Tensor)(如果无法转换为张量,则为列表)。对于
lists,tuples,namedtuples 等也是如此。
用户可以使用自定义的 collate_fn 来实现自定义批处理,例如,
沿着除第一个维度外的其他维度进行拼接,填充不同长度的序列,或添加对自定义数据类型的支持。
如果你遇到一种情况,其中 DataLoader 的输出维度或类型与你的预期不同,你可能需要检查你的 collate_fn。
单进程和多进程数据加载¶
A DataLoader 默认使用单进程数据加载。
在一个Python进程中,
全局解释器锁(GIL)
会阻止Python代码在多线程之间实现真正的并行执行。为了避免数据加载阻塞计算代码,PyTorch 提供了一个简单的切换机制,通过将参数 num_workers 设置为一个正整数来实现多进程数据加载。
单进程数据加载(默认)¶
在此模式下,数据获取在与
DataLoader 初始化的同一进程中完成。因此,数据加载可能会阻塞计算。然而,当用于在进程之间共享数据的资源(例如,共享内存、文件描述符)有限时,或者整个数据集较小且可以完全加载到内存中时,可能更倾向于使用此模式。此外,单进程加载通常显示更易读的错误跟踪信息,因此对于调试很有用。
多进程数据加载¶
将参数 num_workers 设置为正整数将启用多进程数据加载,并使用指定数量的加载器工作进程。
警告
经过几次迭代后,加载器工作进程将消耗与父进程相同数量的CPU内存,这些内存用于父进程中所有被工作进程访问的Python对象。如果数据集包含大量数据(例如,在构建数据集时加载了非常大的文件名列表)和/或使用了很多工作进程(总体内存使用量为number of workers * size of parent process),这可能会出现问题。最简单的解决方法是用非引用计数表示形式(如Pandas、Numpy或PyArrow对象)替换Python对象。有关此问题发生原因及如何解决这些问题的示例代码,请查看
issue #13246.
在此模式下,每次创建一个 DataLoader
的迭代器(例如,当你调用 enumerate(dataloader) 时),会创建 num_workers
个 worker 进程。此时,dataset、
collate_fn 和 worker_init_fn 会被传递给每个
worker,在那里用于初始化和获取数据。这意味着数据集访问及其内部 IO、转换操作
(包括 collate_fn)将在 worker 进程中运行。
torch.utils.data.get_worker_info() 返回工作进程中各种有用的信息(包括工作进程 ID、数据集副本、初始种子等),而在主进程中返回 None。用户可以在数据集代码和/或 worker_init_fn 中使用此函数,为每个数据集副本分别进行配置,并确定代码是否在工作进程中运行。例如,这在对数据集进行分片时特别有帮助。
对于基于映射的数据集,主进程使用
sampler 生成索引并将其发送给工作进程。因此,任何洗牌的随机化都在主进程中完成,通过分配要加载的索引来指导数据加载。
对于可迭代样式的数据集,由于每个工作进程都会获得
dataset 对象的副本,因此普通的多进程加载通常会导致数据重复。使用 torch.utils.data.get_worker_info() 和/或
worker_init_fn,用户可以独立配置每个副本。(有关如何实现这一点,请参阅
IterableDataset 文档。) 出于类似的原因,在多进程加载中,drop_last
参数会丢弃每个工作进程的可迭代样式数据集副本中的最后一个非完整批次。
在迭代结束时,或者当迭代器被垃圾回收时,工作进程将被关闭。
警告
通常不建议在多进程加载中返回CUDA张量,因为使用CUDA和在多进程中共享CUDA张量有许多细微之处(参见多进程中的CUDA)。相反,我们建议使用自动内存固定(即设置pin_memory=True),这可以实现快速的数据传输到CUDA-enabled GPU。
特定平台的行为¶
由于工作进程依赖于 Python multiprocessing,因此在 Windows 上启动工作进程的行为与 Unix 不同。
在 Unix 系统中,
fork()是默认的multiprocessing启动方法。 使用fork(),子工作进程通常可以直接通过克隆的地址空间访问dataset和 Python 参数函数。在 Windows 或 MacOS 上,
spawn()是默认的multiprocessing启动方法。 使用spawn(),会启动另一个解释器来运行你的主脚本, 然后是内部的工作函数,它通过dataset、collate_fn和其他参数的pickle序列化接收它们。
这种独立的序列化意味着在使用多进程数据加载时,为确保与 Windows 兼容,您应采取两个步骤:
将你主脚本中的大部分代码放在
if __name__ == '__main__':块中, 以确保在每个工作进程启动时不会再次运行(很可能会产生错误)。你可以在此处放置数据集和DataLoader实例的创建逻辑,因为它不需要在工作进程中重新执行。请确保任何自定义的
collate_fn,worker_init_fn或dataset代码被声明为顶级定义,位于__main__检查之外。这可以确保它们在工作进程中可用。 (这是必需的,因为函数仅作为引用被序列化,而不是bytecode。)
多进程数据加载中的随机性¶
默认情况下,每个工作进程的 PyTorch 种子将被设置为 base_seed + worker_id,
其中 base_seed 是由主进程使用其随机数生成器(从而强制消耗一个随机数状态)生成的长整数,或者是指定的 generator。然而,在初始化工作进程时,其他库的种子可能会重复,导致每个工作进程返回相同的随机数。(请参阅 FAQ 中的 此部分。)。
在 worker_init_fn 中,您可以使用 torch.utils.data.get_worker_info().seed
或 torch.initial_seed() 访问每个工作器的 PyTorch 随机种子集,并在数据加载之前用它来设置其他库的随机种子。
内存固定¶
从固定(页面锁定)内存中进行主机到 GPU 的复制速度更快。有关何时以及如何一般性地使用固定内存的更多详细信息,请参阅 使用固定内存缓冲区。
对于数据加载,将 pin_memory=True 传递给
DataLoader 将会自动把获取的数据
张量放入固定内存中,从而实现更快地传输到支持 CUDA 的 GPU 上。
默认的内存固定逻辑只识别张量(Tensors)以及包含张量的地图和可迭代对象。如果固定逻辑发现一个批次是自定义类型(如果你有一个collate_fn返回自定义批次类型),或者如果你批次中的每个元素都是自定义类型,那么固定逻辑将无法识别它们,并且会不进行内存固定地返回该批次(或这些元素)。要为自定义批次或数据类型启用内存固定,请在你的自定义类型上定义一个pin_memory()方法。
请参见下面的示例。
Example:
class SimpleCustomBatch:
def __init__(self, data):
transposed_data = list(zip(*data))
self.inp = torch.stack(transposed_data[0], 0)
self.tgt = torch.stack(transposed_data[1], 0)
# custom memory pinning method on custom type
def pin_memory(self):
self.inp = self.inp.pin_memory()
self.tgt = self.tgt.pin_memory()
return self
def collate_wrapper(batch):
return SimpleCustomBatch(batch)
inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)
loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
pin_memory=True)
for batch_ndx, sample in enumerate(loader):
print(sample.inp.is_pinned())
print(sample.tgt.is_pinned())
- class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device='')[source]¶
数据加载器将数据集和采样器结合在一起,并为给定的数据集提供一个可迭代对象。
The
DataLoader支持单进程或多进程加载的映射式和可迭代式数据集,可以自定义加载顺序,并具有可选的自动批处理(整理)和内存固定功能。请参阅
torch.utils.data文档页面以获取更多详细信息。- Parameters
数据集 (Dataset) – 用于加载数据的数据集。
batch_size (int, 可选) – 每个批次加载多少样本 (默认:
1)。shuffle (bool, optional) – 设置为
True以在每个训练周期重新洗牌数据 (默认:False)。sampler (Sampler 或 Iterable, 可选) – 定义从数据集中抽取样本的策略。可以是任何实现了
Iterable的__len__。 如果指定了此参数,则必须不能指定shuffle。batch_sampler (Sampler or Iterable, optional) – 类似于
sampler,但 每次返回一批索引。与batch_size,shuffle,sampler, 和drop_last互斥。num_workers (int, optional) – 使用多少个子进程来加载数据。
0表示数据将在主进程中加载。(默认值:0)collate_fn (Callable, 可选) – 合并一个样本列表以形成一个 张量的小批量。当使用来自映射式数据集的批量加载时使用。
pin_memory (bool, optional) – 如果
True,数据加载器将在返回张量之前将其复制到设备/CUDA固定内存中。如果您的数据元素是自定义类型,或者您的collate_fn返回的批次是自定义类型,请参见下面的示例。drop_last (bool, optional) – 设置为
True以丢弃最后一个不完整的批次, 如果数据集大小不能被批次大小整除。如果设置为False并且数据集大小不能被批次大小整除,则最后一个批次会更小。(默认:False)timeout (数字类型, 可选) – 如果为正数,则是从工作进程中收集一个批次数据的超时值。该值应始终为非负数。(默认:
0)worker_init_fn (Callable, optional) – 如果不是
None,则在每个 工作子进程上会使用工作器 ID(一个在[0, num_workers - 1]中的 int)作为 输入,在设置随机种子之后和加载数据之前调用。 (默认值:None)multiprocessing_context (str 或 multiprocessing.context.BaseContext, 可选) – 如果
None, 将使用操作系统默认的 multiprocessing context。 (默认:None)生成器 (torch.Generator, 可选) – 如果不是
None,此 RNG 将被 RandomSampler 用来生成随机索引,并用于多进程生成base_seed给工作人员。 (默认值:None)prefetch_factor (int, optional, keyword-only arg) – 每个 worker 提前加载的批次数量。
2表示所有 worker 总共会预取 2 * num_workers 批次。(默认值取决于 num_workers 的设置值。如果 num_workers=0,默认值是None。否则,如果num_workers > 0的值为默认值,则是2).persistent_workers (bool, 可选) – 如果为
True,数据加载器在数据集被消费一次后不会关闭工作进程。这允许保持工作进程 Dataset 个实例处于活动状态。(默认:False)pin_memory_device (str, optional) – the device to
pin_memoryto ifpin_memoryisTrue.
警告
如果使用
spawn启动方法,worker_init_fn不能是不可序列化的对象,例如 lambda 函数。有关 PyTorch 中多进程的更多详细信息,请参阅 多进程最佳实践。警告
len(dataloader)启发式方法是基于所使用的采样器的长度。 当dataset是一个IterableDataset时, 它会根据len(dataset) / batch_size返回一个估计值,并根据drop_last进行适当的 四舍五入,无论多进程加载配置如何。这代表了 PyTorch 能够做出的最佳猜测,因为 PyTorch 信任用户正确处理多进程加载以避免重复数据的dataset代码。然而,如果分片导致多个工作器拥有不完整的最后一批数据, 这个估计仍然可能不准确,因为 (1) 一个原本完整的批次可以 被拆分成多个批次,以及 (2) 当设置
drop_last时,可能会丢弃超过一个批次的数据。不幸的是,PyTorch 无法在一般情况下检测到这些情况。查看 数据集类型 以了解这两种数据集的更多详细信息,以及
IterableDataset如何与 多进程数据加载 交互。警告
查看 可复现性,以及 我的数据加载器工作进程返回相同的随机数 和 多进程数据加载中的随机性 的说明以了解与随机种子相关的问题。
- class torch.utils.data.Dataset(*args, **kwds)[source]¶
一个表示
Dataset的抽象类。所有将键映射到数据样本的数据集都应继承它。所有子类应覆盖
__getitem__(),以支持根据给定的键获取一个数据样本。子类还可以选择性地覆盖__len__(),许多Sampler实现和DataLoader的默认选项期望该方法返回数据集的大小。子类还可以选择性地实现__getitems__(),以加快批量样本加载的速度。此方法接受一批样本的索引列表,并返回样本列表。注意
DataLoader默认会构建一个索引采样器,该采样器生成整数索引。要使其与使用非整数索引/键的映射式数据集一起工作,必须提供自定义采样器。
- class torch.utils.data.IterableDataset(*args, **kwds)[source]¶
一个可迭代的数据集。
所有表示数据样本可迭代的数据集都应继承它。 当数据来自流时,这种形式的数据集特别有用。
所有子类都应覆盖
__iter__(),该方法将返回一个迭代器,用于遍历此数据集中的样本。当使用子类与
DataLoader时,数据集中的每个项目将从DataLoader迭代器中生成。当num_workers > 0时,每个工作进程将拥有一个不同的数据集对象副本,因此通常需要独立配置每个副本以避免从工作进程中返回重复的数据。get_worker_info()在工作进程中调用时会返回有关该工作进程的信息。它可以在数据集的__iter__()方法或DataLoader的worker_init_fn选项中使用,以修改每个副本的行为。示例 1:在
__iter__()中将工作负载分配到所有工作人员上:>>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... worker_info = torch.utils.data.get_worker_info() ... if worker_info is None: # single-process data loading, return the full iterator ... iter_start = self.start ... iter_end = self.end ... else: # in a worker process ... # split workload ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... iter_start = self.start + worker_id * per_worker ... iter_end = min(iter_start + per_worker, self.end) ... return iter(range(iter_start, iter_end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [tensor([3]), tensor([4]), tensor([5]), tensor([6])] >>> # Mult-process loading with two worker processes >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [tensor([3]), tensor([5]), tensor([4]), tensor([6])] >>> # With even more workers >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12))) [tensor([3]), tensor([5]), tensor([4]), tensor([6])]
示例 2:使用
worker_init_fn将工作负载分配到所有工作人员上:>>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... return iter(range(self.start, self.end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [3, 4, 5, 6] >>> >>> # Directly doing multi-process loading yields duplicate data >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [3, 3, 4, 4, 5, 5, 6, 6] >>> # Define a `worker_init_fn` that configures each dataset copy differently >>> def worker_init_fn(worker_id): ... worker_info = torch.utils.data.get_worker_info() ... dataset = worker_info.dataset # the dataset copy in this worker process ... overall_start = dataset.start ... overall_end = dataset.end ... # configure the dataset to only process the split workload ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... dataset.start = overall_start + worker_id * per_worker ... dataset.end = min(dataset.start + per_worker, overall_end) ... >>> # Mult-process loading with the custom `worker_init_fn` >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn))) [3, 5, 4, 6] >>> # With even more workers >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn))) [3, 4, 5, 6]
- class torch.utils.data.TensorDataset(*tensors)[source]¶
封装张量的数据集。
每个样本将通过沿着第一个维度对张量进行索引来检索。
- Parameters
*张量 (张量) – 张量的第一个维度大小相同。
- class torch.utils.data.StackDataset(*args, **kwargs)[source]¶
将多个数据集堆叠起来形成的数据集。
此类对于组合作为数据集提供的复杂输入数据的不同部分非常有用。
示例
>>> images = ImageDataset() >>> texts = TextDataset() >>> tuple_stack = StackDataset(images, texts) >>> tuple_stack[0] == (images[0], texts[0]) >>> dict_stack = StackDataset(image=images, text=texts) >>> dict_stack[0] == {'image': images[0], 'text': texts[0]}
- class torch.utils.data.ConcatDataset(datasets)[source]¶
将多个数据集串联起来形成的数据集。
此类有助于整合不同的现有数据集。
- Parameters
数据集 (序列) – 要连接的数据集列表
- class torch.utils.data.ChainDataset(datasets)[source]¶
用于连接多个
IterableDataset的数据集。此类对于组合不同的现有数据集流非常有用。链接操作是在运行时进行的,因此使用此类连接大规模数据集将非常高效。
- Parameters
数据集 (可迭代对象 of IterableDataset) – 要串联在一起的数据集
- class torch.utils.data.Subset(dataset, indices)[source]¶
指定索引处的数据集子集。
- Parameters
dataset (数据集) – 整个数据集
索引 (序列) – 在整个数据集中选择子集的索引
- torch.utils.data._utils.collate.collate(batch, *, collate_fn_map=None)[source]¶
一个通用的 collate 函数,用于处理每个批次中元素的集合类型。
该函数还打开函数注册表以处理特定元素类型。 default_collate_fn_map 为张量、numpy数组、数字和字符串提供默认的整理函数。
- Parameters
示例
>>> def collate_tensor_fn(batch, *, collate_fn_map): >>> # Extend this function to handle batch of tensors ... return torch.stack(batch, 0) >>> def custom_collate(batch): ... collate_map = {torch.Tensor: collate_tensor_fn} ... return collate(batch, collate_fn_map=collate_map) >>> # Extend `default_collate` by in-place modifying `default_collate_fn_map` >>> default_collate_fn_map.update({torch.Tensor: collate_tensor_fn})
注意
每个 collate 函数都需要一个用于 batch 的位置参数和一个用于 collate 函数字典的关键字参数,表示为 collate_fn_map。
- torch.utils.data.default_collate(batch)[source]¶
输入一批数据,并将批次中的元素放入一个具有额外外维(即批量大小)的张量中。
确切的输出类型可以是
torch.Tensor,一个包含 Sequence 的torch.Tensor,一个torch.Tensor的集合,或者保持不变,具体取决于输入类型。 当在DataLoader中定义了 batch_size 或 batch_sampler 时,这将用作整理的默认函数。下面是通用的输入类型(基于批量中元素的类型)到输出类型的映射:
torch.Tensor->torch.Tensor(with an added outer dimension batch size)NumPy Arrays ->
torch.Tensorfloat ->
torch.Tensorint ->
torch.Tensorstr -> str (unchanged)
bytes -> bytes (unchanged)
Mapping[K, V_i] -> Mapping[K, default_collate([V_1, V_2, …])]
NamedTuple[V1_i, V2_i, …] -> NamedTuple[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …]
Sequence[V1_i, V2_i, …] -> Sequence[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …]
- Parameters
batch – 要进行拼接的单个批次
示例
>>> # Example with a batch of `int`s: >>> default_collate([0, 1, 2, 3]) tensor([0, 1, 2, 3]) >>> # Example with a batch of `str`s: >>> default_collate(['a', 'b', 'c']) ['a', 'b', 'c'] >>> # Example with `Map` inside the batch: >>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}]) {'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])} >>> # Example with `NamedTuple` inside the batch: >>> Point = namedtuple('Point', ['x', 'y']) >>> default_collate([Point(0, 0), Point(1, 1)]) Point(x=tensor([0, 1]), y=tensor([0, 1])) >>> # Example with `Tuple` inside the batch: >>> default_collate([(0, 1), (2, 3)]) [tensor([0, 2]), tensor([1, 3])] >>> # Example with `List` inside the batch: >>> default_collate([[0, 1], [2, 3]]) [tensor([0, 2]), tensor([1, 3])] >>> # Two options to extend `default_collate` to handle specific type >>> # Option 1: Write custom collate function and invoke `default_collate` >>> def custom_collate(batch): ... elem = batch[0] ... if isinstance(elem, CustomType): # Some custom condition ... return ... ... else: # Fall back to `default_collate` ... return default_collate(batch) >>> # Option 2: In-place modify `default_collate_fn_map` >>> def collate_customtype_fn(batch, *, collate_fn_map=None): ... return ... >>> default_collate_fn_map.update(CustoType, collate_customtype_fn) >>> default_collate(batch) # Handle `CustomType` automatically
- torch.utils.data.default_convert(data)[source]¶
将每个NumPy数组元素转换为
torch.Tensor。如果输入是 Sequence, Collection 或 Mapping,它会尝试将每个内部元素转换为
torch.Tensor。 如果输入不是 NumPy 数组,则保持不变。 当在DataLoader中未定义 batch_sampler 和 batch_size 时,此函数用作默认的合并函数。从一般输入类型到输出类型的映射类似于
default_collate()。有关更多详细信息,请参阅那里的描述。- Parameters
data – 要转换的单个数据点
示例
>>> # Example with `int` >>> default_convert(0) 0 >>> # Example with NumPy array >>> default_convert(np.array([0, 1])) tensor([0, 1]) >>> # Example with NamedTuple >>> Point = namedtuple('Point', ['x', 'y']) >>> default_convert(Point(0, 0)) Point(x=0, y=0) >>> default_convert(Point(np.array(0), np.array(0))) Point(x=tensor(0), y=tensor(0)) >>> # Example with List >>> default_convert([np.array([0, 1]), np.array([2, 3])]) [tensor([0, 1]), tensor([2, 3])]
- torch.utils.data.get_worker_info()[source]¶
返回当前
DataLoader迭代器工作进程的信息。当在工作进程中调用时,这将返回一个对象,该对象保证具有以下属性:
id:当前工作进程的ID。num_workers:工作者的总数。seed:为当前工作进程设置的随机种子。该值由主进程的 RNG 和工作进程 ID 决定。有关更多详细信息,请参阅DataLoader的文档。dataset: 在此进程中的数据集对象的副本。请注意,这在其他进程中将是一个不同的对象,而不是主进程中的那个。
当在主进程中调用时,这将返回
None。注意
当在
worker_init_fn传递给DataLoader时,此方法可以用于 为每个工作进程设置不同的配置,例如使用worker_id来配置dataset对象以仅读取分片数据集的特定部分, 或者使用seed来为数据集中使用的其他库设置种子。- Return type
可选[工作信息]
- torch.utils.data.random_split(dataset, lengths, generator=<torch._C.Generator object>)[source]¶
将数据集随机划分为给定长度的不重叠的新数据集。
如果提供了一组加起来等于 1 的分数列表, 长度将自动计算为 floor(frac * len(dataset)),每个提供的分数都会如此计算。
在计算长度后,如果存在任何余数,将以循环方式将 1 的计数分配给这些长度,直到没有余数为止。
可选地固定生成器以获得可重复的结果,例如:
示例
>>> generator1 = torch.Generator().manual_seed(42) >>> generator2 = torch.Generator().manual_seed(42) >>> random_split(range(10), [3, 7], generator=generator1) >>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2)
- class torch.utils.data.Sampler(data_source=None)[source]¶
所有 Samplers 的基类。
每个Sampler子类都必须提供一个
__iter__()方法,该方法提供了一种迭代数据集元素索引或索引列表(批次)的方式,以及一个__len__()方法, 该方法返回返回的迭代器的长度。- Parameters
data_source (数据集) – 此参数未被使用,并将在 2.2.0 版本中移除。 您可能仍然有自定义实现会利用它。
示例
>>> class AccedingSequenceLengthSampler(Sampler[int]): >>> def __init__(self, data: List[str]) -> None: >>> self.data = data >>> >>> def __len__(self) -> int: >>> return len(self.data) >>> >>> def __iter__(self) -> Iterator[int]: >>> sizes = torch.tensor([len(x) for x in self.data]) >>> yield from torch.argsort(sizes).tolist() >>> >>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]): >>> def __init__(self, data: List[str], batch_size: int) -> None: >>> self.data = data >>> self.batch_size = batch_size >>> >>> def __len__(self) -> int: >>> return (len(self.data) + self.batch_size - 1) // self.batch_size >>> >>> def __iter__(self) -> Iterator[List[int]]: >>> sizes = torch.tensor([len(x) for x in self.data]) >>> for batch in torch.chunk(torch.argsort(sizes), len(self)): >>> yield batch.tolist()
注意
The
__len__()method isn’t strictly required byDataLoader, but is expected in any calculation involving the length of aDataLoader.
- class torch.utils.data.SequentialSampler(data_source)[source]¶
按顺序依次采样元素,始终以相同的顺序进行。
- Parameters
data_source (数据集) – 用于采样的数据集
- class torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None, generator=None)[source]¶
随机选取元素。如果无放回,则从打乱的数据集中采样。
如果允许替换,则用户可以指定
num_samples来绘制。
- class torch.utils.data.SubsetRandomSampler(indices, generator=None)[source]¶
从给定的索引列表中无放回地随机选取元素。
- Parameters
索引 (序列) – 一个索引序列
生成器 (Generator) – 采样中使用的生成器。
- class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True, generator=None)[source]¶
根据给定的概率(权重)从
[0,..,len(weights)-1]中采样元素。- Parameters
示例
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) [4, 4, 1, 4, 5] >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) [0, 1, 4, 3, 2]
- class torch.utils.data.BatchSampler(sampler, batch_size, drop_last)[source]¶
将另一个采样器包装起来,以生成一个索引的小批量。
- Parameters
示例
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
- class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False)[source]¶
限制数据加载到数据集子集的采样器。
它与
torch.nn.parallel.DistributedDataParallel结合使用尤其有用。在这种情况下,每个进程可以传递一个DistributedSampler实例作为DataLoader采样器,并加载仅属于它的原始数据集的子集。注意
假设数据集的大小是固定的,并且其任何实例始终以相同的顺序返回相同的元素。
- Parameters
数据集 – 用于采样的数据集。
num_replicas (int, optional) – 参与分布式训练的进程数量。默认情况下,
world_size会从当前分布式组中获取。rank (int, 可选) – 当前进程在
num_replicas中的等级。 默认情况下,从当前分布式组中获取rank。shuffle (bool, optional) – 如果为
True(默认值),采样器将打乱索引。seed (int, optional) – 如果
shuffle=True,则用于打乱采样器的随机种子。 分布式组中的所有进程应使用相同的数字。默认值:0。drop_last (bool, 可选) – 如果
True,则采样器将丢弃数据的尾部以使其在副本数量上均匀分配。如果False,采样器将添加额外的索引以使数据在副本之间均匀分配。默认值:False。
警告
在分布式模式下,在每个 epoch 开始时调用
set_epoch()方法 之前 创建DataLoader迭代器是必要的,以确保跨多个 epoch 的正确打乱。否则, 将始终使用相同的顺序。Example:
>>> sampler = DistributedSampler(dataset) if is_distributed else None >>> loader = DataLoader(dataset, shuffle=(sampler is None), ... sampler=sampler) >>> for epoch in range(start_epoch, n_epochs): ... if is_distributed: ... sampler.set_epoch(epoch) ... train(loader)