目录

使用 TorchRL 的强化学习 (PPO) 教程

创建时间: 2023年3月15日 |最后更新时间:2024 年 5 月 16 日 |上次验证: Nov 05, 2024

作者Vincent Moens

本教程演示如何使用 PyTorch 和训练参数策略 网络解决来自 OpenAI-Gym/Farama-Gymnasium 的倒摆任务 control 库torchrl

倒摆

倒摆

主要学习内容:

  • 如何在 TorchRL 中创建环境,转换其输出,并从此环境中收集数据;

  • 如何使用TensorDict;

  • 使用 TorchRL 构建训练循环的基础知识:

    • 如何计算策略梯度方法的优势信号;

    • 如何使用概率神经网络创建随机策略;

    • 如何创建动态重放缓冲区并从中采样而不重复。

我们将介绍 TorchRL 的六个关键组件:

如果您在 Google Colab 中运行此程序,请确保安装以下依赖项:

!pip3 install torchrl
!pip3 install gym[mujoco]
!pip3 install tqdm

近端策略优化 (PPO) 是一种策略梯度算法,其中 正在收集并直接使用批量数据以训练策略以最大化 给定一些近似约束的预期回报。您可以考虑 作为 REINFORCE 的复杂版本, 基础策略优化算法。有关详细信息,请参阅 Proximal Policy Optimization Algorithms 论文。

PPO 通常被认为是一种快速有效的在线保单方法 reinforcement 算法。TorchRL 提供了一个 loss-module 来完成所有工作 ,以便您可以依赖此实施并专注于解决您的 问题,而不是每次都想训练策略时重新发明轮子。

为了完整起见,这里简要概述了损失的计算内容,尽管 这由我们的模块负责——算法的工作原理如下: 1. 我们将通过播放 策略。 2. 然后,我们将使用该批次的随机子样本执行给定数量的优化步骤 REINFORCE 损失的剪辑版本。 3. 削减将对我们的损失产生悲观的界限:较低的回报估计将 与更高的人相比受到青睐。 损失的精确公式为:ClipPPOLoss

\[L(s,a,\theta_k,\theta) = \min\left( \frac{\pi_{\theta}(a|s)}{\pi_{\theta_k}(a|s)} a^{\pi_{\theta_k}}(s,a), \;\; g(\epsilon, A^{\pi_{\theta_k}}(s,a)) \right),\]

该损失有两个组成部分:在 minimum 运算符的第一部分 我们简单地计算 REINFORCE 损失的重要性加权版本(例如,一个 REINFORCE 损失,我们已经纠正了当前政策的事实 配置滞后于用于数据收集的配置)。 该最小运算符的第二部分是类似的损失,我们已裁剪 当它们超过或低于给定的一对阈值时的比率。

这种损失确保了无论优势是积极的还是消极的,策略 与以前的配置相比会产生重大变化的更新 正在被劝阻。

本教程的结构如下:

  1. 首先,我们将定义一组将用于训练的超参数。

  2. 接下来,我们将专注于使用 TorchRL 的 包装器和转换。

  3. 接下来,我们将设计策略网络和价值模型, 这对于损失函数是必不可少的。这些模块将被使用 来配置我们的 loss 模块。

  4. 接下来,我们将创建重放缓冲区和数据加载器。

  5. 最后,我们将运行训练循环并分析结果。

在本教程中,我们将使用该库。 是 TorchRL 的通用语言:它帮助我们抽象 模块读取和写入的内容,不太关心特定数据 description 以及有关算法本身的更多信息。tensordictTensorDict

import warnings
warnings.filterwarnings("ignore")
from torch import multiprocessing


from collections import defaultdict

import matplotlib.pyplot as plt
import torch
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torch import nn
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs import (Compose, DoubleToFloat, ObservationNorm, StepCounter,
                          TransformedEnv)
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
from tqdm import tqdm

定义超参数

我们为算法设置超参数。取决于资源 可用,则可以选择在 GPU 上执行策略,也可以选择在 GPU 上执行策略 装置。 将控制单个帧的数量 正在执行的操作。计算帧数的其余参数 必须针对此值进行更正(因为一个环境步骤将 实际上是返回帧)。frame_skipframe_skip

is_fork = multiprocessing.get_start_method() == "fork"
device = (
    torch.device(0)
    if torch.cuda.is_available() and not is_fork
    else torch.device("cpu")
)
num_cells = 256  # number of cells in each layer i.e. output dim.
lr = 3e-4
max_grad_norm = 1.0

数据收集参数

在收集数据时,我们将能够选择每个批次的大小 通过定义参数。我们还将定义多少 框架(例如与模拟器的交互次数)我们将允许自己 用。一般来说,RL 算法的目标是学习解决任务 在环境交互方面尽可能快:越低越好。frames_per_batchtotal_frames

frames_per_batch = 1000
# For a complete training, bring the number of frames up to 1M
total_frames = 50_000

PPO 参数

