目录

torch.utils.data

PyTorch 数据加载实用程序的核心是类。它表示数据集上的 Python 可迭代对象,支持

这些选项由 a 的构造函数参数配置,该参数具有 signature:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

以下各节详细介绍了这些选项的效果和用法。

数据集类型

constructor 最重要的参数是 ,它表示要加载数据的 dataset 对象 从。PyTorch 支持两种不同类型的数据集:dataset

地图样式数据集

地图风格的数据集是实现 and 协议的数据集,并表示来自 (可能是非整数) indices/keys 添加到数据样本中。__getitem__()__len__()

例如,这样的数据集在使用 访问时可以读取 磁盘上文件夹中的第 -th 映像及其相应的标签。dataset[idx]idx

有关更多详细信息,请参阅

可迭代样式的数据集

可迭代样式的数据集是实现协议的子类的实例,并表示 数据样本。这种类型的数据集特别适用于以下情况 随机读取是昂贵的,甚至是不可能的,并且批处理大小取决于 在获取的数据上。__iter__()

例如,这样的数据集在调用 时可能会返回一个 从数据库、远程服务器甚至生成的日志读取的数据流 实时。iter(dataset)

有关更多详细信息,请参阅

注意

当使用具有多进程数据加载的 时。一样 dataset 对象将复制到每个工作进程上,因此 副本必须以不同的方式配置以避免重复数据。请参阅文档以了解如何 实现这一点。

数据加载顺序 和

对于可迭代样式的数据集,数据加载顺序 完全由用户定义的可迭代对象控制。这允许更容易 块读取和动态批量大小的实现(例如,通过生成 批量采样)。

本节的其余部分涉及地图样式数据集的情况。类用于指定数据加载中使用的索引/键序列。 它们表示数据集索引上的可迭代对象。例如,在 随机梯度 Decent (SGD) 的常见情况,a 可以随机排列索引列表 并一次产生每一个,或者产生少量的小批量 新币。

顺序或随机采样器将根据 的参数自动构造。 或者,用户可以使用该参数来指定 自定义对象,该对象在每次 要获取的下一个索引/键。shufflesampler

生成 batch 列表的自定义 indices 可以作为参数传递。 也可以通过 and 参数启用自动批处理。有关更多详细信息,请参阅下一节 在这个。batch_samplerbatch_sizedrop_last

注意

既不兼容,也不兼容 iterable 样式的数据集,因为此类数据集没有 key 或 指数。samplerbatch_sampler

加载批处理和非批处理数据

支持自动分套 单个通过参数 、 、 和 将数据样本提取到批次中。batch_sizedrop_lastbatch_sampler

自动批处理 (默认)

这是最常见的情况,对应于获取 data 并将它们整理成批量样本,即包含 一个维度是批次维度(通常是第一个维度)。

当 (default ) 为 not 时,数据加载器会生成 批量样本,而不是单个样本。 和 arguments 用于指定数据加载器如何获取 批量的数据集键。对于地图样式的数据集,用户也可以 specify ,一次生成一个键列表。batch_size1Nonebatch_sizedrop_lastbatch_sampler

注意

和 参数基本上被使用 构造一个 from .对于地图样式 数据集,则 要么由用户提供,要么由 基于参数。对于可迭代样式的数据集,this 是一个虚拟的无限 1。有关更多详细信息,请参阅此部分 取样。batch_sizedrop_lastbatch_samplersamplersamplershufflesampler

注意

当使用多重处理可迭代样式的数据集中获取时,该参数会删除每个工作程序的数据集副本的最后一个非完整批次。drop_last

使用 sampler 中的 indices 获取样本列表后,函数 作为参数传递的 Passed 用于整理样本列表 分批进行。collate_fn

在这种情况下,从地图样式数据集加载大致相当于:

for indices in batch_sampler:
    yield collate_fn([dataset[i] for i in indices])

从可迭代样式的数据集加载大致等同于:

dataset_iter = iter(dataset)
for indices in batch_sampler:
    yield collate_fn([next(dataset_iter) for _ in indices])

自定义可用于自定义排序规则,例如填充 Sequential data 设置为批处理的最大长度。请参阅本节,了解有关 .collate_fncollate_fn

禁用自动批处理

