注意
转到末尾 以下载完整示例代码。
TorchRL 训练器:DQN 示例¶
作者: Vincent Moens
TorchRL 提供了一个通用的 Trainer 类来处理
你的训练循环。训练器执行一个嵌套循环,其中外层循环
是数据收集,而内层循环消耗这些数据或从重放缓冲区检索的一些数据
以训练模型。
在这个训练循环中的各个点上,可以附加钩子并在给定间隔处执行。
在本教程中,我们将使用 trainer 类从零开始训练一个 DQN 算法,以解决 CartPole 任务。
主要收获:
构建一个训练器,包含其核心组件:数据收集器、损失模块、回放缓冲区和优化器。
为训练器添加挂钩,例如日志记录器、目标网络更新器等。
训练器是完全可定制的,并提供了一整套功能。
教程围绕其构建组织。
我们将首先详细介绍如何构建库中的每个组件,
然后使用 Trainer 类将这些组件组合在一起。
在这个过程中,我们还将关注库的其他一些方面:
如何在 TorchRL 中构建环境,包括各类变换(例如:数据归一化、帧拼接、图像缩放以及转为灰度图)和并行执行。与我们在 DDPG 教程中所做不同,此处我们将对像素进行归一化,而非对状态向量进行归一化。
如何设计一个
QValueActor对象,即一个演员 估计动作值并选择具有最高 估计回报的动作;如何高效地从环境中收集数据并将其存储在经验回放缓冲区中;
如何使用多步法,一种针对离策略算法的简单预处理步骤;
最后是如何评估你的模型。
前提条件:我们建议您首先通过PPO 教程熟悉 torchrl。
DQN¶
深度Q网络(深度Q学习)是深度强化学习领域的开创性工作。
从高层次来看,该算法非常简单:Q学习的目标是学习一张状态-动作值表,使得当遇到任意特定状态时,我们只需查找值最大的动作,即可知道应选择哪个动作。这种简单的设定要求动作和状态必须是离散的,否则无法构建查找表。
DQN 使用一个神经网络,将状态-动作空间映射到值(标量)空间,从而分摊存储和探索所有可能状态-动作组合所需的成本:如果某个状态在过去未曾出现过,我们仍可将其与各种可用动作一同输入神经网络,并为每个可用动作获得一个插值得到的值。
我们将解决经典的倒立摆控制问题。该环境取自 Gymnasium 文档:

