注意
转到末尾下载完整的示例代码。
开始使用数据收集和存储¶
作者: Vincent Moens
注意
要在笔记本中运行本教程,请添加安装单元 开头包含:
!pip install tensordict !pip install torchrl
没有数据就没有学习。在监督式学习中,用户是
习惯使用等
将数据集成到他们的训练循环中。
Dataloader 是可迭代对象,可为您提供所需的数据
用于训练您的模型。
TorchRL 以类似的方式处理数据加载问题,尽管
它在 RL 库的生态系统中出奇地独特。TorchRL 的
Dataloader 称为 。大多数时候,
数据收集并不止于原始数据的收集,
因为数据需要临时存储在缓冲区中
(或策略 SOTA 实现的等效结构)
由 loss 模块。本教程将探讨
这两个类。DataCollectors
数据收集器¶
这里讨论的主要数据收集器是 ,这是本文的重点
文档。从根本上讲,收集器是一个简单的
类负责在环境中执行策略,
必要时重置环境,并提供批量的
predefined size 的 Predefined size。与
方法
在 env 教程中演示,收集器不会
在连续的数据批次之间重置。因此,两个连续的
批量数据可能包含来自同一轨迹的元素。
您需要传递给收集器的基本参数是
batchs()、长度(可能
infinite) 的 URLERATOR、策略和环境。为简单起见,
在此示例中,我们将使用虚拟的 random 策略。frames_per_batch
import torch
torch.manual_seed(0)
from torchrl.collectors import SyncDataCollector
from torchrl.envs import GymEnv
from torchrl.envs.utils import RandomPolicy
env = GymEnv("CartPole-v1")
env.set_seed(0)
policy = RandomPolicy(env.action_spec)
collector = SyncDataCollector(env, policy, frames_per_batch=200, total_frames=-1)
/pytorch/rl/torchrl/envs/common.py:2989: DeprecationWarning: Your wrapper was not given a device. Currently, this value will default to 'cpu'. From v0.5 it will default to `None`. With a device of None, no device casting is performed and the resulting tensordicts are deviceless. Please set your device accordingly.
warnings.warn(
我们现在预计我们的收集器将交付大小为 no 的批次
重要的是收集过程中发生的情况。换句话说,我们可能有多个
这批的轨迹!表示
collector 应该是。值 将产生一个 never
end collector 的 Collector 中。200
total_frames
-1
让我们遍历收集器以获得一种感觉 这些数据是什么样的:
for data in collector:
print(data)
break
TensorDict(
fields={
action: Tensor(shape=torch.Size([200, 2]), device=cpu, dtype=torch.int64, is_shared=False),
collector: TensorDict(
fields={
traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([200]),
device=None,
is_shared=False),
done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([200, 4]), device=cpu, dtype=torch.float32, is_shared=False),
reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([200]),
device=None,
is_shared=False),
observation: Tensor(shape=torch.Size([200, 4]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([200]),
device=None,
is_shared=False)
如您所见,我们的数据通过一些特定于收集器的元数据进行了扩充
分组到我们在环境推出期间没有看到的 sub-tensordict 中。这对于跟踪
轨迹 ID。在下面的列表中,每个项目都标记了轨迹
number 对应的 transition 属于:"collector"
print(data["collector", "traj_ids"])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5,
5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9,
9, 9, 9, 9, 9, 9, 9, 9])
数据收集器在编码最先进的技术时非常有用
sota-implementations,因为性能通常是通过
在给定数量的交互中解决问题的特定技术
环境 (收集器中的参数)。
因此,我们示例中的大多数训练循环如下所示:total_frames
>>> for data in collector:
... # your algorithm here
重放缓冲区¶
现在我们已经探索了如何收集数据,我们想知道如何 储存它。在 RL 中,典型的设置是收集、存储数据 暂时清除,过了一会儿给出了一些启发式方法: 先进先出或其他。典型的伪代码如下所示:
>>> for data in collector:
... storage.store(data)
... for i in range(n_optim):
... sample = storage.sample()
... loss_val = loss_fn(sample)
... loss_val.backward()
... optim.step() # etc
在 TorchRL 中存储数据的父类
称为 。TorchRL 的重播
缓冲区是可组合的:您可以编辑存储类型及其采样
技术、写作启发式或应用于它们的转换。我们会的
将花哨的东西留给专门的深入教程。通用重播
buffer 只需要知道它必须使用什么存储空间。一般来说,我们
推荐一个子类,它将起作用
在大多数情况下很好。我们将在本教程中使用
它,它有两个不错的属性:首先,“懒惰”,
您无需提前明确告诉它您的数据是什么样子的。
其次,它用作后端来保存
您的数据以有效的方式存储在磁盘上。您唯一需要知道的是
您希望缓冲区有多大。
TensorStorage
MemoryMappedTensor
from torchrl.data.replay_buffers import LazyMemmapStorage, ReplayBuffer
buffer = ReplayBuffer(storage=LazyMemmapStorage(max_size=1000))
可以通过 (single element) 或 (multiple elements) 方法填充缓冲区。
用
我们刚刚收集的数据,我们一次性初始化并填充缓冲区:
indices = buffer.extend(data)
我们可以检查缓冲区现在的元素数量是否与缓冲区的元素数量相同 我们从收藏家那里得到:
assert len(buffer) == collector.frames_per_batch
唯一需要了解的是如何从缓冲区收集数据。
当然,这取决于方法。因为我们没有指定必须在没有
重复,则不能保证从我们的缓冲区收集的样本
将是唯一的:
sample = buffer.sample(batch_size=30)
print(sample)
TensorDict(
fields={
action: Tensor(shape=torch.Size([30, 2]), device=cpu, dtype=torch.int64, is_shared=False),
collector: TensorDict(
fields={
traj_ids: Tensor(shape=torch.Size([30]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([30]),
device=cpu,
is_shared=False),
done: Tensor(shape=torch.Size([30, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([30, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([30, 4]), device=cpu, dtype=torch.float32, is_shared=False),
reward: Tensor(shape=torch.Size([30, 1]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([30, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([30, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([30]),
device=cpu,
is_shared=False),
observation: Tensor(shape=torch.Size([30, 4]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([30, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([30, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([30]),
device=cpu,
is_shared=False)
同样,我们的样本看起来与我们从 收藏家!
后续步骤¶
如果您有多个节点,TorchRL 还提供分布式收集器 用于推理。在 API 参考中查看它们。
查看专用的 Replay Buffer 教程以了解 有关构建缓冲区时的选项的更多信息,或涵盖 详。重播缓冲区具有无数功能,例如多线程 采样、优先体验重播等等......
我们省略了要迭代的重放缓冲区的容量 单纯。自己试一试:构建一个缓冲区并指明其 batch-size,然后尝试迭代它。这是 相当于在循环中调用
rb.sample()
脚本总运行时间:(0 分 18.780 秒)
估计内存使用量:28 MB