在某些情况下,用户可能希望在数据集代码中手动处理批处理。 或者简单地加载单个样品。例如,直接 加载批处理数据(例如,从数据库批量读取或连续读取 块内存),或者批处理大小取决于数据,或者程序是 设计用于处理单个样品。在这些情况下,很可能会 最好不要使用自动批处理(其中 用于 整理样本),但让数据加载器直接返回 对象。collate_fndataset

当 和 are (默认 的值已经),自动批处理是 禁用。从 中获得的每个样品都使用 函数作为参数传递。batch_sizebatch_samplerNonebatch_samplerNonedatasetcollate_fn

禁用自动批处理时,默认的 将 NumPy 数组转换为 PyTorch 张量,并保持其他所有内容保持不变。collate_fn

在这种情况下,从地图样式数据集加载大致相当于:

for index in sampler:
    yield collate_fn(dataset[index])

从可迭代样式的数据集加载大致等同于:

for data in iter(dataset):
    yield collate_fn(data)

请参阅本节,了解有关 .collate_fn

使用collate_fn

当自动批处理时,用途略有不同 enabled 或 disabled。collate_fn

禁用自动批处理时,使用 每个单独的数据样本和输出都是从 Data Loader 生成的 迭 代。在这种情况下,默认只是将 NumPy 数组。collate_fncollate_fn

启用自动批处理后,使用列表调用 的数据样本。它应将输入样本整理到 用于从 Data Loader 迭代器生成 Batch 的 Batch。本节的其余部分 描述在这种情况下 default 的行为。collate_fncollate_fn

例如,如果每个数据样本都由一个 3 通道图像和一个积分 class 标签,即 dataset 的每个元素都返回一个元组,默认整理一个 此类元组转换为批处理图像张量和批处理类的单个元组 label Tensor 的 Tensor 中。具体而言,默认值如下 性能:(image, class_index)collate_fncollate_fn

  • 它始终将新维度作为批处理维度。

  • 它会自动将 NumPy 数组和 Python 数值转换为 PyTorch 张量。

  • 它保留了数据结构,例如,如果每个样本都是一个字典,则它 输出具有相同键集但将 Tensor 作为值的字典 (如果值无法转换为 Tensor,则列出)。相同 用于 s、s、s 等。listtuplenamedtuple

用户可以使用 customized 来实现自定义批处理,例如: 沿第一个维度以外的维度进行整理,填充序列 各种长度,或添加对自定义数据类型的支持。collate_fn

单进程和多进程数据加载

A 使用单进程数据加载 违约。

在 Python 进程中,全局解释器锁 (GIL) 可防止真正跨线程完全并行化 Python 代码。为避免阻塞 计算代码与数据加载一起,PyTorch 提供了一个简单的开关来执行 通过简单地将参数设置为正整数来多进程数据加载。num_workers

单进程数据加载(默认)

在此模式下,数据获取在初始化的同一进程中完成。因此,数据加载 可能会阻止计算。但是,当使用资源时,此模式可能是首选 用于在进程之间共享数据(例如,共享内存、文件描述符)是 limited,或者当整个数据集很小并且可以完全加载到 记忆。此外,单进程加载通常显示更具可读性的错误 跟踪,因此可用于调试。

多进程数据加载

将参数设置为正整数将 使用指定数量的 Loader 工作线程开启多进程数据加载 过程。num_workers

警告

经过几次迭代后,loader 工作进程将消耗 与所有 Python 的父进程相同的 CPU 内存量 父进程中可从 worker 访问的对象 过程。如果 Dataset 包含大量 data 中(例如,您正在 Dataset 中加载一个非常大的文件名列表 施工时间)和/或您正在使用大量工人(总体 memory usage 为 )。这 最简单的解决方法是将 Python 对象替换为 non-refcounted 表示形式,例如 Pandas、Numpy 或 PyArrow 对象。查看问题 #13246,了解有关发生这种情况的原因以及如何操作的示例代码的更多详细信息 解决方法 这些问题。number of workers * size of parent process

在此模式下,每次创建 a 的迭代器时(例如,当您调用 时),都会创建 worker 进程。此时,将 、 、 和 传递给每个 worker,它们用于初始化和获取数据。这意味着 数据集访问及其内部 IO 一起转换 (包括 ) 在 worker 进程中运行。enumerate(dataloader)num_workersdatasetcollate_fnworker_init_fncollate_fn