在每次数据收集(或批量收集)中,我们将运行优化 在一定数量的 epoch 中,每次消耗整个数据时,我们只需 在嵌套训练循环中获取。这里与上面的这里不同:回想一下,我们正在处理 “batch of data” 来自我们的收集器,其大小由 定义,并且 在内部训练循环期间,我们将进一步拆分为更小的子批次。 这些子批次的大小由 控制。sub_batch_sizeframes_per_batchframes_per_batchsub_batch_size

sub_batch_size = 64  # cardinality of the sub-samples gathered from the current data in the inner loop
num_epochs = 10  # optimization steps per batch of data collected
clip_epsilon = (
    0.2  # clip value for PPO loss: see the equation in the intro for more context.
)
gamma = 0.99
lmbda = 0.95
entropy_eps = 1e-4

定义环境

在 RL 中,环境通常是我们称呼模拟器或 控制系统。各种库提供用于加固的仿真环境 学习,包括 Gymnasium(以前的 OpenAI Gym)、DeepMind 控制套件和 许多其他的。 作为一个通用库,TorchRL 的目标是提供一个可互换的接口 转换为大型 RL 模拟器面板,让您轻松交换一个环境 与另一个。例如,创建包裹的健身房环境可以用几个字符来实现:

base_env = GymEnv("InvertedDoublePendulum-v4", device=device)

这段代码中有几点需要注意:首先,我们创建了 环境。如果 extra 关键字参数 被传递,它们将被传输到该方法,因此覆盖 最常见的环境构造命令。 或者,也可以使用 GymWrapper 类直接创建健身房环境并将其包装在 GymWrapper 类中。GymEnvgym.makegym.make(env_name, **kwargs)

还有参数:对于 gym,这只控制 input 操作和观察到的状态将被存储,但执行将始终 在 CPU 上完成。原因很简单,gym 不支持设备上 执行,除非另有说明。对于其他库,我们可以控制 执行设备,并且我们尽可能地在 存储和执行后端。device

变换

我们将向环境追加一些转换,以便为数据做好准备 政策。在 Gym 中,这通常是通过包装器实现的。TorchRL 采用不同的 方法,更类似于其他 PyTorch 域库,通过使用转换。 要将转换添加到环境中,只需将其包装在一个实例中,并将转换序列附加到它。转换后的环境将继承 包装环境的设备和元数据,并根据序列转换它们 of transforms。TransformedEnv

正常化

首先要编码的是规范化转换。 根据经验,最好拥有松散的数据 匹配一个单位高斯分布:为了获得这个分布,我们将 在环境中运行一定数量的随机步骤并计算 这些观测值的汇总统计量。

我们将附加另外两个转换:转换将 将双精度数字转换为单精度数字,以供 政策。转换将用于计算之前的步数 环境已终止。我们将此措施作为补充措施 的绩效。DoubleToFloatStepCounter

正如我们稍后将看到的,许多 TorchRL 的类都依赖于通信。你可以把它看作一个 python 字典,里面有一些额外的 Tensor 特征。在实践中,这意味着我们将使用许多模块 需要被告知要读取什么键 () 和写入什么键 () 中。通常,如果省略,则假定条目将被更新 就地。对于我们的转换,我们唯一感兴趣的条目是 referenced 到 as 中,我们的转换层将被告知修改此 条目,并且仅此条目:TensorDictin_keysout_keystensordictout_keysin_keys"observation"

env = TransformedEnv(
    base_env,
    Compose(
        # normalize observations
        ObservationNorm(in_keys=["observation"]),
        DoubleToFloat(),
        StepCounter(),
    ),
)

您可能已经注意到,我们已经创建了一个归一化层,但我们没有 设置其 normalization 参数。为此,可以 自动收集我们环境的摘要统计信息:ObservationNorm

env.transform[0].init_stats(num_iter=1000, reduce_dim=0, cat_dim=0)

转换现在已填充了 location 和将用于规范化数据的尺度。ObservationNorm

让我们对摘要统计数据的形状进行一些健全性检查:

print("normalization constant shape:", env.transform[0].loc.shape)
normalization constant shape: torch.Size([11])

环境不仅由其模拟器和转换定义,而且还 通过一系列元数据来描述在其 执行。 为了提高效率,TorchRL 在以下方面非常严格 环境规范,但您可以轻松检查您的环境规范是否 足够。 在我们的示例中,继承 从它已经负责为您的环境设置适当的规范,因此 你不应该关心这个。GymWrapperGymEnv

不过,让我们看一个使用转换后的 环境。 有三个规范需要查看:它定义了什么 在环境中执行操作时,表示奖励域,最后是(包含 )并表示 环境执行单个步骤所需的一切。observation_specreward_specinput_specaction_spec

print("observation_spec:", env.observation_spec)
print("reward_spec:", env.reward_spec)
print("input_spec:", env.input_spec)
print("action_spec (as defined by input_spec):", env.action_spec)
observation_spec: CompositeSpec(
    observation: UnboundedContinuousTensorSpec(
        shape=torch.Size([11]),
        space=None,
        device=cpu,
        dtype=torch.float32,
        domain=continuous),
    step_count: BoundedTensorSpec(
        shape=torch.Size([1]),
        space=ContinuousBox(
            low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True),
            high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True)),
        device=cpu,
        dtype=torch.int64,
        domain=continuous),
    device=cpu,
    shape=torch.Size([]))
