目录

TorchRL trainer:DQN 示例

作者Vincent Moens

TorchRL 提供了一个泛型类来处理 您的训练循环。trainer 执行一个嵌套循环,其中外部循环 是数据收集,内部循环使用此数据或某些数据 从 Replay 缓冲区中检索以训练模型。 在这个训练循环的不同点,钩子可以在 给定的间隔。

在本教程中,我们将使用 trainer 类来训练 DQN 算法 从头开始解决 CartPole 任务。

主要收获:

  • 构建一个包含基本组件的 trainer:数据收集器、损失 模块、重放缓冲区和优化器。

  • 向 trainer 添加钩子,例如 loggers、target network updater 等。

该 Trainer 是完全可定制的,并提供了大量功能。 本教程围绕其构造进行组织。 我们将首先详细介绍如何构建库的每个组件。 然后使用 class 将各个部分组合在一起

在此过程中,我们还将关注库的其他一些方面:

  • 如何在 TorchRL 中构建环境,包括转换(例如数据 标准化、帧连接、调整大小和转换为灰度) 和并行执行。与我们在 DDPG 教程中所做的不同,我们 将规格化像素,而不是状态向量。

  • 如何设计对象,即 Actor 这会估算操作值并选取具有最高 预计回报;QValueActor

  • 如何有效地从环境中收集数据并存储数据 在 replay 缓冲区中;

  • 如何使用 Multi-Step,这是 Off-Policy 算法的简单预处理步骤;

  • 最后如何评估您的模型。

先决条件: 我们建议您先通过 PPO 教程熟悉 torchrl。

DQN

DQN(深度 Q 学习)是 深度强化学习的奠基工作。

在高层次上,算法非常简单:Q-学习包括 学习 state-action 值表,这样,当 遇到任何特定状态,我们都知道该选择哪个 action 搜索值最高的 1。这个简单的设置 要求 actions 和 states 为 discrete,否则无法构建 lookup table。

DQN 使用神经网络将 map从状态操作空间编码为 一个值(标量)空间,它摊销存储和探索所有 可能的状态-操作组合:如果在 过去,我们仍然可以将其与各种可用的操作一起传递 通过我们的神经网络获取每个 可用的操作。

我们将解决 Cart Pole 的经典控制问题。从 从中检索此环境的 Gymnasium 文档:

杆子由未致动的关节连接到推车上,推车沿
无摩擦轨道。钟摆直立放置在推车上,球门
是通过在左右方向施加力来平衡杆
在购物车上。
推车杆

我们的目的不是给出算法的 SOTA 实现,而是 在上下文中提供 TorchRL 功能的高级说明 的

import os
import uuid