返回各种有用的信息 在 worker 进程中(包括 worker id、数据集副本、初始种子、 等),并在主进程中返回。用户可以在 数据集代码和/或单独配置每个 数据集副本,并确定代码是否在 worker 中运行 过程。例如,这在对数据集进行分片时特别有用。Noneworker_init_fn

对于地图样式的数据集,主进程使用 生成索引并将其发送给 worker。所以任何随机化都是 在主进程中完成,该进程通过为 Load 分配索引来指导加载。sampler

对于可迭代样式的数据集,由于每个 worker 进程都会获得对象的副本,因此简单的多进程加载通常会导致 重复数据。使用 and/或 ,用户可以单独配置每个副本。(请参阅文档了解如何实现 这。) 出于类似的原因,在多进程加载中,该参数会丢弃每个 worker 的可迭代样式数据集的最后一个非完整批次 复制品。datasetworker_init_fndrop_last

一旦到达迭代结束,或者当 iterator 变为垃圾回收。

警告

一般不建议在多进程中返回 CUDA 张量 loading 的原因,因为使用 CUDA 和在 multiprocessing (请参阅 multiprocessing 中的 CUDA)。相反,我们建议 使用自动内存固定(即 setting ),从而可以将数据快速传输到启用 CUDA 的 GPU 的 GPU 。pin_memory=True

特定于平台的行为

由于 worker 依赖于 Python,因此 worker 启动行为是 Windows 与 Unix 不同。

  • 在 Unix 上,是默认的启动方法。 使用 ,子工作程序通常可以访问 和 Python 参数直接通过克隆的地址空间执行函数。fork()fork()dataset

  • 在 Windows 或 MacOS 上,是默认的启动方法。 使用 ,将启动另一个解释器,该解释器运行您的主脚本 后跟内部 worker 函数,该函数通过序列化接收 、 和其他参数。spawn()spawn()datasetcollate_fn

这种单独的序列化意味着您应该采取两个步骤来确保 在使用多进程数据加载时与 Windows 兼容:

  • 将大部分主脚本的代码包装在 block 中, 确保它不会再次运行(很可能生成错误),当每个 worker 进程。您可以将数据集和实例创建逻辑放在此处,因为它不需要在 worker 中重新执行。if __name__ == '__main__':

  • 确保在检查之外将任何 custom 或 code 声明为顶级定义。这可确保它们在工作进程中可用。 (这是必需的,因为函数仅作为引用被腌制,而不是 .)collate_fnworker_init_fndataset__main__bytecode

多进程数据加载的随机性

默认情况下,每个工作程序的 PyTorch 种子都将设置为 , 其中 是主进程使用其 RNG 生成的 long(因此, 强制使用 RNG 状态)或指定的 .然而,其他 库可能会在初始化 worker 时被复制,从而导致每个 worker 返回 相同的随机数。(请参阅常见问题解答中的此部分base_seed + worker_idbase_seedgenerator

在 中,您可以访问每个工作线程的 PyTorch 种子集 替换为 或 并使用它在 data 之前为其他库设定种子 装载。worker_init_fn

内存固定

主机到 GPU 的副本源自固定(页面锁定)时要快得多 记忆。有关何时以及如何使用的更多详细信息,请参阅使用固定内存缓冲区 固定内存。

对于数据加载,传递给 将自动将获取的数据放入 Tensor 的 Tensor 存储在固定内存中,从而可以更快地将数据传输到支持 CUDA 的 GPU 的 GPU 。pin_memory=True

默认内存固定逻辑仅识别 Tensor 和 map 以及可迭代对象 包含 Tensor。默认情况下,如果固定逻辑看到一个 自定义类型(如果您的 a 返回 custom batch 类型),或者如果 Batch 的每个元素都是自定义类型,则 pinning logic 将无法识别它们,并且会返回该 batch(或那些 元素),而无需固定内存。为自定义启用内存固定 batch 或数据类型,在自定义 类型。collate_fnpin_memory()

请参阅下面的示例。

例:

class SimpleCustomBatch:
    def __init__(self, data):
        transposed_data = list(zip(*data))
        self.inp = torch.stack(transposed_data[0], 0)
        self.tgt = torch.stack(transposed_data[1], 0)

    # custom memory pinning method on custom type
    def pin_memory(self):
        self.inp = self.inp.pin_memory()
        self.tgt = self.tgt.pin_memory()
        return self

def collate_wrapper(batch):
    return SimpleCustomBatch(batch)

inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)

loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
                    pin_memory=True)

for batch_ndx, sample in enumerate(loader):
    print(sample.inp.is_pinned())
    print(sample.tgt.is_pinned())
class datasetbatch_size=1shuffle=Falsesampler=Nonebatch_sampler=Nonenum_workers=0collate_fn=pin_memory=Falsedrop_last=False超时=0worker_init_fn=multiprocessing_context=生成器=*prefetch_factor=2persistent_workers=False[来源]torch.utils.data.DataLoader

数据加载器。将数据集和采样器组合在一起,并在 给定的数据集。

支持 map-style 和 具有单进程或多进程加载、自定义的可迭代样式数据集 加载顺序和可选的自动批处理(排序规则)和内存固定。

有关更多详细信息,请参阅文档页面。

参数
  • datasetDataset) – 从中加载数据的数据集。

  • batch_sizeintoptional) – 每批要加载的样本数 (默认值:)。1

  • shufflebooloptional) – 设置为重新洗牌数据 在每个 epoch (默认值: )。TrueFalse

  • samplerSamplerIterable可选) – 定义要绘制的策略 数据集中的样本。可以是任何已实施的。如果指定,则不得指定。Iterable__len__shuffle

  • batch_samplerSamplerIterable可选) – 类似于 ,但 一次返回一批索引。与 、 、 互斥 和。samplerbatch_sizeshufflesamplerdrop_last

  • num_workersintoptional) – 用于数据的子进程数 装载。 表示数据将在主进程中加载。 (默认:00)

  • collate_fn可调用可选) – 合并样本列表以形成 小批量的 Tensor 中。当使用 batch loading from 地图样式数据集。

  • pin_memorybooloptional) – 如果 ,数据加载器将复制 Tensor 放入 CUDA 固定内存中。如果您的数据元素 是自定义类型,或者您返回的批次是自定义类型, 请参阅下面的示例。Truecollate_fn

  • drop_lastbooloptional) – 设置为 以删除最后一个未完成的批次, 如果数据集大小不能被批量大小整除。If 和 数据集的大小不能被批次大小整除,然后是最后一个批次 会更小。(默认:TrueFalseFalse)

  • timeoutnumericoptional) – 如果为正数,则为收集批次的超时值 从工人。应始终为非负数。(默认:0)

  • worker_init_fncallableoptional) – 如果不是,则将在每个 worker 子进程,其中 worker id ( int in ) 为 input、seeding 之后和 data loading 之前。(默认:None[0, num_workers - 1]None)

  • 发电机Torch.生成器可选) – 如果没有,将使用此 RNG 通过 RandomSampler 生成随机索引,通过 multiprocessing 为 worker 生成 base_seed。(默认:NoneNone)

  • prefetch_factorintoptionalkeyword-only arg) – 加载的样本数 由每个 worker 提前完成。 表示总共会有 2 * num_workers 个在所有工作程序中预取的样本。(默认:22)

  • persistent_workersbooloptional) – 如果 ,则数据加载器不会关闭 工作程序在 dataset 被使用一次后进行处理。这允许 保持 worker Dataset 实例处于活动状态。(默认:TrueFalse)

警告

如果使用 start 方法,则不能是不可封存的对象,例如 lambda 函数。有关更多详细信息,请参阅多处理最佳实践 添加到 PyTorch 中的 multiprocessing 中。spawnworker_init_fn

警告

len(dataloader)启发式 (heuristic) 基于所使用的采样器的长度。 当 为 时 , 相反,它返回基于 的估计值,并使用适当的 舍入取决于 ,而不考虑多进程加载 配置。这代表了 PyTorch 可以做出的最佳猜测,因为 PyTorch 信任用户代码正确处理多进程 loading 以避免重复数据。datasetlen(dataset) / batch_sizedrop_lastdataset

但是,如果分片导致多个 worker 具有不完整的最后一批,则 此估计仍然可能不准确,因为 (1) 否则完整的批次可能 被分成多个 1 和 (2) 多个批次的样品可以是 set 时丢弃。不幸的是,PyTorch 无法检测到此类 一般情况。drop_last

请参阅 数据集类型 有关这两种类型的数据集以及如何与多进程数据加载交互的更多详细信息。

class *args**kwds[来源]torch.utils.data.Dataset

表示 .