reward_spec: UnboundedContinuousTensorSpec(
    shape=torch.Size([1]),
    space=ContinuousBox(
        low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
        high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
    device=cpu,
    dtype=torch.float32,
    domain=continuous)
input_spec: CompositeSpec(
    full_state_spec: CompositeSpec(
        step_count: BoundedTensorSpec(
            shape=torch.Size([1]),
            space=ContinuousBox(
                low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True),
                high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True)),
            device=cpu,
            dtype=torch.int64,
            domain=continuous),
        device=cpu,
        shape=torch.Size([])),
    full_action_spec: CompositeSpec(
        action: BoundedTensorSpec(
            shape=torch.Size([1]),
            space=ContinuousBox(
                low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
                high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
            device=cpu,
            dtype=torch.float32,
            domain=continuous),
        device=cpu,
        shape=torch.Size([])),
    device=cpu,
    shape=torch.Size([]))
action_spec (as defined by input_spec): BoundedTensorSpec(
    shape=torch.Size([1]),
    space=ContinuousBox(
        low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
        high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
    device=cpu,
    dtype=torch.float32,
    domain=continuous)

该函数运行一个小的 rollout,并将其输出与环境进行比较 规格。如果没有引发错误,我们可以确信 spec 已正确定义:check_env_specs()

check_env_specs(env)

为了好玩,让我们看看简单的随机推出是什么样子的。您可以 调用 env.rollout(n_steps) 并获取环境输入内容的概览 和输出如下所示。操作将自动从 action spec 中提取 domain 的 Domain,因此您无需关心设计随机采样器。

通常,在每个步骤中,RL 环境都会收到一个 action 作为输入,并输出 observation、reward 和 done 状态。这 观察可能是复合的,这意味着它可以由多个 张肌。这对 TorchRL 来说不是问题,因为整个观察集 会自动打包到 output 中。执行转出后 (例如,一系列环境步骤和随机动作生成)在给定的 number 的 steps 中,我们将检索具有 shape 的实例 匹配此轨迹长度:TensorDictTensorDict

rollout = env.rollout(3)
print("rollout of three steps:", rollout)
print("Shape of the rollout TensorDict:", rollout.batch_size)
rollout of three steps: TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([3, 11]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                step_count: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.int64, is_shared=False),
                terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([3]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([3, 11]), device=cpu, dtype=torch.float32, is_shared=False),
        step_count: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([3]),
    device=cpu,
    is_shared=False)
Shape of the rollout TensorDict: torch.Size([3])

我们的转出数据的形状为 ,它与步骤数匹配 我们运行它。该入口指向当前步骤之后的数据。 在大多数情况下,时间 t 的数据与 的数据匹配,但此 如果我们使用某些特定的转换(例如,多步骤),则可能不是这种情况。torch.Size([3])"next""next"t+1

政策

PPO 利用随机策略来处理探索。这意味着我们的 神经网络必须输出分布的参数,而不是 而不是与所执行的操作对应的单个值。

由于数据是连续的,因此我们使用 Tanh-Normal 分布来遵循 操作空间边界。TorchRL 提供了这样的发行版,并且唯一的 我们需要关心的是构建一个神经网络,将 策略要使用的正确参数数(位置或平均值、 和一个量表):

\[f_{\theta}(\text{observation}) = \mu_{\theta}(\text{observation}), \sigma^{+}_{\theta}(\text{observation})\]

这里提出的唯一额外困难是将我们的输出一分为二 相等的部分,并将秒映射到严格的正空格。

我们分三个步骤设计策略:

  1. 定义神经网络 -> 。事实上,我们的 (mu) 和 (sigma) 都有维度 。D_obs2 * D_actionlocscaleD_action

  2. 追加 a 以提取位置和比例(例如,将输入分成两个相等的部分,并对比例参数应用正变换)。NormalParamExtractor

  3. 创建一个可以生成此分布并从中采样的概率。TensorDictModule

为了使策略能够通过数据载体与环境“对话”,我们将 包装在 .这 类将简单地准备好它并编写 输出就地位于已注册的 .tensordictnn.ModuleTensorDictModulein_keysout_keys

policy_module = TensorDictModule(
    actor_net, in_keys=["observation"], out_keys=["loc", "scale"]
)

我们现在需要根据 正态分布。为此,我们指示类构建一个 out of the location 和 scale 参数。我们还提供了此 发行版,我们从环境规范中收集。ProbabilisticActorTanhNormal

的名称(因此是 from 的名称 上述)不能设置为任何值 like,因为 distribution 构造函数需要 and 关键字参数。话虽如此,也接受键值对指示的 typed 每个要使用的 keyword 参数应该使用什么 string。in_keysout_keysTensorDictModuleTanhNormallocscaleProbabilisticActorDict[str, str]in_keysin_key