我们并非旨在提供该算法的最先进(SOTA)实现,而是希望在该算法的背景下,高层次地展示 TorchRL 的各项功能。
import os
import uuid
import torch
from torch import nn
from torchrl.collectors import MultiaSyncDataCollector, SyncDataCollector
from torchrl.data import LazyMemmapStorage, MultiStep, TensorDictReplayBuffer
from torchrl.envs import (
EnvCreator,
ExplorationType,
ParallelEnv,
RewardScaling,
StepCounter,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.transforms import (
CatFrames,
Compose,
GrayScale,
ObservationNorm,
Resize,
ToTensorImage,
TransformedEnv,
)
from torchrl.modules import DuelingCnnDQNet, EGreedyModule, QValueActor
from torchrl.objectives import DQNLoss, SoftUpdate
from torchrl.record.loggers.csv import CSVLogger
from torchrl.trainers import (
LogReward,
Recorder,
ReplayBufferTrainer,
Trainer,
UpdateWeights,
)
def is_notebook() -> bool:
try:
shell = get_ipython().__class__.__name__
if shell == "ZMQInteractiveShell":
return True # Jupyter notebook or qtconsole
elif shell == "TerminalInteractiveShell":
return False # Terminal running IPython
else:
return False # Other type (?)
except NameError:
return False # Probably standard Python interpreter
让我们开始准备算法所需的各个组件:
一个环境;
一项策略(以及我们归类在“模型”下的相关模块);
一个数据收集器,用于让策略在环境中运行并提供训练数据;
用于存储训练数据的重放缓冲区;
一个损失模块,用于计算目标函数,以训练我们的策略,从而最大化回报;
一个优化器,它根据我们的损失执行参数更新。
其他模块包括一个日志记录器、一个记录器(以“eval”模式执行策略)以及一个目标网络更新器。当所有这些组件都就位后,很容易看出在训练脚本中可能误放或误用某个组件。而训练器(trainer)的作用正是为你统筹协调所有这些组件!
构建环境¶
首先,我们编写一个辅助函数来输出一个环境。通常情况下, “原始”环境可能过于简单,无法在实际中直接使用,因此我们需要 进行一些数据转换,以便将环境的输出提供给策略使用。
我们将使用五个变换:
StepCounter用于计算每个轨迹中的步骤数;ToTensorImage将会将一个[W, H, C]uint8 张量转换为一个浮点数张量,在[0, 1]空间中,形状为[C, W, H];RewardScaling以减少返回值的规模;GrayScale将把我们的图像转换为灰度;Resize将将图像调整为64x64格式;CatFrames将按通道维度将任意数量的连续帧 (N=4) 拼接成一个张量。这对于单个图像不携带车杆运动信息的情况非常有用。需要一些关于过去观察和动作的记忆,可以通过循环神经网络或使用帧堆栈来实现。ObservationNorm这将根据一些自定义的汇总统计信息对我们的观测值进行归一化。
实际上,我们的环境构建器有两个参数:
parallel: 确定是否需要并行运行多个环境。我们在ParallelEnv之后堆叠变换以利用设备上操作的向量化,尽管这在技术上适用于每个单独环境及其自己的变换集。obs_norm_sd将包含ObservationNorm变换的归一化常数。
def make_env(
parallel=False,
obs_norm_sd=None,
num_workers=1,
):
if obs_norm_sd is None:
obs_norm_sd = {"standard_normal": True}
if parallel:
def maker():
return GymEnv(
"CartPole-v1",
from_pixels=True,
pixels_only=True,
device=device,
)
base_env = ParallelEnv(
num_workers,
EnvCreator(maker),
# Don't create a sub-process if we have only one worker
serial_for_single=True,
mp_start_method=mp_context,
)
else:
base_env = GymEnv(
"CartPole-v1",
from_pixels=True,
pixels_only=True,
device=device,
)
env = TransformedEnv(
base_env,
Compose(
StepCounter(), # to count the steps of each trajectory
ToTensorImage(),
RewardScaling(loc=0.0, scale=0.1),
GrayScale(),
Resize(64, 64),
CatFrames(4, in_keys=["pixels"], dim=-3),
ObservationNorm(in_keys=["pixels"], **obs_norm_sd),
),
)
return env
计算归一化常数¶
为了归一化图像,我们不想独立地归一化每个像素
使用完整的[C, W, H]归一化掩码,而是使用更简单的[C, 1, 1]
形状的归一化常数集(位置和缩放参数)。
我们将使用reduce_dim参数
的init_stats()来指示哪些
维度需要减少,并使用keep_dims参数以确保
在过程中不是所有维度都消失:
def get_norm_stats():
test_env = make_env()
test_env.transform[-1].init_stats(
num_iter=1000, cat_dim=0, reduce_dim=[-1, -2, -4], keep_dims=(-1, -2)
)
obs_norm_sd = test_env.transform[-1].state_dict()
# let's check that normalizing constants have a size of ``[C, 1, 1]`` where
# ``C=4`` (because of :class:`~torchrl.envs.CatFrames`).
print("state dict of the observation norm:", obs_norm_sd)
test_env.close()
del test_env
return obs_norm_sd
构建模型(深度Q网络)¶
以下函数构建了一个 DuelingCnnDQNet
对象,这是一个简单的CNN后接一个两层的MLP。这里使用的唯一技巧是
动作值(即左和右动作值)是通过计算
其中 \(\mathbb{v}\) 是我们的动作值向量, \(b\) 是一个 \(\mathbb{R}^n \rightarrow 1\) 函数 而 \(v\) 是一个 \(\mathbb{R}^n \rightarrow \mathbb{R}^m\) 函数,对于 \(n = \# obs\) 和 \(m = \# actions\)。
我们的网络被包裹在一个QValueActor中,
它将读取状态-动作值,
选择具有最大值的那个并将其结果写入输入tensordict.TensorDict。
def make_model(dummy_env):
cnn_kwargs = {
"num_cells": [32, 64, 64],
"kernel_sizes": [6, 4, 3],
"strides": [2, 2, 1],
"activation_class": nn.ELU,
# This can be used to reduce the size of the last layer of the CNN
# "squeeze_output": True,
# "aggregator_class": nn.AdaptiveAvgPool2d,
# "aggregator_kwargs": {"output_size": (1, 1)},
}
mlp_kwargs = {
"depth": 2,
"num_cells": [
64,
64,
],
"activation_class": nn.ELU,
}
net = DuelingCnnDQNet(
dummy_env.action_spec.shape[-1], 1, cnn_kwargs, mlp_kwargs
).to(device)
net.value[-1].bias.data.fill_(init_bias)
actor = QValueActor(net, in_keys=["pixels"], spec=dummy_env.action_spec).to(device)
# init actor: because the model is composed of lazy conv/linear layers,
# we must pass a fake batch of data through it to instantiate them.
tensordict = dummy_env.fake_tensordict()
actor(tensordict)
# we join our actor with an EGreedyModule for data collection
exploration_module = EGreedyModule(
spec=dummy_env.action_spec,
annealing_num_steps=total_frames,
eps_init=eps_greedy_val,
eps_end=eps_greedy_val_env,
)
actor_explore = TensorDictSequential(actor, exploration_module)
return actor, actor_explore
收集和存储数据¶
回放缓冲区¶
经验回放池(Replay Buffer)在 DQN 等非策略强化学习(RL)算法中扮演着核心角色。 它们构成了我们在训练过程中进行采样的数据集。
此处,我们将采用常规采样策略,尽管使用优先级重放缓冲区(prioritized RB)可显著提升性能。
我们将存储放在磁盘上,使用
LazyMemmapStorage 类。这种
存储以懒加载方式创建:只有在将第一批数据传递给它时才会实例化。
此存储的唯一要求是:写入时传递给它的数据必须始终具有相同的形状。
def get_replay_buffer(buffer_size, n_optim, batch_size):
replay_buffer = TensorDictReplayBuffer(
batch_size=batch_size,
storage=LazyMemmapStorage(buffer_size),
prefetch=n_optim,
)
return replay_buffer
数据收集器¶
如在 PPO 和 DDPG 中,我们将使用 一个数据收集器作为外部循环中的数据加载器。
我们选择以下配置:将在不同的采集器中同步并行运行一系列并行环境,而这些采集器本身则以异步方式并行运行。
注意
此功能仅在使用Python多进程库的“spawn”启动方法运行代码时可用。如果直接将此教程作为脚本运行(即使用“fork”方法),我们将使用常规的SyncDataCollector。
这种配置的优势在于我们可以平衡批量执行的计算量与我们希望异步执行的部分。我们鼓励读者通过修改收集器的数量(即传递给收集器的环境构造函数的数量)以及每个收集器并行执行的环境数量(由num_workers超参数控制)来实验收集速度的变化。
Collector’s devices are fully parametrizable through the device (general),
policy_device, env_device and storing_device arguments.
The storing_device argument will modify the
location of the data being collected: if the batches that we are gathering
have a considerable size, we may want to store them on a different location
than the device where the computation is happening. For asynchronous data
collectors such as ours, different storing devices mean that the data that
we collect won’t sit on the same device each time, which is something that
out training loop must account for. For simplicity, we set the devices to
the same value for all sub-collectors。
def get_collector(
stats,
num_collectors,
actor_explore,
frames_per_batch,
total_frames,
device,
):
# We can't use nested child processes with mp_start_method="fork"
if is_fork:
cls = SyncDataCollector
env_arg = make_env(parallel=True, obs_norm_sd=stats, num_workers=num_workers)
else:
cls = MultiaSyncDataCollector
env_arg = [
make_env(parallel=True, obs_norm_sd=stats, num_workers=num_workers)
] * num_collectors
data_collector = cls(
env_arg,
policy=actor_explore,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
# this is the default behaviour: the collector runs in ``"random"`` (or explorative) mode
exploration_type=ExplorationType.RANDOM,
# We set the all the devices to be identical. Below is an example of
# heterogeneous devices
device=device,
storing_device=device,
split_trajs=False,
postproc=MultiStep(gamma=gamma, n_steps=5),
)
return data_collector
损失函数¶
构建我们的损失函数非常简单:我们只需将模型和一组超参数传递给 DQNLoss 类即可。
目标参数¶
许多离策略强化学习(RL)算法在估计下一状态或状态-动作对的值时,会使用“目标参数”这一概念。 目标参数是模型参数的滞后副本。由于其预测结果与当前模型配置的预测结果不一致,因此可通过为待估价值设定一个悲观的上界来促进学习。 这是一种非常有效的技巧(称为“双Q学习”),在类似的算法中被广泛采用。
def get_loss_module(actor, gamma):
loss_module = DQNLoss(actor, delay_value=True)
loss_module.make_value_estimator(gamma=gamma)
target_updater = SoftUpdate(loss_module, eps=0.995)
return loss_module, target_updater
超参数¶
我们先从超参数开始。以下设置在实践中效果良好,且算法的性能对这些超参数的细微变化应不敏感。
is_fork = multiprocessing.get_start_method() == "fork"
device = (
torch.device(0)
if torch.cuda.is_available() and not is_fork
else torch.device("cpu")
)
优化器¶
# the learning rate of the optimizer
lr = 2e-3
# weight decay
wd = 1e-5
# the beta parameters of Adam
betas = (0.9, 0.999)
# Optimization steps per batch collected (aka UPD or updates per data)
n_optim = 8
DQN 参数¶
伽马衰减因子
gamma = 0.99
平滑目标网络更新衰减参数。 该参数大致对应于硬目标网络更新的 1/tau 间隔。
tau = 0.02
数据收集和重放缓冲区¶
注意
用于正确训练的值已被注释。
在环境中收集的总帧数。在其他实现中,用户会定义最大回合数(episodes)。 由于我们的数据收集器返回的是包含 N 帧的批次(其中 N 为常量),因此直接按回合数设定上限更为困难。 不过,我们可以通过在收集到指定数量的回合后中断训练循环,轻松实现相同的回合数限制。
total_frames = 5_000 # 500000
用于初始化重放缓冲区的随机帧。
init_random_frames = 100 # 1000
每个批次中收集的帧。
frames_per_batch = 32 # 128
从回放缓冲区中采样的帧在每个优化步骤
batch_size = 32 # 256
回放缓冲区的帧数大小
buffer_size = min(total_frames, 100000)
每个数据收集器中并行运行的环境数量
num_workers = 2 # 8
num_collectors = 2 # 4
环境和探索¶
我们设置了 ε-贪婪探索中 epsilon 因子的初始值和最终值。 由于我们的策略是确定性的,因此探索至关重要:若不进行探索,唯一随机性来源将仅为环境重置。
eps_greedy_val = 0.1
eps_greedy_val_env = 0.005
为加快学习速度,我们将价值网络最后一层的偏置项设为一个预定义值(此操作非必需)
init_bias = 2.0
注意
为了快速呈现教程 total_frames 超参数
被设置为一个非常低的数值。要获得合理的性能,请使用更大的
数值,例如 500000
构建训练器¶
TorchRL的 Trainer 类构造函数接受以下仅关键字参数:
collectorloss_moduleoptimizerlogger: 日志记录器可以是total_frames: 此参数定义了训练器的生命周期。frame_skip: 当使用帧跳过时,收集器必须意识到这一点,以便准确计算收集的帧数等。让训练器意识到这个参数不是强制性的,但有助于在总帧数(预算)固定但帧跳过可变的情况下进行更公平的比较。
stats = get_norm_stats()
test_env = make_env(parallel=False, obs_norm_sd=stats)
# Get model
actor, actor_explore = make_model(test_env)
loss_module, target_net_updater = get_loss_module(actor, gamma)
collector = get_collector(
stats=stats,
num_collectors=num_collectors,
actor_explore=actor_explore,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
device=device,
)
optimizer = torch.optim.Adam(
loss_module.parameters(), lr=lr, weight_decay=wd, betas=betas
)
exp_name = f"dqn_exp_{uuid.uuid1()}"
tmpdir = tempfile.TemporaryDirectory()
logger = CSVLogger(exp_name=exp_name, log_dir=tmpdir.name)
warnings.warn(f"log dir: {logger.experiment.log_dir}")
state dict of the observation norm: OrderedDict([('standard_normal', tensor(True)), ('loc', tensor([[[0.9895]],
[[0.9895]],
[[0.9895]],
[[0.9895]]])), ('scale', tensor([[[0.0737]],
[[0.0737]],
[[0.0737]],
[[0.0737]]]))])
我们可以控制标量数据的记录频率。此处我们将该值设为较低值,因为我们的训练循环较短:
注册钩子¶
注册钩子可以通过两种不同的方式实现:
如果钩子有它,
register()方法是首选。只需要将训练器作为输入提供 并且该钩子将以默认名称在默认位置进行注册。 对于某些钩子,注册可能会相当复杂:ReplayBufferTrainer需要 3 个钩子 (extend,sample和update_priority) 这可能难以实现。
buffer_hook = ReplayBufferTrainer(
get_replay_buffer(buffer_size, n_optim, batch_size=batch_size),
flatten_tensordicts=True,
)
buffer_hook.register(trainer)
weight_updater = UpdateWeights(collector, update_weights_interval=1)
weight_updater.register(trainer)
recorder = Recorder(
record_interval=100, # log every 100 optimization steps
record_frames=1000, # maximum number of frames in the record
frame_skip=1,
policy_exploration=actor_explore,
environment=test_env,
exploration_type=ExplorationType.MODE,
log_keys=[("next", "reward")],
out_keys={("next", "reward"): "rewards"},
log_pbar=True,
)
recorder.register(trainer)
探索模块的epsilon因子也会逐渐减少:
trainer.register_op("post_steps", actor_explore[1].step, frames=frames_per_batch)
任何可调用对象(包括
TrainerHookBase的子类)均可使用register_op()进行注册。 此时,必须显式传入一个位置参数()。该方法能更精确地控制钩子(hook)的位置,但也要求开发者对训练器(Trainer)机制有更深入的理解。 请参阅训练器文档,了解训练器钩子的详细说明。
trainer.register_op("post_optim", target_net_updater.step)
我们也可以记录训练奖励。请注意,对于CartPole来说,这并不是很有意义, 因为奖励总是1。折扣后的奖励总和不是通过获得更高的奖励来最大化, 而是通过让车杆存活更长时间来实现的。 这将反映在进度条中显示的total_rewards值上。
log_reward = LogReward(log_pbar=True)
log_reward.register(trainer)
注意
如果需要,可以将多个优化器链接到训练器。
在这种情况下,每个优化器将与损失字典中的一个字段绑定。
请查看 OptimizerHook 以了解更多信息。
我们已经准备好训练我们的算法了!只需简单调用
trainer.train(),我们就能记录下结果。
trainer.train()
0%| | 0/5000 [00:00<?, ?it/s]
1%| | 32/5000 [00:07<19:22, 4.27it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 1%| | 32/5000 [00:07<19:22, 4.27it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 1%|▏ | 64/5000 [00:07<08:27, 9.73it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 1%|▏ | 64/5000 [00:07<08:27, 9.73it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 2%|▏ | 96/5000 [00:08<04:57, 16.46it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 2%|▏ | 96/5000 [00:08<04:57, 16.46it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 3%|▎ | 128/5000 [00:08<03:19, 24.43it/s]
r_training: 0.3172, rewards: 0.1000, total_rewards: 0.9434: 3%|▎ | 128/5000 [00:08<03:19, 24.43it/s]
r_training: 0.3172, rewards: 0.1000, total_rewards: 0.9434: 3%|▎ | 160/5000 [00:08<02:24, 33.46it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434: 3%|▎ | 160/5000 [00:08<02:24, 33.46it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434: 4%|▍ | 192/5000 [00:09<01:51, 42.98it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 4%|▍ | 192/5000 [00:09<01:51, 42.98it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 4%|▍ | 224/5000 [00:09<01:30, 52.66it/s]
r_training: 0.3415, rewards: 0.1000, total_rewards: 0.9434: 4%|▍ | 224/5000 [00:09<01:30, 52.66it/s]
r_training: 0.3415, rewards: 0.1000, total_rewards: 0.9434: 5%|▌ | 256/5000 [00:09<01:17, 61.24it/s]
r_training: 0.3415, rewards: 0.1000, total_rewards: 0.9434: 5%|▌ | 256/5000 [00:09<01:17, 61.24it/s]
r_training: 0.3415, rewards: 0.1000, total_rewards: 0.9434: 6%|▌ | 288/5000 [00:10<01:09, 68.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 6%|▌ | 288/5000 [00:10<01:09, 68.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 6%|▋ | 320/5000 [00:10<01:03, 73.36it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 6%|▋ | 320/5000 [00:10<01:03, 73.36it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 7%|▋ | 352/5000 [00:10<01:00, 77.01it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 7%|▋ | 352/5000 [00:10<01:00, 77.01it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 8%|▊ | 384/5000 [00:11<00:56, 81.86it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 8%|▊ | 384/5000 [00:11<00:56, 81.86it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 8%|▊ | 416/5000 [00:11<00:53, 85.29it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 8%|▊ | 416/5000 [00:11<00:53, 85.29it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 9%|▉ | 448/5000 [00:11<00:52, 87.25it/s]
r_training: 0.3899, rewards: 0.1000, total_rewards: 0.9434: 9%|▉ | 448/5000 [00:11<00:52, 87.25it/s]
r_training: 0.3899, rewards: 0.1000, total_rewards: 0.9434: 10%|▉ | 480/5000 [00:12<00:50, 89.90it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 10%|▉ | 480/5000 [00:12<00:50, 89.90it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 10%|█ | 512/5000 [00:12<00:50, 89.34it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434: 10%|█ | 512/5000 [00:12<00:50, 89.34it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434: 11%|█ | 544/5000 [00:12<00:48, 91.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 11%|█ | 544/5000 [00:12<00:48, 91.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 12%|█▏ | 576/5000 [00:13<00:48, 92.05it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 12%|█▏ | 576/5000 [00:13<00:48, 92.05it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 12%|█▏ | 608/5000 [00:13<00:48, 90.24it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 12%|█▏ | 608/5000 [00:13<00:48, 90.24it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 13%|█▎ | 640/5000 [00:14<00:46, 93.16it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 13%|█▎ | 640/5000 [00:14<00:46, 93.16it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 13%|█▎ | 672/5000 [00:14<00:47, 91.51it/s]
r_training: 0.3385, rewards: 0.1000, total_rewards: 0.9434: 13%|█▎ | 672/5000 [00:14<00:47, 91.51it/s]
r_training: 0.3385, rewards: 0.1000, total_rewards: 0.9434: 14%|█▍ | 704/5000 [00:14<00:47, 90.33it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 14%|█▍ | 704/5000 [00:14<00:47, 90.33it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 15%|█▍ | 736/5000 [00:15<00:46, 91.78it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 15%|█▍ | 736/5000 [00:15<00:46, 91.78it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 15%|█▌ | 768/5000 [00:15<00:46, 91.99it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434: 15%|█▌ | 768/5000 [00:15<00:46, 91.99it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434: 16%|█▌ | 800/5000 [00:15<00:46, 90.97it/s]
r_training: 0.3718, rewards: 0.1000, total_rewards: 0.9434: 16%|█▌ | 800/5000 [00:15<00:46, 90.97it/s]
r_training: 0.3718, rewards: 0.1000, total_rewards: 0.9434: 17%|█▋ | 832/5000 [00:16<00:46, 88.81it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 0.9434: 17%|█▋ | 832/5000 [00:16<00:46, 88.81it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 0.9434: 17%|█▋ | 864/5000 [00:16<00:46, 88.58it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 17%|█▋ | 864/5000 [00:16<00:46, 88.58it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 18%|█▊ | 896/5000 [00:16<00:44, 91.40it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 18%|█▊ | 896/5000 [00:16<00:44, 91.40it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 19%|█▊ | 928/5000 [00:17<00:44, 92.46it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 19%|█▊ | 928/5000 [00:17<00:44, 92.46it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 19%|█▉ | 960/5000 [00:17<00:43, 93.04it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 19%|█▉ | 960/5000 [00:17<00:43, 93.04it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 20%|█▉ | 992/5000 [00:17<00:43, 92.90it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 20%|█▉ | 992/5000 [00:17<00:43, 92.90it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 20%|██ | 1024/5000 [00:18<00:42, 94.08it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434: 20%|██ | 1024/5000 [00:18<00:42, 94.08it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434: 21%|██ | 1056/5000 [00:18<00:42, 92.10it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 21%|██ | 1056/5000 [00:18<00:42, 92.10it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 22%|██▏ | 1088/5000 [00:18<00:41, 95.07it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 22%|██▏ | 1088/5000 [00:18<00:41, 95.07it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 22%|██▏ | 1120/5000 [00:19<00:41, 94.24it/s]
r_training: 0.3596, rewards: 0.1000, total_rewards: 0.9434: 22%|██▏ | 1120/5000 [00:19<00:41, 94.24it/s]
r_training: 0.3596, rewards: 0.1000, total_rewards: 0.9434: 23%|██▎ | 1152/5000 [00:19<00:41, 93.41it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 23%|██▎ | 1152/5000 [00:19<00:41, 93.41it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 24%|██▎ | 1184/5000 [00:19<00:40, 93.20it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 24%|██▎ | 1184/5000 [00:19<00:40, 93.20it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 24%|██▍ | 1216/5000 [00:20<00:40, 93.42it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 24%|██▍ | 1216/5000 [00:20<00:40, 93.42it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 25%|██▍ | 1248/5000 [00:20<00:39, 95.28it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 25%|██▍ | 1248/5000 [00:20<00:39, 95.28it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 26%|██▌ | 1280/5000 [00:20<00:40, 92.85it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434: 26%|██▌ | 1280/5000 [00:20<00:40, 92.85it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434: 26%|██▌ | 1312/5000 [00:21<00:40, 91.84it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 26%|██▌ | 1312/5000 [00:21<00:40, 91.84it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 27%|██▋ | 1344/5000 [00:21<00:39, 92.25it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 27%|██▋ | 1344/5000 [00:21<00:39, 92.25it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 28%|██▊ | 1376/5000 [00:21<00:39, 92.48it/s]
r_training: 0.3718, rewards: 0.1000, total_rewards: 0.9434: 28%|██▊ | 1376/5000 [00:21<00:39, 92.48it/s]
r_training: 0.3718, rewards: 0.1000, total_rewards: 0.9434: 28%|██▊ | 1408/5000 [00:22<00:39, 89.97it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 28%|██▊ | 1408/5000 [00:22<00:39, 89.97it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 29%|██▉ | 1440/5000 [00:22<00:39, 89.36it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 29%|██▉ | 1440/5000 [00:22<00:39, 89.36it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 29%|██▉ | 1472/5000 [00:23<00:38, 90.56it/s]
r_training: 0.3596, rewards: 0.1000, total_rewards: 0.9434: 29%|██▉ | 1472/5000 [00:23<00:38, 90.56it/s]
r_training: 0.3596, rewards: 0.1000, total_rewards: 0.9434: 30%|███ | 1504/5000 [00:23<00:39, 89.50it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 30%|███ | 1504/5000 [00:23<00:39, 89.50it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 31%|███ | 1536/5000 [00:23<00:38, 90.03it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 31%|███ | 1536/5000 [00:23<00:38, 90.03it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 31%|███▏ | 1568/5000 [00:24<00:37, 90.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 31%|███▏ | 1568/5000 [00:24<00:37, 90.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 32%|███▏ | 1600/5000 [00:24<00:36, 92.23it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 32%|███▏ | 1600/5000 [00:24<00:36, 92.23it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 33%|███▎ | 1632/5000 [00:24<00:37, 90.55it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 33%|███▎ | 1632/5000 [00:24<00:37, 90.55it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 33%|███▎ | 1664/5000 [00:25<00:37, 90.12it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 0.9434: 33%|███▎ | 1664/5000 [00:25<00:37, 90.12it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 0.9434: 34%|███▍ | 1696/5000 [00:25<00:36, 91.29it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 34%|███▍ | 1696/5000 [00:25<00:36, 91.29it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 35%|███▍ | 1728/5000 [00:25<00:36, 90.58it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 35%|███▍ | 1728/5000 [00:25<00:36, 90.58it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 35%|███▌ | 1760/5000 [00:26<00:36, 89.72it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 35%|███▌ | 1760/5000 [00:26<00:36, 89.72it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 36%|███▌ | 1792/5000 [00:26<00:34, 93.09it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 36%|███▌ | 1792/5000 [00:26<00:34, 93.09it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 36%|███▋ | 1824/5000 [00:26<00:34, 92.96it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 36%|███▋ | 1824/5000 [00:26<00:34, 92.96it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 37%|███▋ | 1856/5000 [00:27<00:33, 93.14it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 37%|███▋ | 1856/5000 [00:27<00:33, 93.14it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 38%|███▊ | 1888/5000 [00:27<00:33, 92.74it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434: 38%|███▊ | 1888/5000 [00:27<00:33, 92.74it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434: 38%|███▊ | 1920/5000 [00:27<00:33, 92.69it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 38%|███▊ | 1920/5000 [00:27<00:33, 92.69it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 39%|███▉ | 1952/5000 [00:28<00:33, 91.41it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 39%|███▉ | 1952/5000 [00:28<00:33, 91.41it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 40%|███▉ | 1984/5000 [00:28<00:32, 92.94it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 0.9434: 40%|███▉ | 1984/5000 [00:28<00:32, 92.94it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 0.9434: 40%|████ | 2016/5000 [00:29<00:32, 92.04it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 40%|████ | 2016/5000 [00:29<00:32, 92.04it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 41%|████ | 2048/5000 [00:29<00:31, 92.53it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 41%|████ | 2048/5000 [00:29<00:31, 92.53it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 42%|████▏ | 2080/5000 [00:29<00:32, 89.41it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 42%|████▏ | 2080/5000 [00:29<00:32, 89.41it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 42%|████▏ | 2112/5000 [00:30<00:32, 89.80it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 42%|████▏ | 2112/5000 [00:30<00:32, 89.80it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 43%|████▎ | 2144/5000 [00:30<00:32, 88.50it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 43%|████▎ | 2144/5000 [00:30<00:32, 88.50it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 44%|████▎ | 2176/5000 [00:30<00:31, 90.17it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 44%|████▎ | 2176/5000 [00:30<00:31, 90.17it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 44%|████▍ | 2208/5000 [00:31<00:31, 87.70it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 44%|████▍ | 2208/5000 [00:31<00:31, 87.70it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 45%|████▍ | 2240/5000 [00:31<00:31, 88.43it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 45%|████▍ | 2240/5000 [00:31<00:31, 88.43it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 45%|████▌ | 2272/5000 [00:31<00:31, 87.93it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 45%|████▌ | 2272/5000 [00:31<00:31, 87.93it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 46%|████▌ | 2304/5000 [00:32<00:30, 89.17it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 46%|████▌ | 2304/5000 [00:32<00:30, 89.17it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 47%|████▋ | 2336/5000 [00:32<00:30, 88.09it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 47%|████▋ | 2336/5000 [00:32<00:30, 88.09it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 47%|████▋ | 2368/5000 [00:32<00:29, 88.86it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 47%|████▋ | 2368/5000 [00:32<00:29, 88.86it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 48%|████▊ | 2400/5000 [00:33<00:28, 91.64it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 48%|████▊ | 2400/5000 [00:33<00:28, 91.64it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 49%|████▊ | 2432/5000 [00:33<00:28, 90.72it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 49%|████▊ | 2432/5000 [00:33<00:28, 90.72it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 49%|████▉ | 2464/5000 [00:34<00:27, 91.63it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 49%|████▉ | 2464/5000 [00:34<00:27, 91.63it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 50%|████▉ | 2496/5000 [00:34<00:27, 91.98it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 50%|████▉ | 2496/5000 [00:34<00:27, 91.98it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 51%|█████ | 2528/5000 [00:34<00:27, 89.48it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 51%|█████ | 2528/5000 [00:34<00:27, 89.48it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 51%|█████ | 2560/5000 [00:35<00:27, 89.53it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 51%|█████ | 2560/5000 [00:35<00:27, 89.53it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 52%|█████▏ | 2592/5000 [00:35<00:27, 89.09it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434: 52%|█████▏ | 2592/5000 [00:35<00:27, 89.09it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434: 52%|█████▏ | 2624/5000 [00:35<00:26, 89.68it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 52%|█████▏ | 2624/5000 [00:35<00:26, 89.68it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 53%|█████▎ | 2656/5000 [00:36<00:25, 92.80it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 53%|█████▎ | 2656/5000 [00:36<00:25, 92.80it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 54%|█████▍ | 2688/5000 [00:36<00:25, 90.67it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 54%|█████▍ | 2688/5000 [00:36<00:25, 90.67it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 54%|█████▍ | 2720/5000 [00:36<00:25, 90.36it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 54%|█████▍ | 2720/5000 [00:36<00:25, 90.36it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 55%|█████▌ | 2752/5000 [00:37<00:24, 92.32it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 55%|█████▌ | 2752/5000 [00:37<00:24, 92.32it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 56%|█████▌ | 2784/5000 [00:37<00:24, 90.36it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 56%|█████▌ | 2784/5000 [00:37<00:24, 90.36it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 56%|█████▋ | 2816/5000 [00:37<00:23, 91.64it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 56%|█████▋ | 2816/5000 [00:37<00:23, 91.64it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 57%|█████▋ | 2848/5000 [00:38<00:23, 93.53it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 57%|█████▋ | 2848/5000 [00:38<00:23, 93.53it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 58%|█████▊ | 2880/5000 [00:38<00:23, 91.80it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 0.9434: 58%|█████▊ | 2880/5000 [00:38<00:23, 91.80it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 0.9434: 58%|█████▊ | 2912/5000 [00:38<00:22, 93.54it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 58%|█████▊ | 2912/5000 [00:38<00:22, 93.54it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 59%|█████▉ | 2944/5000 [00:39<00:21, 94.19it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 59%|█████▉ | 2944/5000 [00:39<00:21, 94.19it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 60%|█████▉ | 2976/5000 [00:39<00:21, 95.39it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 0.9434: 60%|█████▉ | 2976/5000 [00:39<00:21, 95.39it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 0.9434: 60%|██████ | 3008/5000 [00:39<00:21, 93.11it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 60%|██████ | 3008/5000 [00:39<00:21, 93.11it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 61%|██████ | 3040/5000 [00:40<00:21, 92.74it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 61%|██████ | 3040/5000 [00:40<00:21, 92.74it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 61%|██████▏ | 3072/5000 [00:40<00:20, 94.36it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 61%|██████▏ | 3072/5000 [00:40<00:20, 94.36it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 62%|██████▏ | 3104/5000 [00:40<00:20, 94.32it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 62%|██████▏ | 3104/5000 [00:40<00:20, 94.32it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 63%|██████▎ | 3136/5000 [00:41<00:20, 90.60it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 63%|██████▎ | 3136/5000 [00:41<00:20, 90.60it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 63%|██████▎ | 3168/5000 [00:41<00:19, 92.78it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 63%|██████▎ | 3168/5000 [00:41<00:19, 92.78it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 64%|██████▍ | 3200/5000 [00:42<00:19, 90.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 64%|██████▍ | 3200/5000 [00:42<00:19, 90.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 65%|██████▍ | 3232/5000 [00:48<02:05, 14.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 65%|██████▍ | 3232/5000 [00:48<02:05, 14.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 65%|██████▌ | 3264/5000 [00:49<01:32, 18.79it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 6.6667: 65%|██████▌ | 3264/5000 [00:49<01:32, 18.79it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 6.6667: 66%|██████▌ | 3296/5000 [00:49<01:08, 24.70it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: 66%|██████▌ | 3296/5000 [00:49<01:08, 24.70it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: 67%|██████▋ | 3328/5000 [00:49<00:53, 31.47it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: 67%|██████▋ | 3328/5000 [00:49<00:53, 31.47it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: 67%|██████▋ | 3360/5000 [00:50<00:41, 39.15it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 6.6667: 67%|██████▋ | 3360/5000 [00:50<00:41, 39.15it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 6.6667: 68%|██████▊ | 3392/5000 [00:50<00:33, 47.31it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 68%|██████▊ | 3392/5000 [00:50<00:33, 47.31it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 68%|██████▊ | 3424/5000 [00:50<00:28, 55.38it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 6.6667: 68%|██████▊ | 3424/5000 [00:50<00:28, 55.38it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 6.6667: 69%|██████▉ | 3456/5000 [00:51<00:24, 62.68it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 69%|██████▉ | 3456/5000 [00:51<00:24, 62.68it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 70%|██████▉ | 3488/5000 [00:51<00:21, 68.93it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 70%|██████▉ | 3488/5000 [00:51<00:21, 68.93it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 70%|███████ | 3520/5000 [00:51<00:20, 73.99it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 6.6667: 70%|███████ | 3520/5000 [00:51<00:20, 73.99it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 6.6667: 71%|███████ | 3552/5000 [00:52<00:18, 79.67it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 6.6667: 71%|███████ | 3552/5000 [00:52<00:18, 79.67it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 6.6667: 72%|███████▏ | 3584/5000 [00:52<00:17, 81.36it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 72%|███████▏ | 3584/5000 [00:52<00:17, 81.36it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 72%|███████▏ | 3616/5000 [00:53<00:16, 83.02it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 72%|███████▏ | 3616/5000 [00:53<00:16, 83.02it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 73%|███████▎ | 3648/5000 [00:53<00:16, 83.59it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 73%|███████▎ | 3648/5000 [00:53<00:16, 83.59it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 74%|███████▎ | 3680/5000 [00:53<00:15, 85.55it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 74%|███████▎ | 3680/5000 [00:53<00:15, 85.55it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 74%|███████▍ | 3712/5000 [00:54<00:15, 85.86it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 74%|███████▍ | 3712/5000 [00:54<00:15, 85.86it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 75%|███████▍ | 3744/5000 [00:54<00:14, 86.99it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 6.6667: 75%|███████▍ | 3744/5000 [00:54<00:14, 86.99it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 6.6667: 76%|███████▌ | 3776/5000 [00:54<00:13, 90.06it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 76%|███████▌ | 3776/5000 [00:54<00:13, 90.06it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 76%|███████▌ | 3808/5000 [00:55<00:12, 92.52it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 6.6667: 76%|███████▌ | 3808/5000 [00:55<00:12, 92.52it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 6.6667: 77%|███████▋ | 3840/5000 [00:55<00:12, 93.65it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 77%|███████▋ | 3840/5000 [00:55<00:12, 93.65it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 77%|███████▋ | 3872/5000 [00:55<00:11, 94.13it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 77%|███████▋ | 3872/5000 [00:55<00:11, 94.13it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 78%|███████▊ | 3904/5000 [00:56<00:11, 94.43it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 78%|███████▊ | 3904/5000 [00:56<00:11, 94.43it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 79%|███████▊ | 3936/5000 [00:56<00:11, 95.61it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 79%|███████▊ | 3936/5000 [00:56<00:11, 95.61it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 79%|███████▉ | 3968/5000 [00:56<00:10, 96.93it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: 79%|███████▉ | 3968/5000 [00:56<00:10, 96.93it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: 80%|████████ | 4000/5000 [00:57<00:10, 93.21it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: 80%|████████ | 4000/5000 [00:57<00:10, 93.21it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: 81%|████████ | 4032/5000 [00:57<00:10, 91.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 81%|████████ | 4032/5000 [00:57<00:10, 91.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 81%|████████▏ | 4064/5000 [00:57<00:10, 88.04it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 81%|████████▏ | 4064/5000 [00:57<00:10, 88.04it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 82%|████████▏ | 4096/5000 [00:58<00:10, 88.45it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 82%|████████▏ | 4096/5000 [00:58<00:10, 88.45it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 83%|████████▎ | 4128/5000 [00:58<00:09, 87.25it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 83%|████████▎ | 4128/5000 [00:58<00:09, 87.25it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 83%|████████▎ | 4160/5000 [00:59<00:09, 88.03it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 6.6667: 83%|████████▎ | 4160/5000 [00:59<00:09, 88.03it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 6.6667: 84%|████████▍ | 4192/5000 [00:59<00:09, 89.32it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 6.6667: 84%|████████▍ | 4192/5000 [00:59<00:09, 89.32it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 6.6667: 84%|████████▍ | 4224/5000 [00:59<00:08, 88.51it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 84%|████████▍ | 4224/5000 [00:59<00:08, 88.51it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 85%|████████▌ | 4256/5000 [01:00<00:08, 90.12it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 85%|████████▌ | 4256/5000 [01:00<00:08, 90.12it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 86%|████████▌ | 4288/5000 [01:00<00:07, 92.98it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: 86%|████████▌ | 4288/5000 [01:00<00:07, 92.98it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: 86%|████████▋ | 4320/5000 [01:00<00:07, 91.38it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 86%|████████▋ | 4320/5000 [01:00<00:07, 91.38it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 87%|████████▋ | 4352/5000 [01:01<00:07, 91.61it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 6.6667: 87%|████████▋ | 4352/5000 [01:01<00:07, 91.61it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 6.6667: 88%|████████▊ | 4384/5000 [01:01<00:06, 94.17it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: 88%|████████▊ | 4384/5000 [01:01<00:06, 94.17it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: 88%|████████▊ | 4416/5000 [01:01<00:06, 94.57it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 88%|████████▊ | 4416/5000 [01:01<00:06, 94.57it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 89%|████████▉ | 4448/5000 [01:02<00:05, 92.92it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 89%|████████▉ | 4448/5000 [01:02<00:05, 92.92it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 90%|████████▉ | 4480/5000 [01:02<00:05, 91.40it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 90%|████████▉ | 4480/5000 [01:02<00:05, 91.40it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 90%|█████████ | 4512/5000 [01:02<00:05, 91.26it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 90%|█████████ | 4512/5000 [01:02<00:05, 91.26it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 91%|█████████ | 4544/5000 [01:03<00:04, 92.65it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 91%|█████████ | 4544/5000 [01:03<00:04, 92.65it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 92%|█████████▏| 4576/5000 [01:03<00:04, 93.42it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 92%|█████████▏| 4576/5000 [01:03<00:04, 93.42it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 92%|█████████▏| 4608/5000 [01:03<00:04, 94.25it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: 92%|█████████▏| 4608/5000 [01:03<00:04, 94.25it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: 93%|█████████▎| 4640/5000 [01:04<00:03, 92.66it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 93%|█████████▎| 4640/5000 [01:04<00:03, 92.66it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 93%|█████████▎| 4672/5000 [01:04<00:03, 95.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 93%|█████████▎| 4672/5000 [01:04<00:03, 95.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 94%|█████████▍| 4704/5000 [01:04<00:03, 95.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 94%|█████████▍| 4704/5000 [01:04<00:03, 95.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 95%|█████████▍| 4736/5000 [01:05<00:02, 91.72it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 95%|█████████▍| 4736/5000 [01:05<00:02, 91.72it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 95%|█████████▌| 4768/5000 [01:05<00:02, 93.90it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 95%|█████████▌| 4768/5000 [01:05<00:02, 93.90it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 96%|█████████▌| 4800/5000 [01:05<00:02, 90.68it/s]
r_training: 0.3718, rewards: 0.1000, total_rewards: 6.6667: 96%|█████████▌| 4800/5000 [01:05<00:02, 90.68it/s]
r_training: 0.3718, rewards: 0.1000, total_rewards: 6.6667: 97%|█████████▋| 4832/5000 [01:06<00:01, 89.91it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 97%|█████████▋| 4832/5000 [01:06<00:01, 89.91it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 97%|█████████▋| 4864/5000 [01:06<00:01, 88.96it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 97%|█████████▋| 4864/5000 [01:06<00:01, 88.96it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 98%|█████████▊| 4896/5000 [01:07<00:01, 90.07it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 6.6667: 98%|█████████▊| 4896/5000 [01:07<00:01, 90.07it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 6.6667: 99%|█████████▊| 4928/5000 [01:07<00:00, 90.39it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 99%|█████████▊| 4928/5000 [01:07<00:00, 90.39it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: 99%|█████████▉| 4960/5000 [01:07<00:00, 92.28it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 6.6667: 99%|█████████▉| 4960/5000 [01:07<00:00, 92.28it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 6.6667: 100%|█████████▉| 4992/5000 [01:08<00:00, 90.27it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: 100%|█████████▉| 4992/5000 [01:08<00:00, 90.27it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 6.6667: : 5024it [01:08, 90.73it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 6.6667: : 5024it [01:08, 90.73it/s]
我们现在可以快速检查包含结果的CSV文件。
def print_csv_files_in_folder(folder_path):
"""
Find all CSV files in a folder and prints the first 10 lines of each file.
Args:
folder_path (str): The relative path to the folder.
"""
csv_files = []
output_str = ""
for dirpath, _, filenames in os.walk(folder_path):
for file in filenames:
if file.endswith(".csv"):
csv_files.append(os.path.join(dirpath, file))
for csv_file in csv_files:
output_str += f"File: {csv_file}\n"
with open(csv_file, "r") as f:
for i, line in enumerate(f):
if i == 10:
break
output_str += line.strip() + "\n"
output_str += "\n"
print(output_str)
print_csv_files_in_folder(logger.experiment.log_dir)
File: /tmp/tmprz_tomjh/dqn_exp_1f3ef74e-4ec7-11ef-991e-0242ac110002/scalars/r_training.csv
512,0.38084372878074646
1024,0.37784188985824585
1536,0.41726210713386536
2048,0.36880600452423096
2560,0.39912933111190796
3072,0.39912936091423035
3584,0.42945271730422974
4096,0.42945271730422974
4608,0.39912933111190796
File: /tmp/tmprz_tomjh/dqn_exp_1f3ef74e-4ec7-11ef-991e-0242ac110002/scalars/optim_steps.csv
512,128.0
1024,256.0
1536,384.0
2048,512.0
2560,640.0
3072,768.0
3584,896.0
4096,1024.0
4608,1152.0
File: /tmp/tmprz_tomjh/dqn_exp_1f3ef74e-4ec7-11ef-991e-0242ac110002/scalars/loss.csv
512,0.13857470452785492
1024,0.15926682949066162
1536,0.1994696408510208
2048,0.21946831047534943
2560,0.25826987624168396
3072,0.30737149715423584
3584,0.24386540055274963
4096,0.34079253673553467
4608,0.2449716478586197
File: /tmp/tmprz_tomjh/dqn_exp_1f3ef74e-4ec7-11ef-991e-0242ac110002/scalars/grad_norm_0.csv
512,1.89058256149292
1024,2.0029563903808594
1536,3.236938714981079
2048,2.1101794242858887
2560,2.259946823120117
3072,2.8765692710876465
3584,3.375800609588623
4096,3.7260398864746094
4608,2.8490850925445557
File: /tmp/tmprz_tomjh/dqn_exp_1f3ef74e-4ec7-11ef-991e-0242ac110002/scalars/rewards.csv
3232,0.10000000894069672
File: /tmp/tmprz_tomjh/dqn_exp_1f3ef74e-4ec7-11ef-991e-0242ac110002/scalars/total_rewards.csv
3232,6.666667461395264
结论和可能的改进¶
在这个教程中我们已经学习了:
如何编写 Trainer,包括构建其组件并在 Trainer 中注册它们;
如何编写DQN算法,包括如何创建一个策略来选择具有最高价值的动作
QValueNetwork;如何构建多进程数据收集器;
这个教程可能的改进包括:
也可使用优先级经验回放缓冲区(prioritized replay buffer)。该方法会为价值估计误差最大的样本赋予更高的采样优先级。 有关该功能的更多详情,请参阅文档中的 经验回放缓冲区章节 。
分布损失(请参见
DistributionalDQNLoss获取更多信息)。更多高级探索技术,如
NoisyLinear层等。
脚本总运行时间: (2 分钟 37.294 秒)
估计内存使用量: 1018 MB