import torch
from torch import nn
from torchrl.collectors import MultiaSyncDataCollector, SyncDataCollector
from torchrl.data import LazyMemmapStorage, MultiStep, TensorDictReplayBuffer
from torchrl.envs import (
    EnvCreator,
    ExplorationType,
    ParallelEnv,
    RewardScaling,
    StepCounter,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.transforms import (
    CatFrames,
    Compose,
    GrayScale,
    ObservationNorm,
    Resize,
    ToTensorImage,
    TransformedEnv,
)
from torchrl.modules import DuelingCnnDQNet, EGreedyModule, QValueActor

from torchrl.objectives import DQNLoss, SoftUpdate
from torchrl.record.loggers.csv import CSVLogger
from torchrl.trainers import (
    LogReward,
    Recorder,
    ReplayBufferTrainer,
    Trainer,
    UpdateWeights,
)


def is_notebook() -> bool:
    try:
        shell = get_ipython().__class__.__name__
        if shell == "ZMQInteractiveShell":
            return True  # Jupyter notebook or qtconsole
        elif shell == "TerminalInteractiveShell":
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False  # Probably standard Python interpreter

让我们从算法所需的各种部分开始:

  • 环境;

  • 一个策略(以及我们在 “model” 伞下分组的相关模块);

  • 数据收集器,它使策略在环境中发挥作用,并且 提供训练数据;

  • 用于存储训练数据的重放缓冲区;

  • 损失模块,计算目标函数来训练我们的策略 使回报最大化;

  • 一个优化器,它根据我们的损失执行参数更新。

其他模块包括 logger、recorder(在 “eval” 模式)和 Target Network Updater 的 SET 实例。将所有这些组件放入 place 中,很容易看出人们是如何放错地方或误用 训练脚本。培训师在那里为您精心安排一切!

构建环境

首先,让我们编写一个将输出环境的辅助函数。照常 “原始”环境可能太简单了,无法在实践中使用,我们需要 一些数据转换,以将其输出公开给策略。

我们将使用五个转换:

  • StepCounter计算每个轨迹中的步数;

  • 将转换一个 uint8 Tensor 在浮点中,Tensor 在 shape[W, H, C][0, 1][C, W, H];

  • 减少回报的规模;

  • 将我们的图像变成灰度;

  • 将以 64x64 格式调整图像大小;

  • 将连接任意数量的 沿通道维度的单个张量中的连续帧 ()。 这很有用,因为单个图像不包含有关 Cartpole 的运动。关于过去的观察和行动的一些记忆 需要,通过递归神经网络或使用 框架。N=4

  • 这将使我们的观察正常化 给定一些自定义摘要统计信息。

在实践中,我们的环境构建器有两个参数:

  • parallel:确定是否必须在 平行。我们在 之后堆叠转换以利用 对设备上的操作进行矢量化,尽管这会 从技术上讲,适用于附加到其自己的 变换。

  • obs_norm_sd将包含 转换。ObservationNorm

def make_env(
    parallel=False,
    obs_norm_sd=None,
    num_workers=1,
):
    if obs_norm_sd is None:
        obs_norm_sd = {"standard_normal": True}
    if parallel:

        def maker():
            return GymEnv(
                "CartPole-v1",
                from_pixels=True,
                pixels_only=True,
                device=device,
            )

        base_env = ParallelEnv(
            num_workers,
            EnvCreator(maker),
            # Don't create a sub-process if we have only one worker
            serial_for_single=True,
            mp_start_method=mp_context,
        )
    else:
        base_env = GymEnv(
            "CartPole-v1",
            from_pixels=True,
            pixels_only=True,
            device=device,
        )

    env = TransformedEnv(
        base_env,
        Compose(
            StepCounter(),  # to count the steps of each trajectory
            ToTensorImage(),
            RewardScaling(loc=0.0, scale=0.1),
            GrayScale(),
            Resize(64, 64),
            CatFrames(4, in_keys=["pixels"], dim=-3),
            ObservationNorm(in_keys=["pixels"], **obs_norm_sd),
        ),
    )
    return env

计算归一化常量

要标准化图像,我们不想单独标准化每个像素 具有完整的归一化蒙版,但具有更简单的形状归一化常量集(LOC 和 Scale 参数)。 我们将使用参数 的来指示哪个 dimensions 必须减小,并且参数要确保 并非所有尺寸都会在此过程中消失:[C, W, H][C, 1, 1]reduce_diminit_stats()keep_dims

def get_norm_stats():
    test_env = make_env()
    test_env.transform[-1].init_stats(
        num_iter=1000, cat_dim=0, reduce_dim=[-1, -2, -4], keep_dims=(-1, -2)
    )
    obs_norm_sd = test_env.transform[-1].state_dict()
    # let's check that normalizing constants have a size of ``[C, 1, 1]`` where
    # ``C=4`` (because of :class:`~torchrl.envs.CatFrames`).
    print("state dict of the observation norm:", obs_norm_sd)
    test_env.close()
    del test_env
    return obs_norm_sd

构建模型 (Deep Q 网络)

以下函数构建一个对象,该对象是一个简单的 CNN,后跟一个两层 MLP。唯一使用的技巧 这里是 action 值(即 left 和 right action value)是 计算方式

\[\mathbb{v} = b(obs) + v(obs) - \mathbb{E}[v(obs)]\]

其中 是我们的动作值向量,是一个函数,是一个函数,对于

我们的网络被包装在一个 , 它将读取 state-action 值,选择具有最大值的那个并写入所有这些结果 在输入 .QValueActortensordict.TensorDict

def make_model(dummy_env):
    cnn_kwargs = {
        "num_cells": [32, 64, 64],
        "kernel_sizes": [6, 4, 3],
        "strides": [2, 2, 1],
        "activation_class": nn.ELU,
        # This can be used to reduce the size of the last layer of the CNN
        # "squeeze_output": True,
        # "aggregator_class": nn.AdaptiveAvgPool2d,
        # "aggregator_kwargs": {"output_size": (1, 1)},
    }
    mlp_kwargs = {
        "depth": 2,
        "num_cells": [
            64,
            64,
        ],
        "activation_class": nn.ELU,
    }
    net = DuelingCnnDQNet(
        dummy_env.action_spec.shape[-1], 1, cnn_kwargs, mlp_kwargs
    ).to(device)
    net.value[-1].bias.data.fill_(init_bias)

    actor = QValueActor(net, in_keys=["pixels"], spec=dummy_env.action_spec).to(device)
    # init actor: because the model is composed of lazy conv/linear layers,
    # we must pass a fake batch of data through it to instantiate them.
    tensordict = dummy_env.fake_tensordict()
    actor(tensordict)

    # we join our actor with an EGreedyModule for data collection
    exploration_module = EGreedyModule(
        spec=dummy_env.action_spec,
        annealing_num_steps=total_frames,
        eps_init=eps_greedy_val,
        eps_end=eps_greedy_val_env,
    )
    actor_explore = TensorDictSequential(actor, exploration_module)

    return actor, actor_explore

收集和存储数据

重放缓冲区

重放缓冲区在 DQN 等非策略 RL 算法中起着核心作用。 它们构成了我们将在训练期间从中采样的数据集。

在这里,我们将使用常规抽样策略,尽管优先考虑 RB 可以显著提高性能。

我们使用 class 将存储放在磁盘上。这 storage 以惰性方式创建:它只会在 第一批数据将传递给它。

此存储的唯一要求是在写入时传递给它的数据 时间必须始终具有相同的形状。

def get_replay_buffer(buffer_size, n_optim, batch_size):
    replay_buffer = TensorDictReplayBuffer(
        batch_size=batch_size,
        storage=LazyMemmapStorage(buffer_size),
        prefetch=n_optim,
    )
    return replay_buffer

数据收集器

PPODDPG 一样,我们将使用 一个数据收集器作为外部循环中的 DataLoader。

我们选择以下配置:我们将运行一系列 并行环境 在不同的收集器中同步并行, 它们本身以并行方式运行,但以异步方式运行。

注意

此功能仅在 “spawn” 中运行代码时可用 Python Multiprocessing 库的 start 方法。如果运行本教程 直接作为脚本(从而使用 “fork” 方法),我们将使用 常规 .

这种配置的优点是我们可以平衡 计算,该计算与我们想要执行的内容一起批量执行 异步。我们鼓励读者尝试该系列 速度受修改 collector 数量(即 环境构造函数)和 environment 在每个收集器中并行执行(由 hyperparameter 控制)。num_workers

Collector 的设备可以通过 (general)、 和 参数进行完全参数化。 该参数将修改 正在收集的数据的位置:如果我们正在收集的批次 具有相当大的大小,我们可能希望将它们存储在不同的位置 而不是进行计算的设备。对于异步数据 像我们这样的收集器,不同的存储设备意味着 We collect 不会每次都位于同一设备上,这是 出训练循环必须考虑。为简单起见,我们将设备设置为 所有 Sub-collector 的值相同。devicepolicy_deviceenv_devicestoring_devicestoring_device

def get_collector(
    stats,
    num_collectors,
    actor_explore,
    frames_per_batch,
    total_frames,
    device,
):
    # We can't use nested child processes with mp_start_method="fork"
    if is_fork:
        cls = SyncDataCollector
        env_arg = make_env(parallel=True, obs_norm_sd=stats, num_workers=num_workers)
    else:
        cls = MultiaSyncDataCollector
        env_arg = [
            make_env(parallel=True, obs_norm_sd=stats, num_workers=num_workers)
        ] * num_collectors
    data_collector = cls(
        env_arg,
        policy=actor_explore,
        frames_per_batch=frames_per_batch,
        total_frames=total_frames,
        # this is the default behavior: the collector runs in ``"random"`` (or explorative) mode
        exploration_type=ExplorationType.RANDOM,
        # We set the all the devices to be identical. Below is an example of
        # heterogeneous devices
        device=device,
        storing_device=device,
        split_trajs=False,
        postproc=MultiStep(gamma=gamma, n_steps=5),
    )
    return data_collector

损失函数

构建我们的损失函数很简单:我们只需要提供 模型和一组超参数添加到 DQNLoss 类中。

目标参数

许多非策略 RL 算法在使用 “目标参数” 的概念时 来估计下一个 state 或 state-action 对的值。 目标参数是模型参数的滞后副本。因为 他们的预测与当前模型配置的预测不匹配,则 通过对所估计的值设置悲观界限来帮助学习。 这是一个无处不在的强大技巧(称为“双 Q-学习”) 在类似的算法中。

def get_loss_module(actor, gamma):
    loss_module = DQNLoss(actor, delay_value=True)
    loss_module.make_value_estimator(gamma=gamma)
    target_updater = SoftUpdate(loss_module, eps=0.995)
    return loss_module, target_updater

超参数

让我们从超参数开始。以下设置应该可以正常工作 在实践中,希望算法的性能不应 对这些的微小变化太敏感了。

is_fork = multiprocessing.get_start_method() == "fork"
device = (
    torch.device(0)
    if torch.cuda.is_available() and not is_fork
    else torch.device("cpu")
)

优化

# the learning rate of the optimizer
lr = 2e-3
# weight decay
wd = 1e-5
# the beta parameters of Adam
betas = (0.9, 0.999)
# Optimization steps per batch collected (aka UPD or updates per data)
n_optim = 8

DQN 参数

Gamma 衰减因子

gamma = 0.99

Smooth target network update decay 参数。 这大致对应于硬目标网络的 1/tau 区间 更新

tau = 0.02

数据收集和重放缓冲区

注意

用于正确训练的值已注释。

环境中收集的帧总数。在其他实现中, user 定义最大剧集数。 这对我们的数据收集器来说更难做到,因为它们会返回批次 的 N 个收集帧,其中 N 是一个常数。 但是,人们可以很容易地获得相同的集数限制 当某个数字 剧集已被收集。

total_frames = 5_000  # 500000

用于初始化重播缓冲区的随机帧。

init_random_frames = 100  # 1000

收集的每个批次中的帧。

frames_per_batch = 32  # 128

在每个优化步骤中从 replay 缓冲区采样的帧

batch_size = 32  # 256

重放缓冲区的大小(以帧为单位)

buffer_size = min(total_frames, 100000)

每个数据收集器中并行运行的环境数

num_workers = 2  # 8
num_collectors = 2  # 4

环境和勘探

我们在 Epsilon-greedy 中设置 epsilon 因子的初始值和最终值 勘探。 由于我们的策略是确定性的,因此探索至关重要:没有它, 随机性的唯一来源是环境重置。

eps_greedy_val = 0.1
eps_greedy_val_env = 0.005

为了加快学习速度,我们设置了价值网络最后一层的偏差 设置为预定义值(这不是强制性的)

init_bias = 2.0

注意

用于快速呈现教程超参数 设置为非常低的数字。为了获得合理的性能,请使用更大的 值,例如 500000total_frames

构建 Trainer

TorchRL 的类构造函数采用 以下仅关键字参数:

  • collector

  • loss_module

  • optimizer

  • logger:记录器可以是

  • total_frames:此参数定义 Trainer 的生命周期。

  • frame_skip:使用跳帧时,必须创建收集器 了解它,以便准确计算帧数 收集等。让 trainer 知道此参数不是 强制的,但有助于在设置之间进行更公平的比较,其中 总帧数 (预算) 是固定的,但 frame-skip 是 变量。

stats = get_norm_stats()
test_env = make_env(parallel=False, obs_norm_sd=stats)
# Get model
actor, actor_explore = make_model(test_env)
loss_module, target_net_updater = get_loss_module(actor, gamma)

collector = get_collector(
    stats=stats,
    num_collectors=num_collectors,
    actor_explore=actor_explore,
    frames_per_batch=frames_per_batch,
    total_frames=total_frames,
    device=device,
)
optimizer = torch.optim.Adam(
    loss_module.parameters(), lr=lr, weight_decay=wd, betas=betas
)
exp_name = f"dqn_exp_{uuid.uuid1()}"
tmpdir = tempfile.TemporaryDirectory()
logger = CSVLogger(exp_name=exp_name, log_dir=tmpdir.name)
warnings.warn(f"log dir: {logger.experiment.log_dir}")
state dict of the observation norm: OrderedDict([('standard_normal', tensor(True)), ('loc', tensor([[[0.9895]],

        [[0.9895]],

        [[0.9895]],

        [[0.9895]]])), ('scale', tensor([[[0.0737]],

        [[0.0737]],

        [[0.0737]],

        [[0.0737]]]))])

我们可以控制记录标量的频率。这里我们设置这个 设置为较低的值,因为我们的训练循环很短:

log_interval = 500

trainer = Trainer(
    collector=collector,
    total_frames=total_frames,
    frame_skip=1,
    loss_module=loss_module,
    optimizer=optimizer,
    logger=logger,
    optim_steps_per_batch=n_optim,
    log_interval=log_interval,
)

注册 hook

注册 hook 可以通过两种不同的方式实现:

  • 如果 hook 有它,则该方法是首选。只需提供 trainer 作为输入 并且 hook 将在默认位置使用默认名称注册。 对于某些钩子,注册可能相当复杂:需要 3 个钩子 (和 ),其中 实现起来可能很麻烦。extendsampleupdate_priority

buffer_hook = ReplayBufferTrainer(
    get_replay_buffer(buffer_size, n_optim, batch_size=batch_size),
    flatten_tensordicts=True,
)
buffer_hook.register(trainer)
weight_updater = UpdateWeights(collector, update_weights_interval=1)
weight_updater.register(trainer)
recorder = Recorder(
    record_interval=100,  # log every 100 optimization steps
    record_frames=1000,  # maximum number of frames in the record
    frame_skip=1,
    policy_exploration=actor_explore,
    environment=test_env,
    exploration_type=ExplorationType.DETERMINISTIC,
    log_keys=[("next", "reward")],
    out_keys={("next", "reward"): "rewards"},
    log_pbar=True,
)
recorder.register(trainer)

探测模块 ε 因子也经过退火处理:

trainer.register_op("post_steps", actor_explore[1].step, frames=frames_per_batch)
  • 任何可调用对象(包括子类)都可以使用 . 在这种情况下,必须显式传递 location ()。此方法给出 对钩子的位置有更多的控制,但也需要更多的 了解 Trainer 机制。 查看 Trainer 文档以获取 Trainer 钩子的详细说明。register_op()

trainer.register_op("post_optim", target_net_updater.step)

我们也可以记录训练奖励。请注意,这无关紧要 使用 CartPole,因为奖励始终为 1。折扣后的奖励总和为 最大化不是通过获得更高的奖励,而是通过保持车杆的活力 更长的时间。 这将total_rewards反映在 进度条。

log_reward = LogReward(log_pbar=True)
log_reward.register(trainer)

注意

如果需要,可以将多个 Optimizer 链接到 trainer。 在这种情况下,每个优化器都将绑定到损失中的一个字段 字典。 查看 以了解更多信息。

我们在这里,准备好训练我们的算法了!只需简单地打电话,我们就会登录我们的结果。trainer.train()

trainer.train()
  0%|          | 0/5000 [00:00<?, ?it/s]
  1%|          | 32/5000 [00:07<20:39,  4.01it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:   1%|          | 32/5000 [00:07<20:39,  4.01it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:   1%|▏         | 64/5000 [00:08<09:02,  9.10it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:   1%|▏         | 64/5000 [00:08<09:02,  9.10it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:   2%|▏         | 96/5000 [00:08<05:18, 15.40it/s]
r_training: 0.3566, rewards: 0.1000, total_rewards: 0.9434:   2%|▏         | 96/5000 [00:08<05:18, 15.40it/s]
r_training: 0.3566, rewards: 0.1000, total_rewards: 0.9434:   3%|▎         | 128/5000 [00:09<03:33, 22.84it/s]
r_training: 0.3323, rewards: 0.1000, total_rewards: 0.9434:   3%|▎         | 128/5000 [00:09<03:33, 22.84it/s]
r_training: 0.3323, rewards: 0.1000, total_rewards: 0.9434:   3%|▎         | 160/5000 [00:09<02:34, 31.38it/s]
r_training: 0.3445, rewards: 0.1000, total_rewards: 0.9434:   3%|▎         | 160/5000 [00:09<02:34, 31.38it/s]
r_training: 0.3445, rewards: 0.1000, total_rewards: 0.9434:   4%|▍         | 192/5000 [00:09<01:57, 40.89it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:   4%|▍         | 192/5000 [00:09<01:57, 40.89it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:   4%|▍         | 224/5000 [00:10<01:36, 49.41it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434:   4%|▍         | 224/5000 [00:10<01:36, 49.41it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434:   5%|▌         | 256/5000 [00:10<01:21, 58.17it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434:   5%|▌         | 256/5000 [00:10<01:21, 58.17it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434:   6%|▌         | 288/5000 [00:10<01:12, 65.07it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:   6%|▌         | 288/5000 [00:10<01:12, 65.07it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:   6%|▋         | 320/5000 [00:11<01:05, 71.61it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:   6%|▋         | 320/5000 [00:11<01:05, 71.61it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:   7%|▋         | 352/5000 [00:11<01:01, 76.03it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434:   7%|▋         | 352/5000 [00:11<01:01, 76.03it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434:   8%|▊         | 384/5000 [00:11<00:57, 80.71it/s]
r_training: 0.3505, rewards: 0.1000, total_rewards: 0.9434:   8%|▊         | 384/5000 [00:11<00:57, 80.71it/s]
r_training: 0.3505, rewards: 0.1000, total_rewards: 0.9434:   8%|▊         | 416/5000 [00:12<00:55, 82.38it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:   8%|▊         | 416/5000 [00:12<00:55, 82.38it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:   9%|▉         | 448/5000 [00:12<00:52, 86.92it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:   9%|▉         | 448/5000 [00:12<00:52, 86.92it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  10%|▉         | 480/5000 [00:12<00:51, 87.57it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:  10%|▉         | 480/5000 [00:12<00:51, 87.57it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:  10%|█         | 512/5000 [00:13<00:52, 86.28it/s]
r_training: 0.3566, rewards: 0.1000, total_rewards: 0.9434:  10%|█         | 512/5000 [00:13<00:52, 86.28it/s]
r_training: 0.3566, rewards: 0.1000, total_rewards: 0.9434:  11%|█         | 544/5000 [00:13<00:51, 86.63it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434:  11%|█         | 544/5000 [00:13<00:51, 86.63it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434:  12%|█▏        | 576/5000 [00:14<00:50, 88.04it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  12%|█▏        | 576/5000 [00:14<00:50, 88.04it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  12%|█▏        | 608/5000 [00:14<00:49, 88.23it/s]
r_training: 0.3596, rewards: 0.1000, total_rewards: 0.9434:  12%|█▏        | 608/5000 [00:14<00:49, 88.23it/s]
r_training: 0.3596, rewards: 0.1000, total_rewards: 0.9434:  13%|█▎        | 640/5000 [00:14<00:49, 88.38it/s]
r_training: 0.3718, rewards: 0.1000, total_rewards: 0.9434:  13%|█▎        | 640/5000 [00:14<00:49, 88.38it/s]
r_training: 0.3718, rewards: 0.1000, total_rewards: 0.9434:  13%|█▎        | 672/5000 [00:15<00:47, 90.20it/s]
r_training: 0.3566, rewards: 0.1000, total_rewards: 0.9434:  13%|█▎        | 672/5000 [00:15<00:47, 90.20it/s]
r_training: 0.3566, rewards: 0.1000, total_rewards: 0.9434:  14%|█▍        | 704/5000 [00:15<00:48, 88.22it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434:  14%|█▍        | 704/5000 [00:15<00:48, 88.22it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434:  15%|█▍        | 736/5000 [00:15<00:48, 88.05it/s]
r_training: 0.3960, rewards: 0.1000, total_rewards: 0.9434:  15%|█▍        | 736/5000 [00:15<00:48, 88.05it/s]
r_training: 0.3960, rewards: 0.1000, total_rewards: 0.9434:  15%|█▌        | 768/5000 [00:16<00:47, 88.59it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  15%|█▌        | 768/5000 [00:16<00:47, 88.59it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  16%|█▌        | 800/5000 [00:16<00:46, 89.47it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434:  16%|█▌        | 800/5000 [00:16<00:46, 89.47it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434:  17%|█▋        | 832/5000 [00:16<00:47, 88.51it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434:  17%|█▋        | 832/5000 [00:16<00:47, 88.51it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434:  17%|█▋        | 864/5000 [00:17<00:45, 91.01it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  17%|█▋        | 864/5000 [00:17<00:45, 91.01it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  18%|█▊        | 896/5000 [00:17<00:46, 89.03it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  18%|█▊        | 896/5000 [00:17<00:46, 89.03it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  19%|█▊        | 928/5000 [00:17<00:46, 88.35it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434:  19%|█▊        | 928/5000 [00:17<00:46, 88.35it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434:  19%|█▉        | 960/5000 [00:18<00:45, 88.13it/s]
r_training: 0.3292, rewards: 0.1000, total_rewards: 0.9434:  19%|█▉        | 960/5000 [00:18<00:45, 88.13it/s]
r_training: 0.3292, rewards: 0.1000, total_rewards: 0.9434:  20%|█▉        | 992/5000 [00:18<00:44, 89.57it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434:  20%|█▉        | 992/5000 [00:18<00:44, 89.57it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434:  20%|██        | 1024/5000 [00:19<00:45, 88.35it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  20%|██        | 1024/5000 [00:19<00:45, 88.35it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  21%|██        | 1056/5000 [00:19<00:45, 87.62it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434:  21%|██        | 1056/5000 [00:19<00:45, 87.62it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434:  22%|██▏       | 1088/5000 [00:19<00:42, 91.60it/s]
r_training: 0.3899, rewards: 0.1000, total_rewards: 0.9434:  22%|██▏       | 1088/5000 [00:19<00:42, 91.60it/s]
r_training: 0.3899, rewards: 0.1000, total_rewards: 0.9434:  22%|██▏       | 1120/5000 [00:20<00:41, 93.01it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  22%|██▏       | 1120/5000 [00:20<00:41, 93.01it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  23%|██▎       | 1152/5000 [00:20<00:40, 95.24it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  23%|██▎       | 1152/5000 [00:20<00:40, 95.24it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  24%|██▎       | 1184/5000 [00:20<00:41, 92.75it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:  24%|██▎       | 1184/5000 [00:20<00:41, 92.75it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:  24%|██▍       | 1216/5000 [00:21<00:41, 90.23it/s]
r_training: 0.3445, rewards: 0.1000, total_rewards: 0.9434:  24%|██▍       | 1216/5000 [00:21<00:41, 90.23it/s]
r_training: 0.3445, rewards: 0.1000, total_rewards: 0.9434:  25%|██▍       | 1248/5000 [00:21<00:42, 87.78it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  25%|██▍       | 1248/5000 [00:21<00:42, 87.78it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  26%|██▌       | 1280/5000 [00:21<00:42, 87.94it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434:  26%|██▌       | 1280/5000 [00:21<00:42, 87.94it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434:  26%|██▌       | 1312/5000 [00:22<00:41, 88.97it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  26%|██▌       | 1312/5000 [00:22<00:41, 88.97it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  27%|██▋       | 1344/5000 [00:22<00:40, 90.44it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  27%|██▋       | 1344/5000 [00:22<00:40, 90.44it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  28%|██▊       | 1376/5000 [00:22<00:40, 88.57it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  28%|██▊       | 1376/5000 [00:22<00:40, 88.57it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  28%|██▊       | 1408/5000 [00:23<00:40, 88.32it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:  28%|██▊       | 1408/5000 [00:23<00:40, 88.32it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:  29%|██▉       | 1440/5000 [00:23<00:40, 88.41it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434:  29%|██▉       | 1440/5000 [00:23<00:40, 88.41it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434:  29%|██▉       | 1472/5000 [00:24<00:40, 88.02it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  29%|██▉       | 1472/5000 [00:24<00:40, 88.02it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  30%|███       | 1504/5000 [00:24<00:40, 87.23it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  30%|███       | 1504/5000 [00:24<00:40, 87.23it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  31%|███       | 1536/5000 [00:24<00:39, 86.72it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  31%|███       | 1536/5000 [00:24<00:39, 86.72it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  31%|███▏      | 1568/5000 [00:25<00:39, 87.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  31%|███▏      | 1568/5000 [00:25<00:39, 87.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  32%|███▏      | 1600/5000 [00:25<00:39, 86.43it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434:  32%|███▏      | 1600/5000 [00:25<00:39, 86.43it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434:  33%|███▎      | 1632/5000 [00:25<00:38, 86.66it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  33%|███▎      | 1632/5000 [00:25<00:38, 86.66it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  33%|███▎      | 1664/5000 [00:26<00:38, 87.17it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  33%|███▎      | 1664/5000 [00:26<00:38, 87.17it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  34%|███▍      | 1696/5000 [00:26<00:37, 89.14it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434:  34%|███▍      | 1696/5000 [00:26<00:37, 89.14it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434:  35%|███▍      | 1728/5000 [00:26<00:36, 89.07it/s]
r_training: 0.3718, rewards: 0.1000, total_rewards: 0.9434:  35%|███▍      | 1728/5000 [00:26<00:36, 89.07it/s]
r_training: 0.3718, rewards: 0.1000, total_rewards: 0.9434:  35%|███▌      | 1760/5000 [00:27<00:35, 91.73it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434:  35%|███▌      | 1760/5000 [00:27<00:35, 91.73it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434:  36%|███▌      | 1792/5000 [00:27<00:35, 91.36it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  36%|███▌      | 1792/5000 [00:27<00:35, 91.36it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  36%|███▋      | 1824/5000 [00:28<00:35, 88.62it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  36%|███▋      | 1824/5000 [00:28<00:35, 88.62it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  37%|███▋      | 1856/5000 [00:28<00:35, 88.46it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  37%|███▋      | 1856/5000 [00:28<00:35, 88.46it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  38%|███▊      | 1888/5000 [00:28<00:33, 92.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  38%|███▊      | 1888/5000 [00:28<00:33, 92.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  38%|███▊      | 1920/5000 [00:29<00:32, 94.75it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  38%|███▊      | 1920/5000 [00:29<00:32, 94.75it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  39%|███▉      | 1952/5000 [00:29<00:32, 94.41it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434:  39%|███▉      | 1952/5000 [00:29<00:32, 94.41it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434:  40%|███▉      | 1984/5000 [00:29<00:31, 94.62it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  40%|███▉      | 1984/5000 [00:29<00:31, 94.62it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  40%|████      | 2016/5000 [00:30<00:31, 94.09it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  40%|████      | 2016/5000 [00:30<00:31, 94.09it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  41%|████      | 2048/5000 [00:30<00:31, 93.61it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  41%|████      | 2048/5000 [00:30<00:31, 93.61it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  42%|████▏     | 2080/5000 [00:30<00:31, 94.07it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  42%|████▏     | 2080/5000 [00:30<00:31, 94.07it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  42%|████▏     | 2112/5000 [00:31<00:30, 95.79it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  42%|████▏     | 2112/5000 [00:31<00:30, 95.79it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  43%|████▎     | 2144/5000 [00:31<00:31, 91.76it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  43%|████▎     | 2144/5000 [00:31<00:31, 91.76it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  44%|████▎     | 2176/5000 [00:31<00:30, 93.25it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  44%|████▎     | 2176/5000 [00:31<00:30, 93.25it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  44%|████▍     | 2208/5000 [00:32<00:29, 95.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  44%|████▍     | 2208/5000 [00:32<00:29, 95.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  45%|████▍     | 2240/5000 [00:32<00:29, 93.00it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434:  45%|████▍     | 2240/5000 [00:32<00:29, 93.00it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434:  45%|████▌     | 2272/5000 [00:32<00:28, 94.99it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  45%|████▌     | 2272/5000 [00:32<00:28, 94.99it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  46%|████▌     | 2304/5000 [00:33<00:28, 95.30it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  46%|████▌     | 2304/5000 [00:33<00:28, 95.30it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  47%|████▋     | 2336/5000 [00:33<00:28, 94.29it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  47%|████▋     | 2336/5000 [00:33<00:28, 94.29it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  47%|████▋     | 2368/5000 [00:33<00:28, 92.67it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  47%|████▋     | 2368/5000 [00:33<00:28, 92.67it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  48%|████▊     | 2400/5000 [00:34<00:28, 91.16it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  48%|████▊     | 2400/5000 [00:34<00:28, 91.16it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  49%|████▊     | 2432/5000 [00:34<00:28, 91.36it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434:  49%|████▊     | 2432/5000 [00:34<00:28, 91.36it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434:  49%|████▉     | 2464/5000 [00:34<00:27, 92.58it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  49%|████▉     | 2464/5000 [00:34<00:27, 92.58it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  50%|████▉     | 2496/5000 [00:35<00:27, 90.30it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  50%|████▉     | 2496/5000 [00:35<00:27, 90.30it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  51%|█████     | 2528/5000 [00:35<00:27, 89.95it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  51%|█████     | 2528/5000 [00:35<00:27, 89.95it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  51%|█████     | 2560/5000 [00:35<00:27, 87.80it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  51%|█████     | 2560/5000 [00:35<00:27, 87.80it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  52%|█████▏    | 2592/5000 [00:36<00:26, 89.57it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  52%|█████▏    | 2592/5000 [00:36<00:26, 89.57it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  52%|█████▏    | 2624/5000 [00:36<00:26, 90.43it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:  52%|█████▏    | 2624/5000 [00:36<00:26, 90.43it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:  53%|█████▎    | 2656/5000 [00:37<00:26, 88.07it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:  53%|█████▎    | 2656/5000 [00:37<00:26, 88.07it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434:  54%|█████▍    | 2688/5000 [00:37<00:26, 86.12it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  54%|█████▍    | 2688/5000 [00:37<00:26, 86.12it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  54%|█████▍    | 2720/5000 [00:37<00:26, 86.06it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  54%|█████▍    | 2720/5000 [00:37<00:26, 86.06it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  55%|█████▌    | 2752/5000 [00:38<00:25, 88.22it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  55%|█████▌    | 2752/5000 [00:38<00:25, 88.22it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  56%|█████▌    | 2784/5000 [00:38<00:24, 90.21it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  56%|█████▌    | 2784/5000 [00:38<00:24, 90.21it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  56%|█████▋    | 2816/5000 [00:38<00:23, 92.05it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  56%|█████▋    | 2816/5000 [00:38<00:23, 92.05it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  57%|█████▋    | 2848/5000 [00:39<00:24, 89.07it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434:  57%|█████▋    | 2848/5000 [00:39<00:24, 89.07it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434:  58%|█████▊    | 2880/5000 [00:39<00:23, 90.04it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  58%|█████▊    | 2880/5000 [00:39<00:23, 90.04it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  58%|█████▊    | 2912/5000 [00:39<00:22, 92.89it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  58%|█████▊    | 2912/5000 [00:39<00:22, 92.89it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  59%|█████▉    | 2944/5000 [00:40<00:22, 90.57it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  59%|█████▉    | 2944/5000 [00:40<00:22, 90.57it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  60%|█████▉    | 2976/5000 [00:40<00:22, 89.29it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  60%|█████▉    | 2976/5000 [00:40<00:22, 89.29it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  60%|██████    | 3008/5000 [00:40<00:21, 91.42it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 0.9434:  60%|██████    | 3008/5000 [00:40<00:21, 91.42it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 0.9434:  61%|██████    | 3040/5000 [00:41<00:21, 90.81it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  61%|██████    | 3040/5000 [00:41<00:21, 90.81it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  61%|██████▏   | 3072/5000 [00:41<00:21, 91.65it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434:  61%|██████▏   | 3072/5000 [00:41<00:21, 91.65it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434:  62%|██████▏   | 3104/5000 [00:41<00:20, 93.16it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  62%|██████▏   | 3104/5000 [00:41<00:20, 93.16it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  63%|██████▎   | 3136/5000 [00:42<00:19, 94.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  63%|██████▎   | 3136/5000 [00:42<00:19, 94.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  63%|██████▎   | 3168/5000 [00:42<00:19, 93.29it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  63%|██████▎   | 3168/5000 [00:42<00:19, 93.29it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434:  64%|██████▍   | 3200/5000 [00:42<00:18, 94.90it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  64%|██████▍   | 3200/5000 [00:42<00:18, 94.90it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434:  65%|██████▍   | 3232/5000 [00:50<02:12, 13.38it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556:  65%|██████▍   | 3232/5000 [00:50<02:12, 13.38it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556:  65%|██████▌   | 3264/5000 [00:50<01:37, 17.83it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  65%|██████▌   | 3264/5000 [00:50<01:37, 17.83it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  66%|██████▌   | 3296/5000 [00:50<01:13, 23.32it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  66%|██████▌   | 3296/5000 [00:50<01:13, 23.32it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  67%|██████▋   | 3328/5000 [00:51<00:55, 29.87it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  67%|██████▋   | 3328/5000 [00:51<00:55, 29.87it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  67%|██████▋   | 3360/5000 [00:51<00:44, 36.95it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  67%|██████▋   | 3360/5000 [00:51<00:44, 36.95it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  68%|██████▊   | 3392/5000 [00:52<00:36, 44.30it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  68%|██████▊   | 3392/5000 [00:52<00:36, 44.30it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  68%|██████▊   | 3424/5000 [00:52<00:30, 51.78it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  68%|██████▊   | 3424/5000 [00:52<00:30, 51.78it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  69%|██████▉   | 3456/5000 [00:52<00:26, 58.47it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  69%|██████▉   | 3456/5000 [00:52<00:26, 58.47it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  70%|██████▉   | 3488/5000 [00:53<00:23, 64.95it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556:  70%|██████▉   | 3488/5000 [00:53<00:23, 64.95it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556:  70%|███████   | 3520/5000 [00:53<00:21, 69.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  70%|███████   | 3520/5000 [00:53<00:21, 69.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  71%|███████   | 3552/5000 [00:54<00:19, 72.87it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  71%|███████   | 3552/5000 [00:54<00:19, 72.87it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  72%|███████▏  | 3584/5000 [00:54<00:18, 75.22it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  72%|███████▏  | 3584/5000 [00:54<00:18, 75.22it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  72%|███████▏  | 3616/5000 [00:54<00:17, 78.72it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  72%|███████▏  | 3616/5000 [00:54<00:17, 78.72it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  73%|███████▎  | 3648/5000 [00:55<00:16, 80.90it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556:  73%|███████▎  | 3648/5000 [00:55<00:16, 80.90it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556:  74%|███████▎  | 3680/5000 [00:55<00:15, 82.97it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  74%|███████▎  | 3680/5000 [00:55<00:15, 82.97it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  74%|███████▍  | 3712/5000 [00:55<00:15, 84.10it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  74%|███████▍  | 3712/5000 [00:55<00:15, 84.10it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  75%|███████▍  | 3744/5000 [00:56<00:14, 83.85it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  75%|███████▍  | 3744/5000 [00:56<00:14, 83.85it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  76%|███████▌  | 3776/5000 [00:56<00:14, 84.22it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  76%|███████▌  | 3776/5000 [00:56<00:14, 84.22it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  76%|███████▌  | 3808/5000 [00:56<00:14, 84.97it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  76%|███████▌  | 3808/5000 [00:56<00:14, 84.97it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  77%|███████▋  | 3840/5000 [00:57<00:13, 85.76it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  77%|███████▋  | 3840/5000 [00:57<00:13, 85.76it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  77%|███████▋  | 3872/5000 [00:57<00:12, 88.48it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  77%|███████▋  | 3872/5000 [00:57<00:12, 88.48it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  78%|███████▊  | 3904/5000 [00:58<00:12, 89.26it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  78%|███████▊  | 3904/5000 [00:58<00:12, 89.26it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  79%|███████▊  | 3936/5000 [00:58<00:11, 90.04it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  79%|███████▊  | 3936/5000 [00:58<00:11, 90.04it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  79%|███████▉  | 3968/5000 [00:58<00:11, 89.77it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556:  79%|███████▉  | 3968/5000 [00:58<00:11, 89.77it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556:  80%|████████  | 4000/5000 [00:59<00:11, 89.92it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 5.5556:  80%|████████  | 4000/5000 [00:59<00:11, 89.92it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 5.5556:  81%|████████  | 4032/5000 [00:59<00:10, 88.32it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  81%|████████  | 4032/5000 [00:59<00:10, 88.32it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  81%|████████▏ | 4064/5000 [00:59<00:10, 86.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  81%|████████▏ | 4064/5000 [00:59<00:10, 86.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  82%|████████▏ | 4096/5000 [01:00<00:10, 84.34it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  82%|████████▏ | 4096/5000 [01:00<00:10, 84.34it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  83%|████████▎ | 4128/5000 [01:00<00:10, 85.10it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 5.5556:  83%|████████▎ | 4128/5000 [01:00<00:10, 85.10it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 5.5556:  83%|████████▎ | 4160/5000 [01:00<00:09, 87.36it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  83%|████████▎ | 4160/5000 [01:00<00:09, 87.36it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  84%|████████▍ | 4192/5000 [01:01<00:08, 90.26it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  84%|████████▍ | 4192/5000 [01:01<00:08, 90.26it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  84%|████████▍ | 4224/5000 [01:01<00:08, 92.19it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  84%|████████▍ | 4224/5000 [01:01<00:08, 92.19it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  85%|████████▌ | 4256/5000 [01:01<00:08, 92.86it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 5.5556:  85%|████████▌ | 4256/5000 [01:01<00:08, 92.86it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 5.5556:  86%|████████▌ | 4288/5000 [01:02<00:07, 90.60it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  86%|████████▌ | 4288/5000 [01:02<00:07, 90.60it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556:  86%|████████▋ | 4320/5000 [01:02<00:07, 92.06it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  86%|████████▋ | 4320/5000 [01:02<00:07, 92.06it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  87%|████████▋ | 4352/5000 [01:03<00:06, 93.60it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  87%|████████▋ | 4352/5000 [01:03<00:06, 93.60it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  88%|████████▊ | 4384/5000 [01:03<00:06, 92.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  88%|████████▊ | 4384/5000 [01:03<00:06, 92.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  88%|████████▊ | 4416/5000 [01:03<00:06, 92.21it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  88%|████████▊ | 4416/5000 [01:03<00:06, 92.21it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  89%|████████▉ | 4448/5000 [01:04<00:05, 92.90it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  89%|████████▉ | 4448/5000 [01:04<00:05, 92.90it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  90%|████████▉ | 4480/5000 [01:04<00:05, 89.09it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  90%|████████▉ | 4480/5000 [01:04<00:05, 89.09it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  90%|█████████ | 4512/5000 [01:04<00:05, 87.57it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 5.5556:  90%|█████████ | 4512/5000 [01:04<00:05, 87.57it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 5.5556:  91%|█████████ | 4544/5000 [01:05<00:05, 87.44it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 5.5556:  91%|█████████ | 4544/5000 [01:05<00:05, 87.44it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 5.5556:  92%|█████████▏| 4576/5000 [01:05<00:04, 86.47it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  92%|█████████▏| 4576/5000 [01:05<00:04, 86.47it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  92%|█████████▏| 4608/5000 [01:05<00:04, 85.84it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  92%|█████████▏| 4608/5000 [01:05<00:04, 85.84it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  93%|█████████▎| 4640/5000 [01:06<00:04, 85.98it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  93%|█████████▎| 4640/5000 [01:06<00:04, 85.98it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  93%|█████████▎| 4672/5000 [01:06<00:03, 87.97it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 5.5556:  93%|█████████▎| 4672/5000 [01:06<00:03, 87.97it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 5.5556:  94%|█████████▍| 4704/5000 [01:07<00:03, 87.37it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 5.5556:  94%|█████████▍| 4704/5000 [01:07<00:03, 87.37it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 5.5556:  95%|█████████▍| 4736/5000 [01:07<00:02, 90.46it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  95%|█████████▍| 4736/5000 [01:07<00:02, 90.46it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  95%|█████████▌| 4768/5000 [01:07<00:02, 93.67it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  95%|█████████▌| 4768/5000 [01:07<00:02, 93.67it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  96%|█████████▌| 4800/5000 [01:08<00:02, 93.89it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556:  96%|█████████▌| 4800/5000 [01:08<00:02, 93.89it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556:  97%|█████████▋| 4832/5000 [01:08<00:01, 91.12it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  97%|█████████▋| 4832/5000 [01:08<00:01, 91.12it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  97%|█████████▋| 4864/5000 [01:08<00:01, 89.91it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  97%|█████████▋| 4864/5000 [01:08<00:01, 89.91it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  98%|█████████▊| 4896/5000 [01:09<00:01, 88.75it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  98%|█████████▊| 4896/5000 [01:09<00:01, 88.75it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  99%|█████████▊| 4928/5000 [01:09<00:00, 90.36it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 5.5556:  99%|█████████▊| 4928/5000 [01:09<00:00, 90.36it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 5.5556:  99%|█████████▉| 4960/5000 [01:09<00:00, 93.68it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556:  99%|█████████▉| 4960/5000 [01:09<00:00, 93.68it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 100%|█████████▉| 4992/5000 [01:10<00:00, 91.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 100%|█████████▉| 4992/5000 [01:10<00:00, 91.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: : 5024it [01:10, 90.51it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 5.5556: : 5024it [01:10, 90.51it/s]

我们现在可以使用结果快速检查 CSV。

def print_csv_files_in_folder(folder_path):
    """
    Find all CSV files in a folder and prints the first 10 lines of each file.

    Args:
        folder_path (str): The relative path to the folder.

    """
    csv_files = []
    output_str = ""
    for dirpath, _, filenames in os.walk(folder_path):
        for file in filenames:
            if file.endswith(".csv"):
                csv_files.append(os.path.join(dirpath, file))
    for csv_file in csv_files:
        output_str += f"File: {csv_file}\n"
        with open(csv_file, "r") as f:
            for i, line in enumerate(f):
                if i == 10:
                    break
                output_str += line.strip() + "\n"
        output_str += "\n"
    print(output_str)


print_csv_files_in_folder(logger.experiment.log_dir)
File: /tmp/tmp16y9hbib/dqn_exp_a0ee01f2-90c8-11ef-a49b-0242ac110002/scalars/r_training.csv
512,0.3566153347492218
1024,0.39912936091423035
1536,0.39912936091423035
2048,0.39912936091423035
2560,0.42945271730422974
3072,0.40213119983673096
3584,0.39912933111190796
4096,0.42945271730422974
4608,0.42945271730422974

File: /tmp/tmp16y9hbib/dqn_exp_a0ee01f2-90c8-11ef-a49b-0242ac110002/scalars/optim_steps.csv
512,128.0
1024,256.0
1536,384.0
2048,512.0
2560,640.0
3072,768.0
3584,896.0
4096,1024.0
4608,1152.0

File: /tmp/tmp16y9hbib/dqn_exp_a0ee01f2-90c8-11ef-a49b-0242ac110002/scalars/loss.csv
512,0.47876793146133423
1024,0.18667784333229065
1536,0.1948033571243286
2048,0.22345909476280212
2560,0.2145865112543106
3072,0.47586697340011597
3584,0.28343674540519714
4096,0.3203103542327881
4608,0.3053428530693054

File: /tmp/tmp16y9hbib/dqn_exp_a0ee01f2-90c8-11ef-a49b-0242ac110002/scalars/grad_norm_0.csv
512,5.5816755294799805
1024,2.9089717864990234
1536,3.4687838554382324
2048,2.8756051063537598
2560,2.7815587520599365
3072,6.685841083526611
3584,3.793360948562622
4096,3.469670295715332
4608,3.317387104034424

File: /tmp/tmp16y9hbib/dqn_exp_a0ee01f2-90c8-11ef-a49b-0242ac110002/scalars/rewards.csv
3232,0.10000000894069672

File: /tmp/tmp16y9hbib/dqn_exp_a0ee01f2-90c8-11ef-a49b-0242ac110002/scalars/total_rewards.csv
3232,5.555555820465088

结论和可能的改进

在本教程中,我们学习了:

  • 如何编写 Trainer,包括构建其组件和注册 他们在教练中;

  • 如何对 DQN 算法进行编码,包括如何创建选取 将具有最高值的动作上QValueNetwork;

  • 如何构建多进程数据收集器;

本教程的可能改进可能包括:

  • 也可以使用 Prioritized replay buffer。这将得到一个 对于具有最差值准确度的样本,优先级更高。 在文档的重放缓冲区部分了解更多信息。

  • 分布损失(有关更多信息,请参阅)。

  • 更花哨的探索技术,例如图层等。

脚本总运行时间:(2 分 40.957 秒)

估计内存使用量:1267 MB

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源