policy_module = ProbabilisticActor(
    module=policy_module,
    spec=env.action_spec,
    in_keys=["loc", "scale"],
    distribution_class=TanhNormal,
    distribution_kwargs={
        "min": env.action_spec.space.low,
        "max": env.action_spec.space.high,
    },
    return_log_prob=True,
    # we'll need the log-prob for the numerator of the importance weights
)

价值网络

value network 是 PPO 算法的关键组成部分,尽管它 不会在推理时使用。此模块将读取 observations 和 返回以下轨迹的 discounted return 的估计值。 这允许我们通过依赖一些效用估计来摊销学习 这是在训练期间即时学习的。我们的价值网络具有相同的 structure 作为策略,但为了简单起见,我们为其分配了自己的 参数。

value_net = nn.Sequential(
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(1, device=device),
)

value_module = ValueOperator(
    module=value_net,
    in_keys=["observation"],
)

让我们试试我们的 Policy 和 Value 模块。正如我们之前所说,使用 of 可以直接读取输出 运行这些模块,因为它们知道要读取哪些信息 以及写入位置:TensorDictModule

print("Running policy:", policy_module(env.reset()))
print("Running value:", value_module(env.reset()))
Running policy: TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        loc: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        step_count: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)
Running value: TensorDict(
    fields={
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
        state_value: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        step_count: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)

数据收集器

TorchRL 提供了一组 DataCollector 类。 简而言之,这些类执行三个操作:重置环境、 根据最新的观察结果计算一个动作,在环境中执行一个步骤, 并重复最后两个步骤,直到环境发出停止信号(或达到 完成状态)。

它们允许您控制每次迭代时要收集的帧数 (通过参数)、 何时重置环境(通过参数), 应执行策略,等等。他们也是 旨在高效处理批处理和多处理环境。frames_per_batchmax_frames_per_trajdevice

最简单的数据收集器是 : 它是一个迭代器,你可以用它来获取给定长度的批量数据,并且 一旦帧总数 () 达到 收集。 其他数据收集器 ( 和 ) 将执行 在 一组多处理工作程序。SyncDataCollectortotal_framesMultiSyncDataCollectorMultiaSyncDataCollector

对于之前的策略和环境,数据收集器将返回具有元素总数的实例,这些元素将 火柴。用于将数据传递给 训练循环允许您编写数据加载管道 100% 忽略了推出内容的实际特性。TensorDictframes_per_batchTensorDict

collector = SyncDataCollector(
    env,
    policy_module,
    frames_per_batch=frames_per_batch,
    total_frames=total_frames,
    split_trajs=False,
    device=device,
)

重放缓冲区

重放缓冲区是非策略 RL 算法的常见构建部分。 在策略上下文中,每次一批 收集数据,并且其数据被重复消耗一定数量 的时代。

TorchRL 的重放缓冲区是使用一个通用容器构建的,该容器将组件作为参数 缓冲区:一个 Storage、一个 Writer、一个 Sampler 和可能的一些 Transform。 只有存储(指示重放缓冲区容量)是必需的。 我们还指定了一个没有重复的采样器,以避免多次采样 同一项在一个 epoch 中。 对 PPO 使用重放缓冲区不是强制性的,我们可以简单地 从收集的 Batch 中对子 Batch 进行采样,但使用这些类 让我们能够轻松地以可重现的方式构建内部训练循环。ReplayBuffer

replay_buffer = ReplayBuffer(
    storage=LazyTensorStorage(max_size=frames_per_batch),
    sampler=SamplerWithoutReplacement(),
)

损失函数

为了方便使用类,可以直接从 TorchRL 导入 PPO 损失。这是使用 PPO 的最简单方法: 它隐藏了 PPO 的数学运算和控制流 随它去。ClipPPOLoss

PPO 需要计算一些“优势估计”。简而言之,一个优势 是一个值,该值反映在处理 偏差/方差权衡。 要计算 advantage,只需要 (1) 构建 advantage 模块,该 使用我们的 Value 运算符,并且 (2) 在每个 时代。 GAE 模块将使用 new 和 entries 更新输入。 这是一个无梯度的张量,代表经验 值 network 应与输入观测值表示的值。 这两者都将被用于 退还保单和价值损失。tensordict"advantage""value_target""value_target"ClipPPOLoss

advantage_module = GAE(
    gamma=gamma, lmbda=lmbda, value_network=value_module, average_gae=True
)

loss_module = ClipPPOLoss(
    actor_network=policy_module,
    critic_network=value_module,
    clip_epsilon=clip_epsilon,
    entropy_bonus=bool(entropy_eps),
    entropy_coef=entropy_eps,
    # these keys match by default but we set this for completeness
    critic_coef=1.0,
    loss_critic_type="smooth_l1",
)

optim = torch.optim.Adam(loss_module.parameters(), lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optim, total_frames // frames_per_batch, 0.0
)

训练循环

现在,我们已具备编写训练循环所需的所有部分。 这些步骤包括:

  • 收集数据

    • 计算优势

      • 遍历收集到的 Loop to Compute loss values

      • 反向传播

      • 优化

      • 重复

    • 重复

  • 重复

logs = defaultdict(list)
pbar = tqdm(total=total_frames)
eval_str = ""

# We iterate over the collector until it reaches the total number of frames it was
# designed to collect:
for i, tensordict_data in enumerate(collector):
    # we now have a batch of data to work with. Let's learn something from it.
    for _ in range(num_epochs):
        # We'll need an "advantage" signal to make PPO work.
        # We re-compute it at each epoch as its value depends on the value
        # network which is updated in the inner loop.
        advantage_module(tensordict_data)
        data_view = tensordict_data.reshape(-1)
        replay_buffer.extend(data_view.cpu())
        for _ in range(frames_per_batch // sub_batch_size):
            subdata = replay_buffer.sample(sub_batch_size)
            loss_vals = loss_module(subdata.to(device))
            loss_value = (
                loss_vals["loss_objective"]
                + loss_vals["loss_critic"]
                + loss_vals["loss_entropy"]
            )

            # Optimization: backward, grad clipping and optimization step
            loss_value.backward()
            # this is not strictly mandatory but it's good practice to keep
            # your gradient norm bounded
            torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_grad_norm)
            optim.step()
            optim.zero_grad()

    logs["reward"].append(tensordict_data["next", "reward"].mean().item())
    pbar.update(tensordict_data.numel())
    cum_reward_str = (
        f"average reward={logs['reward'][-1]: 4.4f} (init={logs['reward'][0]: 4.4f})"
    )
    logs["step_count"].append(tensordict_data["step_count"].max().item())
    stepcount_str = f"step count (max): {logs['step_count'][-1]}"
    logs["lr"].append(optim.param_groups[0]["lr"])
    lr_str = f"lr policy: {logs['lr'][-1]: 4.4f}"
    if i % 10 == 0:
        # We evaluate the policy once every 10 batches of data.
        # Evaluation is rather simple: execute the policy without exploration
        # (take the expected value of the action distribution) for a given
        # number of steps (1000, which is our ``env`` horizon).
        # The ``rollout`` method of the ``env`` can take a policy as argument:
        # it will then execute this policy at each step.
        with set_exploration_type(ExplorationType.MEAN), torch.no_grad():
            # execute a rollout with the trained policy
            eval_rollout = env.rollout(1000, policy_module)
            logs["eval reward"].append(eval_rollout["next", "reward"].mean().item())
            logs["eval reward (sum)"].append(
                eval_rollout["next", "reward"].sum().item()
            )
            logs["eval step_count"].append(eval_rollout["step_count"].max().item())
            eval_str = (
                f"eval cumulative reward: {logs['eval reward (sum)'][-1]: 4.4f} "
                f"(init: {logs['eval reward (sum)'][0]: 4.4f}), "
                f"eval step-count: {logs['eval step_count'][-1]}"
            )
            del eval_rollout
    pbar.set_description(", ".join([eval_str, cum_reward_str, stepcount_str, lr_str]))

    # We're also using a learning rate scheduler. Like the gradient clipping,
    # this is a nice-to-have but nothing necessary for PPO to work.
    scheduler.step()
  0%|          | 0/50000 [00:00<?, ?it/s]
  2%|2         | 1000/50000 [00:03<03:10, 256.79it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.0998 (init= 9.0998), step count (max): 16, lr policy:  0.0003:   2%|2         | 1000/50000 [00:03<03:10, 256.79it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.0998 (init= 9.0998), step count (max): 16, lr policy:  0.0003:   4%|4         | 2000/50000 [00:07<03:04, 259.51it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.1175 (init= 9.0998), step count (max): 14, lr policy:  0.0003:   4%|4         | 2000/50000 [00:07<03:04, 259.51it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.1175 (init= 9.0998), step count (max): 14, lr policy:  0.0003:   6%|6         | 3000/50000 [00:11<02:59, 261.82it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.1509 (init= 9.0998), step count (max): 14, lr policy:  0.0003:   6%|6         | 3000/50000 [00:11<02:59, 261.82it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.1509 (init= 9.0998), step count (max): 14, lr policy:  0.0003:   8%|8         | 4000/50000 [00:15<02:54, 263.16it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.1931 (init= 9.0998), step count (max): 22, lr policy:  0.0003:   8%|8         | 4000/50000 [00:15<02:54, 263.16it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.1931 (init= 9.0998), step count (max): 22, lr policy:  0.0003:  10%|#         | 5000/50000 [00:19<02:50, 264.70it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.2155 (init= 9.0998), step count (max): 27, lr policy:  0.0003:  10%|#         | 5000/50000 [00:19<02:50, 264.70it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.2155 (init= 9.0998), step count (max): 27, lr policy:  0.0003:  12%|#2        | 6000/50000 [00:22<02:45, 265.97it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.2189 (init= 9.0998), step count (max): 25, lr policy:  0.0003:  12%|#2        | 6000/50000 [00:22<02:45, 265.97it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.2189 (init= 9.0998), step count (max): 25, lr policy:  0.0003:  14%|#4        | 7000/50000 [00:26<02:41, 266.81it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.2371 (init= 9.0998), step count (max): 47, lr policy:  0.0003:  14%|#4        | 7000/50000 [00:26<02:41, 266.81it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.2371 (init= 9.0998), step count (max): 47, lr policy:  0.0003:  16%|#6        | 8000/50000 [00:30<02:37, 267.35it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.2277 (init= 9.0998), step count (max): 36, lr policy:  0.0003:  16%|#6        | 8000/50000 [00:30<02:37, 267.35it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.2277 (init= 9.0998), step count (max): 36, lr policy:  0.0003:  18%|#8        | 9000/50000 [00:33<02:32, 268.43it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.2517 (init= 9.0998), step count (max): 41, lr policy:  0.0003:  18%|#8        | 9000/50000 [00:33<02:32, 268.43it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.2517 (init= 9.0998), step count (max): 41, lr policy:  0.0003:  20%|##        | 10000/50000 [00:37<02:32, 262.98it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.2600 (init= 9.0998), step count (max): 41, lr policy:  0.0003:  20%|##        | 10000/50000 [00:37<02:32, 262.98it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.2600 (init= 9.0998), step count (max): 41, lr policy:  0.0003:  22%|##2       | 11000/50000 [00:41<02:26, 265.59it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2575 (init= 9.0998), step count (max): 38, lr policy:  0.0003:  22%|##2       | 11000/50000 [00:41<02:26, 265.59it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2575 (init= 9.0998), step count (max): 38, lr policy:  0.0003:  24%|##4       | 12000/50000 [00:45<02:23, 265.18it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2730 (init= 9.0998), step count (max): 56, lr policy:  0.0003:  24%|##4       | 12000/50000 [00:45<02:23, 265.18it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2730 (init= 9.0998), step count (max): 56, lr policy:  0.0003:  26%|##6       | 13000/50000 [00:48<02:18, 267.12it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2719 (init= 9.0998), step count (max): 55, lr policy:  0.0003:  26%|##6       | 13000/50000 [00:48<02:18, 267.12it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2719 (init= 9.0998), step count (max): 55, lr policy:  0.0003:  28%|##8       | 14000/50000 [00:52<02:14, 268.62it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2725 (init= 9.0998), step count (max): 102, lr policy:  0.0003:  28%|##8       | 14000/50000 [00:52<02:14, 268.62it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2725 (init= 9.0998), step count (max): 102, lr policy:  0.0003:  30%|###       | 15000/50000 [00:56<02:09, 269.74it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2774 (init= 9.0998), step count (max): 95, lr policy:  0.0002:  30%|###       | 15000/50000 [00:56<02:09, 269.74it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2774 (init= 9.0998), step count (max): 95, lr policy:  0.0002:  32%|###2      | 16000/50000 [01:00<02:05, 270.42it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2724 (init= 9.0998), step count (max): 59, lr policy:  0.0002:  32%|###2      | 16000/50000 [01:00<02:05, 270.42it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2724 (init= 9.0998), step count (max): 59, lr policy:  0.0002:  34%|###4      | 17000/50000 [01:03<02:01, 271.23it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2809 (init= 9.0998), step count (max): 89, lr policy:  0.0002:  34%|###4      | 17000/50000 [01:03<02:01, 271.23it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2809 (init= 9.0998), step count (max): 89, lr policy:  0.0002:  36%|###6      | 18000/50000 [01:07<01:57, 271.33it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2828 (init= 9.0998), step count (max): 83, lr policy:  0.0002:  36%|###6      | 18000/50000 [01:07<01:57, 271.33it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2828 (init= 9.0998), step count (max): 83, lr policy:  0.0002:  38%|###8      | 19000/50000 [01:11<01:54, 271.37it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2828 (init= 9.0998), step count (max): 69, lr policy:  0.0002:  38%|###8      | 19000/50000 [01:11<01:54, 271.37it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2828 (init= 9.0998), step count (max): 69, lr policy:  0.0002:  40%|####      | 20000/50000 [01:14<01:50, 270.62it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2765 (init= 9.0998), step count (max): 66, lr policy:  0.0002:  40%|####      | 20000/50000 [01:14<01:50, 270.62it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2765 (init= 9.0998), step count (max): 66, lr policy:  0.0002:  42%|####2     | 21000/50000 [01:18<01:46, 271.33it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2970 (init= 9.0998), step count (max): 121, lr policy:  0.0002:  42%|####2     | 21000/50000 [01:18<01:46, 271.33it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2970 (init= 9.0998), step count (max): 121, lr policy:  0.0002:  44%|####4     | 22000/50000 [01:22<01:44, 267.08it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.3032 (init= 9.0998), step count (max): 125, lr policy:  0.0002:  44%|####4     | 22000/50000 [01:22<01:44, 267.08it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.3032 (init= 9.0998), step count (max): 125, lr policy:  0.0002:  46%|####6     | 23000/50000 [01:26<01:42, 262.83it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2970 (init= 9.0998), step count (max): 78, lr policy:  0.0002:  46%|####6     | 23000/50000 [01:26<01:42, 262.83it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2970 (init= 9.0998), step count (max): 78, lr policy:  0.0002:  48%|####8     | 24000/50000 [01:29<01:37, 265.77it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2985 (init= 9.0998), step count (max): 113, lr policy:  0.0002:  48%|####8     | 24000/50000 [01:29<01:37, 265.77it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2985 (init= 9.0998), step count (max): 113, lr policy:  0.0002:  50%|#####     | 25000/50000 [01:33<01:33, 267.84it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.3044 (init= 9.0998), step count (max): 102, lr policy:  0.0002:  50%|#####     | 25000/50000 [01:33<01:33, 267.84it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.3044 (init= 9.0998), step count (max): 102, lr policy:  0.0002:  52%|#####2    | 26000/50000 [01:37<01:29, 269.15it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2937 (init= 9.0998), step count (max): 87, lr policy:  0.0001:  52%|#####2    | 26000/50000 [01:37<01:29, 269.15it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2937 (init= 9.0998), step count (max): 87, lr policy:  0.0001:  54%|#####4    | 27000/50000 [01:41<01:25, 268.28it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2961 (init= 9.0998), step count (max): 70, lr policy:  0.0001:  54%|#####4    | 27000/50000 [01:41<01:25, 268.28it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2961 (init= 9.0998), step count (max): 70, lr policy:  0.0001:  56%|#####6    | 28000/50000 [01:44<01:21, 268.42it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2842 (init= 9.0998), step count (max): 60, lr policy:  0.0001:  56%|#####6    | 28000/50000 [01:44<01:21, 268.42it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2842 (init= 9.0998), step count (max): 60, lr policy:  0.0001:  58%|#####8    | 29000/50000 [01:48<01:17, 269.30it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2952 (init= 9.0998), step count (max): 67, lr policy:  0.0001:  58%|#####8    | 29000/50000 [01:48<01:17, 269.30it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2952 (init= 9.0998), step count (max): 67, lr policy:  0.0001:  60%|######    | 30000/50000 [01:52<01:14, 270.23it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2988 (init= 9.0998), step count (max): 75, lr policy:  0.0001:  60%|######    | 30000/50000 [01:52<01:14, 270.23it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2988 (init= 9.0998), step count (max): 75, lr policy:  0.0001:  62%|######2   | 31000/50000 [01:55<01:10, 270.84it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.2974 (init= 9.0998), step count (max): 77, lr policy:  0.0001:  62%|######2   | 31000/50000 [01:55<01:10, 270.84it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.2974 (init= 9.0998), step count (max): 77, lr policy:  0.0001:  64%|######4   | 32000/50000 [01:59<01:07, 267.85it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3021 (init= 9.0998), step count (max): 100, lr policy:  0.0001:  64%|######4   | 32000/50000 [01:59<01:07, 267.85it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3021 (init= 9.0998), step count (max): 100, lr policy:  0.0001:  66%|######6   | 33000/50000 [02:03<01:03, 268.88it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3097 (init= 9.0998), step count (max): 175, lr policy:  0.0001:  66%|######6   | 33000/50000 [02:03<01:03, 268.88it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3097 (init= 9.0998), step count (max): 175, lr policy:  0.0001:  68%|######8   | 34000/50000 [02:06<00:59, 270.19it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3168 (init= 9.0998), step count (max): 140, lr policy:  0.0001:  68%|######8   | 34000/50000 [02:06<00:59, 270.19it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3168 (init= 9.0998), step count (max): 140, lr policy:  0.0001:  70%|#######   | 35000/50000 [02:10<00:56, 264.93it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3159 (init= 9.0998), step count (max): 117, lr policy:  0.0001:  70%|#######   | 35000/50000 [02:10<00:56, 264.93it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3159 (init= 9.0998), step count (max): 117, lr policy:  0.0001:  72%|#######2  | 36000/50000 [02:14<00:52, 267.19it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3155 (init= 9.0998), step count (max): 132, lr policy:  0.0001:  72%|#######2  | 36000/50000 [02:14<00:52, 267.19it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3155 (init= 9.0998), step count (max): 132, lr policy:  0.0001:  74%|#######4  | 37000/50000 [02:18<00:48, 268.67it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3185 (init= 9.0998), step count (max): 118, lr policy:  0.0001:  74%|#######4  | 37000/50000 [02:18<00:48, 268.67it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3185 (init= 9.0998), step count (max): 118, lr policy:  0.0001:  76%|#######6  | 38000/50000 [02:21<00:44, 270.04it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3231 (init= 9.0998), step count (max): 147, lr policy:  0.0000:  76%|#######6  | 38000/50000 [02:21<00:44, 270.04it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3231 (init= 9.0998), step count (max): 147, lr policy:  0.0000:  78%|#######8  | 39000/50000 [02:25<00:40, 270.74it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3233 (init= 9.0998), step count (max): 173, lr policy:  0.0000:  78%|#######8  | 39000/50000 [02:25<00:40, 270.74it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3233 (init= 9.0998), step count (max): 173, lr policy:  0.0000:  80%|########  | 40000/50000 [02:29<00:36, 271.60it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3168 (init= 9.0998), step count (max): 135, lr policy:  0.0000:  80%|########  | 40000/50000 [02:29<00:36, 271.60it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3168 (init= 9.0998), step count (max): 135, lr policy:  0.0000:  82%|########2 | 41000/50000 [02:32<00:33, 272.21it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3168 (init= 9.0998), step count (max): 135, lr policy:  0.0000:  82%|########2 | 41000/50000 [02:32<00:33, 272.21it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3168 (init= 9.0998), step count (max): 135, lr policy:  0.0000:  84%|########4 | 42000/50000 [02:36<00:29, 270.49it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3261 (init= 9.0998), step count (max): 166, lr policy:  0.0000:  84%|########4 | 42000/50000 [02:36<00:29, 270.49it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3261 (init= 9.0998), step count (max): 166, lr policy:  0.0000:  86%|########6 | 43000/50000 [02:40<00:25, 271.36it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3296 (init= 9.0998), step count (max): 193, lr policy:  0.0000:  86%|########6 | 43000/50000 [02:40<00:25, 271.36it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3296 (init= 9.0998), step count (max): 193, lr policy:  0.0000:  88%|########8 | 44000/50000 [02:43<00:22, 271.95it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3385 (init= 9.0998), step count (max): 182, lr policy:  0.0000:  88%|########8 | 44000/50000 [02:43<00:22, 271.95it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3385 (init= 9.0998), step count (max): 182, lr policy:  0.0000:  90%|######### | 45000/50000 [02:47<00:18, 272.30it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3294 (init= 9.0998), step count (max): 189, lr policy:  0.0000:  90%|######### | 45000/50000 [02:47<00:18, 272.30it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3294 (init= 9.0998), step count (max): 189, lr policy:  0.0000:  92%|#########2| 46000/50000 [02:51<00:15, 266.35it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3320 (init= 9.0998), step count (max): 197, lr policy:  0.0000:  92%|#########2| 46000/50000 [02:51<00:15, 266.35it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3320 (init= 9.0998), step count (max): 197, lr policy:  0.0000:  94%|#########3| 47000/50000 [02:55<00:11, 268.64it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3278 (init= 9.0998), step count (max): 160, lr policy:  0.0000:  94%|#########3| 47000/50000 [02:55<00:11, 268.64it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3278 (init= 9.0998), step count (max): 160, lr policy:  0.0000:  96%|#########6| 48000/50000 [02:58<00:07, 270.15it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3257 (init= 9.0998), step count (max): 162, lr policy:  0.0000:  96%|#########6| 48000/50000 [02:58<00:07, 270.15it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3257 (init= 9.0998), step count (max): 162, lr policy:  0.0000:  98%|#########8| 49000/50000 [03:02<00:03, 271.35it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3230 (init= 9.0998), step count (max): 118, lr policy:  0.0000:  98%|#########8| 49000/50000 [03:02<00:03, 271.35it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3230 (init= 9.0998), step count (max): 118, lr policy:  0.0000: 100%|##########| 50000/50000 [03:06<00:00, 272.31it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3355 (init= 9.0998), step count (max): 348, lr policy:  0.0000: 100%|##########| 50000/50000 [03:06<00:00, 272.31it/s]

结果

在达到 1M 步长上限之前,算法应已达到最大值 Step Count 为 1000 个步骤,这是 轨迹被截断。

plt.figure(figsize=(10, 10))
plt.subplot(2, 2, 1)
plt.plot(logs["reward"])
plt.title("training rewards (average)")
plt.subplot(2, 2, 2)
plt.plot(logs["step_count"])
plt.title("Max step count (training)")
plt.subplot(2, 2, 3)
plt.plot(logs["eval reward (sum)"])
plt.title("Return (test)")
plt.subplot(2, 2, 4)
plt.plot(logs["eval step_count"])
plt.title("Max step count (test)")
plt.show()
训练奖励 (平均)、最大步数 (训练)、返回 (测试)、最大步数 (测试)

结论和下一步

在本教程中,我们学习了:

  1. 如何创建和自定义环境torchrl;

  2. 如何编写模型和损失函数;

  3. 如何设置典型的训练循环。

如果您想对本教程进行更多试验,可以应用以下修改:

  • 从效率的角度来看, 我们可以并行运行多个模拟以加快数据收集速度。 查看更多信息。ParallelEnv

  • 从日志记录的角度来看,可以将转换添加到 环境请求渲染后获取视觉渲染 倒摆在行动中。检查 了解更多。torchrl.record.VideoRecordertorchrl.record

脚本总运行时间:(3 分 7.796 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源