表示从键到数据样本的映射的所有数据集都应子类化 它。所有子类都应该覆盖 ,支持获取 data 样本。子类也可以选择覆盖 ,预计许多实现和默认选项将返回数据集的大小 的 .__getitem__()__len__()

注意

默认情况下,构造一个索引 sampler 生成整数索引。使其使用地图样式 数据集,则必须提供自定义采样器。

class *args**kwds[来源]torch.utils.data.IterableDataset

一个可迭代的 Dataset。

表示数据样本可迭代对象的所有数据集都应将其子类化。 当数据来自流时,这种形式的数据集特别有用。

所有子类都应覆盖 ,这将返回一个 此数据集中样本的迭代器。__iter__()

当子类与 一起使用时,每个 item 将从迭代器中生成。当 ,每个工作进程都将具有 数据集对象的不同副本,因此通常需要配置 每个副本,以避免从 工人。,在 worker 中调用 process 返回有关工作程序的信息。它可以用于 dataset 的方法或 的选项来修改每个副本的行为。num_workers > 0__iter__()worker_init_fn

示例 1:在 中的所有工作线程之间拆分工作负载:__iter__()

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         worker_info = torch.utils.data.get_worker_info()
...         if worker_info is None:  # single-process data loading, return the full iterator
...             iter_start = self.start
...             iter_end = self.end
...         else:  # in a worker process
...             # split workload
...             per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
...             worker_id = worker_info.id
...             iter_start = self.start + worker_id * per_worker
...             iter_end = min(iter_start + per_worker, self.end)
...         return iter(range(iter_start, iter_end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]

>>> # Mult-process loading with two worker processes
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 5, 4, 6]

>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20)))
[3, 4, 5, 6]

示例 2:使用 :worker_init_fn

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         return iter(range(self.start, self.end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>>
>>> # Directly doing multi-process loading yields duplicate data
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 3, 4, 4, 5, 5, 6, 6]

>>> # Define a `worker_init_fn` that configures each dataset copy differently
>>> def worker_init_fn(worker_id):
...     worker_info = torch.utils.data.get_worker_info()
...     dataset = worker_info.dataset  # the dataset copy in this worker process
...     overall_start = dataset.start
...     overall_end = dataset.end
...     # configure the dataset to only process the split workload
...     per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
...     worker_id = worker_info.id
...     dataset.start = overall_start + worker_id * per_worker
...     dataset.end = min(dataset.start + per_worker, overall_end)
...

>>> # Mult-process loading with the custom `worker_init_fn`
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
[3, 5, 4, 6]

>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn)))
[3, 4, 5, 6]
*Tensors[来源]torch.utils.data.TensorDataset

数据集包装张量。

每个样本将通过沿第一维索引张量来检索。

参数

*tensorsTensor) – 与第一维大小相同的张量。

数据集[来源]torch.utils.data.ConcatDataset

Dataset 作为多个数据集的串联。

此类可用于组合不同的现有数据集。

参数

datasetssequence) – 要连接的数据集列表

数据集[来源]torch.utils.data.ChainDataset

用于链接多个 s 的数据集。

此类可用于组合不同的现有数据集流。这 链接操作是动态完成的,因此大规模连接 具有此类的数据集将非常高效。

参数

datasetsiterableDataset 的 iterable ) – 要链接在一起的数据集

class 数据集indices[来源]torch.utils.data.Subset

位于指定索引处的数据集子集。

参数
  • datasetDataset) – 整个 Dataset

  • indicessequence) – 为子集选择的整个集合中的索引

torch.utils.data.get_worker_info()[来源]

返回有关当前迭代器 worker 进程的信息。

在 worker 中调用时,这将返回一个保证具有 以下属性:

  • id:当前 worker ID。

  • num_workers:worker 总数。

  • seed:当前工作程序的随机种子集。该值为 由主进程 RNG 和 worker ID 决定。有关更多详细信息,请参阅 文档。

  • dataset进程中 dataset 对象的副本。注意 这将是不同进程中的不同对象 在主进程中。

在主进程中调用时,这将返回 。None

注意

当用于传递给 时,此方法可用于 以不同的方式设置每个 worker 进程,例如,用于将对象配置为仅读取 分片数据集,或用于为 Dataset 中使用的其他库设定种子 法典。worker_init_fnworker_iddatasetseed

