torch.utils.data¶
PyTorch 数据加载实用程序的核心是类。它表示数据集上的 Python 可迭代对象,支持
这些选项由 a 的构造函数参数配置,该参数具有 signature:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
以下各节详细介绍了这些选项的效果和用法。
数据集类型¶
constructor 最重要的参数是
,它表示要加载数据的 dataset 对象
从。PyTorch 支持两种不同类型的数据集:
地图样式数据集¶
地图风格的数据集是实现 and 协议的数据集,并表示来自 (可能是非整数)
indices/keys 添加到数据样本中。__getitem__()
__len__()
例如,这样的数据集在使用 访问时可以读取
磁盘上文件夹中的第 -th 映像及其相应的标签。dataset[idx]
idx
数据加载顺序 和
¶
对于可迭代样式的数据集,数据加载顺序 完全由用户定义的可迭代对象控制。这允许更容易 块读取和动态批量大小的实现(例如,通过生成 批量采样)。
本节的其余部分涉及地图样式数据集的情况。类用于指定数据加载中使用的索引/键序列。
它们表示数据集索引上的可迭代对象。例如,在
随机梯度 Decent (SGD) 的常见情况,a
可以随机排列索引列表
并一次产生每一个,或者产生少量的小批量
新币。
顺序或随机采样器将根据 的参数自动构造。
或者,用户可以使用该
参数来指定
自定义
对象,该对象在每次
要获取的下一个索引/键。
shuffle
生成 batch 列表的自定义
indices 可以作为参数传递。
也可以通过 and 参数启用自动批处理。有关更多详细信息,请参阅下一节
在这个。
batch_sampler
batch_size
drop_last
加载批处理和非批处理数据¶
支持自动分套
单个通过参数 、 、 和 (具有 default 函数) 将数据样本提取到批次中。
batch_size
drop_last
batch_sampler
collate_fn
自动批处理 (默认)¶
这是最常见的情况,对应于获取 data 并将它们整理成批量样本,即包含 一个维度是批次维度(通常是第一个维度)。
当 (default ) 为 not 时,数据加载器会生成
批量样本,而不是单个样本。 和 arguments 用于指定数据加载器如何获取
批量的数据集键。对于地图样式的数据集,用户也可以
specify ,一次生成一个键列表。batch_size
1
None
batch_size
drop_last
batch_sampler
注意
和 参数基本上被使用
构造一个 from .对于地图样式
数据集,则 要么由用户提供,要么由
基于参数。对于可迭代样式的数据集,this
是一个虚拟的无限 1。有关更多详细信息,请参阅此部分
取样。
batch_size
drop_last
batch_sampler
shuffle
使用 sampler 中的 indices 获取样本列表后,函数
作为参数传递的 Passed 用于整理样本列表
分批进行。collate_fn
在这种情况下,从地图样式数据集加载大致相当于:
for indices in batch_sampler:
yield collate_fn([dataset[i] for i in indices])
从可迭代样式的数据集加载大致等同于:
dataset_iter = iter(dataset)
for indices in batch_sampler:
yield collate_fn([next(dataset_iter) for _ in indices])
自定义可用于自定义排序规则,例如填充
Sequential data 设置为批处理的最大长度。请参阅本节,了解有关 .collate_fn
collate_fn
禁用自动批处理¶
在某些情况下,用户可能希望在数据集代码中手动处理批处理。
或者简单地加载单个样品。例如,直接
加载批处理数据(例如,从数据库批量读取或连续读取
块内存),或者批处理大小取决于数据,或者程序是
设计用于处理单个样品。在这些情况下,很可能会
最好不要使用自动批处理(其中 用于
整理样本),但让数据加载器直接返回
对象。
collate_fn
当 和 are (默认
的值已经),自动批处理是
禁用。从 中获得的每个样品都使用
函数作为参数传递。
batch_size
batch_sampler
None
batch_sampler
None
collate_fn
禁用自动批处理时,默认的
将 NumPy 数组转换为 PyTorch 张量,并保持其他所有内容保持不变。collate_fn
在这种情况下,从地图样式数据集加载大致相当于:
for index in sampler:
yield collate_fn(dataset[index])
从可迭代样式的数据集加载大致等同于:
for data in iter(dataset):
yield collate_fn(data)
请参阅本节,了解有关 .collate_fn
使用collate_fn
¶
当自动批处理时,用途略有不同
enabled 或 disabled。collate_fn
禁用自动批处理时,使用
每个单独的数据样本和输出都是从 Data Loader 生成的
迭 代。在这种情况下,默认只是将 NumPy
数组。collate_fn
collate_fn
启用自动批处理后,使用列表调用
的数据样本。它应将输入样本整理到
用于从 Data Loader 迭代器生成 Batch 的 Batch。本节的其余部分
描述默认 () 的行为。
collate_fn
collate_fn
例如,如果每个数据样本都由一个 3 通道图像和一个积分
class 标签,即 dataset 的每个元素都返回一个元组,默认整理一个
此类元组转换为批处理图像张量和批处理类的单个元组
label Tensor 的 Tensor 中。具体而言,默认值如下
性能:(image, class_index)
collate_fn
collate_fn
它始终将新维度作为批处理维度。
它会自动将 NumPy 数组和 Python 数值转换为 PyTorch 张量。
它保留了数据结构,例如,如果每个样本都是一个字典,则它 输出具有相同键集但将 Tensor 作为值的字典 (如果值无法转换为 Tensor,则列出)。相同 用于 s、s、s 等。
list
tuple
namedtuple
用户可以使用 customized 来实现自定义批处理,例如:
沿第一个维度以外的维度进行整理,填充序列
各种长度,或添加对自定义数据类型的支持。collate_fn
单进程和多进程数据加载¶
在 Python 进程中,全局解释器锁 (GIL) 可防止真正跨线程完全并行化 Python 代码。为避免阻塞
计算代码与数据加载一起,PyTorch 提供了一个简单的开关来执行
通过简单地将参数设置为正整数来多进程数据加载。num_workers
单进程数据加载(默认)¶
在此模式下,数据获取在初始化的同一进程中完成。因此,数据加载
可能会阻止计算。但是,当使用资源时,此模式可能是首选
用于在进程之间共享数据(例如,共享内存、文件描述符)是
limited,或者当整个数据集很小并且可以完全加载到
记忆。此外,单进程加载通常显示更具可读性的错误
跟踪,因此可用于调试。
多进程数据加载¶
将参数设置为正整数将
使用指定数量的 Loader 工作线程开启多进程数据加载
过程。num_workers
警告
经过几次迭代后,loader 工作进程将消耗
与所有 Python 的父进程相同的 CPU 内存量
父进程中可从 worker 访问的对象
过程。如果 Dataset 包含大量
data 中(例如,您正在 Dataset 中加载一个非常大的文件名列表
施工时间)和/或您正在使用大量工人(总体
memory usage 为 )。这
最简单的解决方法是将 Python 对象替换为 non-refcounted
表示形式,例如 Pandas、Numpy 或 PyArrow 对象。查看问题 #13246,了解有关发生这种情况的原因以及如何操作的示例代码的更多详细信息
解决方法 这些问题。number of workers * size of parent process
在此模式下,每次创建 a 的迭代器时(例如,当您调用 时),都会创建 worker 进程。此时,将
、 、 和 传递给每个
worker,它们用于初始化和获取数据。这意味着
数据集访问及其内部 IO 一起转换
(包括 ) 在 worker 进程中运行。
enumerate(dataloader)
num_workers
collate_fn
worker_init_fn
collate_fn
返回各种有用的信息
在 worker 进程中(包括 worker id、数据集副本、初始种子、
等),并在主进程中返回。用户可以在
数据集代码和/或单独配置每个
数据集副本,并确定代码是否在 worker 中运行
过程。例如,这在对数据集进行分片时特别有用。
None
worker_init_fn
对于地图样式的数据集,主进程使用 生成索引并将其发送给 worker。所以任何随机化都是
在主进程中完成,该进程通过为 Load 分配索引来指导加载。
对于可迭代样式的数据集,由于每个 worker 进程都会获得对象的副本,因此简单的多进程加载通常会导致
重复数据。使用
and/或 ,用户可以单独配置每个副本。(请参阅
文档了解如何实现
这。) 出于类似的原因,在多进程加载中,该参数会丢弃每个 worker 的可迭代样式数据集的最后一个非完整批次
复制品。
worker_init_fn
drop_last
一旦到达迭代结束,或者当 iterator 变为垃圾回收。
警告
一般不建议在多进程中返回 CUDA 张量
loading 的原因,因为使用 CUDA 和在
multiprocessing (请参阅 multiprocessing 中的 CUDA)。相反,我们建议
使用自动内存固定(即 setting ),从而可以将数据快速传输到启用 CUDA 的
GPU 的 GPU 。pin_memory=True
特定于平台的行为¶
由于 worker 依赖于 Python,因此 worker 启动行为是
Windows 与 Unix 不同。
在 Unix 上,是默认
的启动方法。 使用 ,子工作程序通常可以访问
和 Python 参数直接通过克隆的地址空间执行函数。
fork()
fork()
在 Windows 或 MacOS 上,是默认
的启动方法。 使用 ,将启动另一个解释器,该解释器运行您的主脚本 后跟内部 worker 函数,该函数通过
序列化接收
、 和其他参数。
spawn()
spawn()
collate_fn
这种单独的序列化意味着您应该采取两个步骤来确保 在使用多进程数据加载时与 Windows 兼容:
内存固定¶
主机到 GPU 的副本源自固定(页面锁定)时要快得多 记忆。有关何时以及如何使用的更多详细信息,请参阅使用固定内存缓冲区 固定内存。
对于数据加载,传递给 将自动将获取的数据放入
Tensor 的 Tensor 存储在固定内存中,从而可以更快地将数据传输到支持 CUDA 的
GPU 的 GPU 。
pin_memory=True
默认内存固定逻辑仅识别 Tensor 和 map 以及可迭代对象
包含 Tensor。默认情况下,如果固定逻辑看到一个
自定义类型(如果您的 a 返回
custom batch 类型),或者如果 Batch 的每个元素都是自定义类型,则
pinning logic 将无法识别它们,并且会返回该 batch(或那些
元素),而无需固定内存。为自定义启用内存固定
batch 或数据类型,在自定义
类型。collate_fn
pin_memory()
请参阅下面的示例。
例:
class SimpleCustomBatch:
def __init__(self, data):
transposed_data = list(zip(*data))
self.inp = torch.stack(transposed_data[0], 0)
self.tgt = torch.stack(transposed_data[1], 0)
# custom memory pinning method on custom type
def pin_memory(self):
self.inp = self.inp.pin_memory()
self.tgt = self.tgt.pin_memory()
return self
def collate_wrapper(batch):
return SimpleCustomBatch(batch)
inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)
loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
pin_memory=True)
for batch_ndx, sample in enumerate(loader):
print(sample.inp.is_pinned())
print(sample.tgt.is_pinned())
- 类 torch.utils.data 中。DataLoader(dataset, batch_size=1, shuffle=无, sampler=无, batch_sampler=无、num_workers=0、collate_fn=无、pin_memory=False、drop_last=False、超时=0, worker_init_fn=无, multiprocessing_context=无, generator=无, *, prefetch_factor=无, persistent_workers=假, pin_memory_device='')[来源]¶
Data loader 将 dataset 和 sampler 组合在一起,并在给定数据集上提供可迭代对象。
支持
map-style 和 具有单进程或多进程加载、自定义的可迭代样式数据集 加载顺序和可选的自动批处理(排序规则)和内存固定。
- 参数
dataset (Dataset) – 从中加载数据的数据集。
batch_size (int, optional) – 每批要加载的样本数 (默认值:)。
1
shuffle (bool, optional) – 设置为重新洗牌数据 在每个 epoch (默认值: )。
True
False
sampler (Sampler 或 Iterable,可选) – 定义要绘制的策略 数据集中的样本。可以是任何已实施的。如果指定,则不得指定。
Iterable
__len__
shuffle
batch_sampler (Sampler 或 Iterable,可选) – 类似于
,但 一次返回一批索引。与 、 、
互斥 和。
batch_size
shuffle
drop_last
num_workers (int, optional) – 用于数据的子进程数 装载。 表示数据将在主进程中加载。 (默认:
0
0
)collate_fn (Callable, optional) – 合并样本列表以形成 小批量的 Tensor 中。当使用 batch loading from 地图样式数据集。
pin_memory (bool, optional) – 如果 ,数据加载器将复制 Tensor 放入 device/CUDA 固定内存中。如果您的数据元素 是自定义类型,或者您返回的批次是自定义类型, 请参阅下面的示例。
True
collate_fn
drop_last (bool, optional) – 设置为 以删除最后一个未完成的批次, 如果数据集大小不能被批量大小整除。If 和 数据集的大小不能被批次大小整除,然后是最后一个批次 会更小。(默认:
True
False
False
)timeout (numeric, optional) – 如果为正数,则为收集批次的超时值 从工人。应始终为非负数。(默认:
0
)worker_init_fn (Callable, optional) – 如果不是 ,则将在每个 worker 子进程,其中 worker id ( int in ) 为 input、seeding 之后和 data loading 之前。(默认:
None
[0, num_workers - 1]
None
)multiprocessing_context (str 或 multiprocessing.context.BaseContext,可选) – 如果 ,则操作系统的默认多处理上下文将 被使用。(默认:
None
None
)发电机 (Torch.生成器,可选) – 如果没有,将使用此 RNG 通过 RandomSampler 生成随机索引,并通过 multiprocessing 为 worker 生成。(默认:
None
base_seed
None
)prefetch_factor (int, optional, keyword-only arg) – 加载的批次数 由每个 worker 提前完成。 表示总共会有 2 * num_workers 个批次,在所有工作程序中预取。(默认值取决于 在 num_workers 的 Set 值上。如果值 num_workers=0,则默认值为 。 否则,如果 default 的值为 )。
2
None
num_workers > 0
2
persistent_workers (bool, optional) – 如果 ,则数据加载器不会关闭 工作程序在 dataset 被使用一次后进行处理。这允许 保持 worker Dataset 实例处于活动状态。(默认:
True
False
)pin_memory_device (str, optional) – 如果设备为 。
pin_memory
pin_memory
True
警告
如果使用 start 方法,则不能是不可封存的对象,例如 lambda 函数。有关更多详细信息,请参阅多处理最佳实践 添加到 PyTorch 中的 multiprocessing 中。
spawn
worker_init_fn
警告
len(dataloader)
启发式 (heuristic) 基于所使用的采样器的长度。 当为
时 , 相反,它返回基于 的估计值,并使用适当的 舍入取决于 ,而不考虑多进程加载 配置。这代表了 PyTorch 可以做出的最佳猜测,因为 PyTorch 信任用户
代码正确处理多进程 loading 以避免重复数据。
len(dataset) / batch_size
drop_last
但是,如果分片导致多个 worker 具有不完整的最后一批,则 此估计仍然可能不准确,因为 (1) 否则完整的批次可能 被分成多个 1 和 (2) 多个批次的样品可以是 set 时丢弃。不幸的是,PyTorch 无法检测到此类 一般情况。
drop_last
请参阅 数据集类型 有关这两种类型的数据集以及如何与多进程数据加载交互的更多详细信息。
警告
有关随机种子相关问题,请参阅可重复性和我的数据加载器工作程序返回相同的随机数和多进程数据加载说明中的随机性。
- 类 torch.utils.data 中。数据集[来源]¶
-
表示从键到数据样本的映射的所有数据集都应子类化 它。所有子类都应该覆盖 ,支持获取 data 样本。子类也可以选择覆盖 ,预计许多实现和默认选项将返回数据集
的大小 的
.子类也可以 (可选)实施 ,用于加速批处理样本 装载。此方法接受 batch 的样本索引列表并返回 样本列表。
__getitem__()
__len__()
__getitems__()
- 类 torch.utils.data 中。IterableDataset[来源]¶
一个可迭代的 Dataset。
表示数据样本可迭代对象的所有数据集都应将其子类化。 当数据来自流时,这种形式的数据集特别有用。
所有子类都应覆盖 ,这将返回一个 此数据集中样本的迭代器。
__iter__()
当子类与
一起使用时,每个 item 将从迭代器中生成
。当 ,每个工作进程都将具有 数据集对象的不同副本,因此通常需要配置 每个副本,以避免从 工人。
,在 worker 中调用 process 返回有关工作程序的信息。它可以用于 dataset 的方法或
的选项来修改每个副本的行为。
num_workers > 0
__iter__()
worker_init_fn
示例 1:在 中的所有工作线程之间拆分工作负载:
__iter__()
>>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... worker_info = torch.utils.data.get_worker_info() ... if worker_info is None: # single-process data loading, return the full iterator ... iter_start = self.start ... iter_end = self.end ... else: # in a worker process ... # split workload ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... iter_start = self.start + worker_id * per_worker ... iter_end = min(iter_start + per_worker, self.end) ... return iter(range(iter_start, iter_end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [tensor([3]), tensor([4]), tensor([5]), tensor([6])] >>> # Mult-process loading with two worker processes >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [tensor([3]), tensor([5]), tensor([4]), tensor([6])] >>> # With even more workers >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12))) [tensor([3]), tensor([5]), tensor([4]), tensor([6])]
示例 2:使用 :
worker_init_fn
>>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... return iter(range(self.start, self.end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [3, 4, 5, 6] >>> >>> # Directly doing multi-process loading yields duplicate data >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [3, 3, 4, 4, 5, 5, 6, 6] >>> # Define a `worker_init_fn` that configures each dataset copy differently >>> def worker_init_fn(worker_id): ... worker_info = torch.utils.data.get_worker_info() ... dataset = worker_info.dataset # the dataset copy in this worker process ... overall_start = dataset.start ... overall_end = dataset.end ... # configure the dataset to only process the split workload ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... dataset.start = overall_start + worker_id * per_worker ... dataset.end = min(dataset.start + per_worker, overall_end) ... >>> # Mult-process loading with the custom `worker_init_fn` >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn))) [3, 5, 4, 6] >>> # With even more workers >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn))) [3, 4, 5, 6]
- 类 torch.utils.data 中。TensorDataset(*tensors)[来源]¶
数据集包装张量。
每个样本将通过沿第一维索引张量来检索。
- 参数
*tensors (Tensor) – 与第一维大小相同的张量。
- 类 torch.utils.data 中。StackDataset(*args, **kwargs)[来源]¶
Dataset 作为多个数据集的堆叠。
此类可用于组合复杂输入数据的不同部分,以数据集的形式给出。
例
>>> images = ImageDataset() >>> texts = TextDataset() >>> tuple_stack = StackDataset(images, texts) >>> tuple_stack[0] == (images[0], texts[0]) >>> dict_stack = StackDataset(image=images, text=texts) >>> dict_stack[0] == {'image': images[0], 'text': texts[0]}
- 类 torch.utils.data 中。ConcatDataset(数据集)[来源]¶
Dataset 作为多个数据集的串联。
此类可用于组合不同的现有数据集。
- 参数
datasets (sequence) – 要连接的数据集列表
- 类 torch.utils.data 中。ChainDataset(datasets)[来源]¶
-
此类可用于组合不同的现有数据集流。这 链接操作是动态完成的,因此大规模连接 具有此类的数据集将非常高效。
- 参数
datasets (iterableDataset 的 iterable ) – 要链接在一起的数据集
- 类 torch.utils.data 中。子集(数据集,索引)[来源]¶
位于指定索引处的数据集子集。
- 参数
dataset (Dataset) – 整个 Dataset
indices (sequence) – 为子集选择的整个集合中的索引
- torch.utils.data._utils.collate 中。collate(batch, *, collate_fn_map=None)[来源]¶
通用 collate 函数,用于处理每个批次中元素的集合类型。
该函数还会打开函数注册表以处理特定的元素类型。default_collate_fn_map 为张量、numpy 数组、数字和字符串提供默认的排序函数。
- 参数
例子
>>> def collate_tensor_fn(batch, *, collate_fn_map): ... # Extend this function to handle batch of tensors ... return torch.stack(batch, 0) >>> def custom_collate(batch): ... collate_map = {torch.Tensor: collate_tensor_fn} ... return collate(batch, collate_fn_map=collate_map) >>> # Extend `default_collate` by in-place modifying `default_collate_fn_map` >>> default_collate_fn_map.update({torch.Tensor: collate_tensor_fn})
注意
每个 collate 函数都需要一个 batch 的位置参数和一个关键字参数 对于 collate 函数的字典,则为 collate_fn_map。
- torch.utils.data 中。default_collate(批次)[来源]¶
获取一批数据,并将该批次中的元素放入具有附加外部维度 - batch size 的张量中。
确切的输出类型可以是 a
、 a Sequence 、
a 的集合
,或保持不变,具体取决于输入类型。 当 batch_size 或 batch_sampler 在 中
定义时,此函数用作排序规则的默认函数。
以下是到输出类型映射的常规输入类型(基于批处理中元素的类型):
int ->
torch.Tensor
str -> str (不变)
bytes -> 字节(未更改)
映射[K, V_i] -> 映射[K, default_collate([V_1, V_2, ...])]
NamedTuple[V1_i, V2_i, ...]-> NamedTuple[default_collate([V1_1, V1_2, ...]), default_collate([V2_1, V2_2, ...]), ...]
序列[V1_i, V2_i, ...]-> 序列[default_collate([V1_1, V1_2, ...]), default_collate([V2_1, V2_2, ...]), ...]
- 参数
batch – 要整理的单个批次
例子
>>> # Example with a batch of `int`s: >>> default_collate([0, 1, 2, 3]) tensor([0, 1, 2, 3]) >>> # Example with a batch of `str`s: >>> default_collate(['a', 'b', 'c']) ['a', 'b', 'c'] >>> # Example with `Map` inside the batch: >>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}]) {'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])} >>> # Example with `NamedTuple` inside the batch: >>> Point = namedtuple('Point', ['x', 'y']) >>> default_collate([Point(0, 0), Point(1, 1)]) Point(x=tensor([0, 1]), y=tensor([0, 1])) >>> # Example with `Tuple` inside the batch: >>> default_collate([(0, 1), (2, 3)]) [tensor([0, 2]), tensor([1, 3])] >>> # Example with `List` inside the batch: >>> default_collate([[0, 1], [2, 3]]) [tensor([0, 2]), tensor([1, 3])] >>> # Two options to extend `default_collate` to handle specific type >>> # Option 1: Write custom collate function and invoke `default_collate` >>> def custom_collate(batch): ... elem = batch[0] ... if isinstance(elem, CustomType): # Some custom condition ... return ... ... else: # Fall back to `default_collate` ... return default_collate(batch) >>> # Option 2: In-place modify `default_collate_fn_map` >>> def collate_customtype_fn(batch, *, collate_fn_map=None): ... return ... >>> default_collate_fn_map.update(CustomType, collate_customtype_fn) >>> default_collate(batch) # Handle `CustomType` automatically
- torch.utils.data 中。default_convert(数据)[来源]¶
-
如果输入是 Sequence、Collection 或 Mapping,它会尝试将里面的每个元素转换为
. 如果输入不是 NumPy 数组,则保持不变。 当 batch_sampler 和 batch_size 均未在 中
定义时,此函数用作排序规则的默认函数。
常规输入类型到输出类型的映射类似于 的
.有关更多详细信息,请参阅那里的描述。
- 参数
data – 要转换的单个数据点
例子
>>> # Example with `int` >>> default_convert(0) 0 >>> # Example with NumPy array >>> default_convert(np.array([0, 1])) tensor([0, 1]) >>> # Example with NamedTuple >>> Point = namedtuple('Point', ['x', 'y']) >>> default_convert(Point(0, 0)) Point(x=0, y=0) >>> default_convert(Point(np.array(0), np.array(0))) Point(x=tensor(0), y=tensor(0)) >>> # Example with List >>> default_convert([np.array([0, 1]), np.array([2, 3])]) [tensor([0, 1]), tensor([2, 3])]
- torch.utils.data 中。get_worker_info()[来源]¶
-
在 worker 中调用时,这将返回一个保证具有 以下属性:
id
:当前 worker ID。num_workers
:worker 总数。seed
:当前工作程序的随机种子集。该值为 由主进程 RNG 和 worker ID 决定。有关更多详细信息,请参阅文档。
在主进程中调用时,这将返回 。
None
注意
当用于传递给 时
,此方法可用于 以不同的方式设置每个 worker 进程,例如,用于将对象配置为仅读取 分片数据集,或用于为 Dataset 中使用的其他库设定种子 法典。
worker_init_fn
worker_id
dataset
seed
- 返回类型
可选[WorkerInfo]
- torch.utils.data 中。random_split(dataset, lengths, generator=<torch._C.Generator object>)[来源]¶
将数据集随机拆分为给定长度的非重叠新数据集。
如果给出了总和为 1 的分数列表,则 长度将自动计算为 floor(frac * len(dataset)) 提供的每个分数。
计算长度后,如果有任何余数,则 1 个计数将为 以循环方式分发到各个长度 直到没有剩余部分为止。
可选择修复生成器以获得可重复的结果,例如:
例
>>> generator1 = torch.Generator().manual_seed(42) >>> generator2 = torch.Generator().manual_seed(42) >>> random_split(range(10), [3, 7], generator=generator1) >>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2)
- 类 torch.utils.data 中。采样器(data_source=无)[来源]¶
所有 Sampler 的基类。
每个 Sampler 子类都必须提供一个方法,提供 迭代数据集元素的索引或索引列表(批次)的方法, 并且可以提供返回返回的迭代器长度的方法。
__iter__()
__len__()
- 参数
data_source (Dataset) – 此参数未使用,将在 2.2.0 中删除。 您可能仍有使用它的自定义实现。
例
>>> class AccedingSequenceLengthSampler(Sampler[int]): >>> def __init__(self, data: List[str]) -> None: >>> self.data = data >>> >>> def __len__(self) -> int: >>> return len(self.data) >>> >>> def __iter__(self) -> Iterator[int]: >>> sizes = torch.tensor([len(x) for x in self.data]) >>> yield from torch.argsort(sizes).tolist() >>> >>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]): >>> def __init__(self, data: List[str], batch_size: int) -> None: >>> self.data = data >>> self.batch_size = batch_size >>> >>> def __len__(self) -> int: >>> return (len(self.data) + self.batch_size - 1) // self.batch_size >>> >>> def __iter__(self) -> Iterator[List[int]]: >>> sizes = torch.tensor([len(x) for x in self.data]) >>> for batch in torch.chunk(torch.argsort(sizes), len(self)): >>> yield batch.tolist()
- 类 torch.utils.data 中。SequentialSampler(data_source)[源代码]¶
按顺序对元素进行采样,始终按相同的顺序进行采样。
- 参数
data_source (Dataset) – 要从中采样的数据集
- 类 torch.utils.data 中。RandomSampler(data_source, replacement=False, num_samples=None, generator=None)[来源]¶
随机采样元素。如果没有替换,则从随机数据集中采样。
如果带有 replace,则用户可以指定绘制。
num_samples
- 类 torch.utils.data 中。SubsetRandomSampler(indices, generator=None)[来源]¶
从给定的索引列表中随机采样元素,无需替换。
- 参数
indices (sequence) – 索引序列
generator (Generator) – 采样中使用的生成器。
- 类 torch.utils.data 中。WeightedRandomSampler(权重,num_samples,替换=True,生成器=无)[来源]¶
从中采样具有给定概率 (权重) 的元素。
[0,..,len(weights)-1]
- 参数
例
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) [4, 4, 1, 4, 5] >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) [0, 1, 4, 3, 2]
- 类 torch.utils.data 中。BatchSampler(sampler, batch_size, drop_last)[来源]¶
包装另一个采样器以生成一小批索引。
- 参数
例
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
- 类 torch.utils.data.distributed。DistributedSampler(数据集, num_replicas=无, rank=无, shuffle=True, 种子=0,drop_last=False)[来源]¶
将数据加载到数据集子集的 Sampler。
它与 结合使用
时特别有用。在这种情况下,每个 进程可以将实例作为
sampler 传递,并加载 原始数据集。
DistributedSampler
注意
假定 Dataset 的大小是恒定的,并且它的任何实例始终 以相同的顺序返回相同的元素。
- 参数
dataset (Dataset) – 用于采样的数据集。
num_replicas (int, optional) – 参与的进程数 分布式训练。默认情况下,是从 当前分布式组。
world_size
rank (int, optional) – 当前进程在 中的排名。 默认情况下,从当前分布式 群。
num_replicas
rank
shuffle (bool, optional) – 如果 (默认),采样器将对 指标。
True
seed (int, optional) – 用于随机排序采样器的随机种子,如果 .此数字在所有 分布式组中的进程。违约:。
shuffle=True
0
drop_last (bool, optional) – 如果 ,则采样器将删除 tail 的数据,使其在 副本。如果 ,采样器将添加额外的索引以使 数据可在副本之间均匀整除。违约:。
True
False
False
警告
在分布式模式下,调用 创建 iterator 之前每个 epoch 的开始 对于使 shuffle 在多个 epoch 中正常工作是必要的。否则 将始终使用相同的 Sequences。
set_epoch()
DataLoader
例:
>>> sampler = DistributedSampler(dataset) if is_distributed else None >>> loader = DataLoader(dataset, shuffle=(sampler is None), ... sampler=sampler) >>> for epoch in range(start_epoch, n_epochs): ... if is_distributed: ... sampler.set_epoch(epoch) ... train(loader)