torch.utils.data.random_split(数据集长度生成器=<torch._C.Generator 对象>[来源]

将数据集随机拆分为给定长度的非重叠新数据集。 可选择修复生成器以获得可重复的结果,例如:

>>> random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42))
参数
  • datasetDataset) – 要拆分的数据集

  • lengthssequence) – 要生成的分割长度

  • generatorGenerator) – 用于随机排列的生成器。

data_source[来源]torch.utils.data.Sampler

所有 Sampler 的基类。

每个 Sampler 子类都必须提供一个方法,提供 迭代 dataset 元素索引的方法,以及 这将返回返回的迭代器的长度。__iter__()__len__()

注意

该方法不是严格要求的,但在任何 涉及 长度的计算。__len__()

data_source[来源]torch.utils.data.SequentialSampler

按顺序对元素进行采样,始终按相同的顺序进行采样。

参数

data_sourceDataset) – 要从中采样的数据集

class data_sourcereplacement=Falsenum_samples=Nonegenerator=None[来源]torch.utils.data.RandomSampler

随机采样元素。如果没有替换,则从随机数据集中采样。 如果带有 replace,则用户可以指定绘制。num_samples

参数
  • data_sourceDataset) – 要从中采样的数据集

  • replacementbool) – 使用替换按需绘制样本 if , default=''False''True

  • num_samplesint) - 要绘制的样本数,default='len(dataset)'。这个参数 应该仅在 replacement 为 时指定。True

  • generatorGenerator) – 采样中使用的生成器。

class indicesgenerator=None[来源]torch.utils.data.SubsetRandomSampler

从给定的索引列表中随机采样元素,无需替换。

参数
  • indicessequence) – 索引序列

  • generatorGenerator) – 采样中使用的生成器。

class weightsnum_samplesreplacement=Truegenerator=None[来源]torch.utils.data.WeightedRandomSampler

从中采样具有给定概率 (权重) 的元素。[0,..,len(weights)-1]

参数
  • weightssequence) – 权重序列,不必求和为 1

  • num_samplesint) – 要绘制的样本数

  • replacementbool) – 如果 ,则使用 replacement 绘制样本。 否则,它们将被绘制而不进行替换,这意味着当 为一行绘制样本索引,则不能为该行再次绘制该索引。True

  • generatorGenerator) – 采样中使用的生成器。

>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
[4, 4, 1, 4, 5]
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
[0, 1, 4, 3, 2]
Samplerbatch_sizedrop_last[来源]torch.utils.data.BatchSampler

包装另一个采样器以生成一小批索引。

参数
  • samplerSamplerIterable) – Base sampler。可以是任何可迭代对象

  • batch_sizeint) – 小批量的大小。

  • drop_lastbool) – 如果 ,则采样器将丢弃最后一个批次 它的大小将小于Truebatch_size

>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
数据集num_replicas=rank=shuffle=True种子=0drop_last=False[来源]torch.utils.data.distributed.DistributedSampler

将数据加载到数据集子集的 Sampler。

它与 结合使用时特别有用。在这种情况下,每个 进程可以将实例作为 sampler 传递,并加载 原始数据集。DistributedSampler

注意

假定 Dataset 的大小为常量。

参数
  • dataset (数据集) – 用于采样的数据集。

  • num_replicasintoptional) – 参与的进程数 分布式训练。默认情况下,是从 当前分布式组。world_size

  • rankintoptional) – 当前进程在 中的排名。 默认情况下,从当前分布式 群。num_replicasrank

  • shufflebooloptional) – 如果 (默认),采样器将对 指标。True

  • seedintoptional) – 用于随机排序采样器的随机种子,如果 .此数字在所有 分布式组中的进程。违约:。shuffle=True0

  • drop_lastbooloptional) – 如果 ,则采样器将删除 tail 的数据,使其在 副本。如果 ,采样器将添加额外的索引以使 数据可在副本之间均匀整除。违约:。TrueFalseFalse

警告

在分布式模式下,调用 创建 iterator 之前每个 epoch 的开始 对于使 shuffle 在多个 epoch 中正常工作是必要的。否则 将始终使用相同的 Sequences。set_epoch()DataLoader

例:

>>> sampler = DistributedSampler(dataset) if is_distributed else None
>>> loader = DataLoader(dataset, shuffle=(sampler is None),
...                     sampler=sampler)
>>> for epoch in range(start_epoch, n_epochs):
...     if is_distributed:
...         sampler.set_epoch(epoch)
...     train(loader)

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源