注意
转到末尾下载完整的示例代码。
TorchRL trainer:DQN 示例¶
作者: Vincent Moens
TorchRL 提供了一个泛型类来处理
您的训练循环。trainer 执行一个嵌套循环,其中外部循环
是数据收集,内部循环使用此数据或某些数据
从 Replay 缓冲区中检索以训练模型。
在这个训练循环的不同点,钩子可以在
给定的间隔。
在本教程中,我们将使用 trainer 类来训练 DQN 算法 从头开始解决 CartPole 任务。
主要收获:
构建一个包含基本组件的 trainer:数据收集器、损失 模块、重放缓冲区和优化器。
向 trainer 添加钩子,例如 loggers、target network updater 等。
该 Trainer 是完全可定制的,并提供了大量功能。
本教程围绕其构造进行组织。
我们将首先详细介绍如何构建库的每个组件。
然后使用 class 将各个部分组合在一起。
在此过程中,我们还将关注库的其他一些方面:
如何在 TorchRL 中构建环境,包括转换(例如数据 标准化、帧连接、调整大小和转换为灰度) 和并行执行。与我们在 DDPG 教程中所做的不同,我们 将规格化像素,而不是状态向量。
如何设计对象,即 Actor 这会估算操作值并选取具有最高 预计回报;
QValueActor
如何有效地从环境中收集数据并存储数据 在 replay 缓冲区中;
如何使用 Multi-Step,这是 Off-Policy 算法的简单预处理步骤;
最后如何评估您的模型。
先决条件: 我们建议您先通过 PPO 教程熟悉 torchrl。
DQN¶
DQN(深度 Q 学习)是 深度强化学习的奠基工作。
在高层次上,算法非常简单:Q-学习包括 学习 state-action 值表,这样,当 遇到任何特定状态,我们都知道该选择哪个 action 搜索值最高的 1。这个简单的设置 要求 actions 和 states 为 discrete,否则无法构建 lookup table。
DQN 使用神经网络将 map从状态操作空间编码为 一个值(标量)空间,它摊销存储和探索所有 可能的状态-操作组合:如果在 过去,我们仍然可以将其与各种可用的操作一起传递 通过我们的神经网络获取每个 可用的操作。
我们将解决 Cart Pole 的经典控制问题。从 从中检索此环境的 Gymnasium 文档:
![推车杆](https://pytorch.org/rl/0.6/_images/cartpole_demo.gif)
我们的目的不是给出算法的 SOTA 实现,而是 在上下文中提供 TorchRL 功能的高级说明 的
import os
import uuid
import torch
from torch import nn
from torchrl.collectors import MultiaSyncDataCollector, SyncDataCollector
from torchrl.data import LazyMemmapStorage, MultiStep, TensorDictReplayBuffer
from torchrl.envs import (
EnvCreator,
ExplorationType,
ParallelEnv,
RewardScaling,
StepCounter,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.transforms import (
CatFrames,
Compose,
GrayScale,
ObservationNorm,
Resize,
ToTensorImage,
TransformedEnv,
)
from torchrl.modules import DuelingCnnDQNet, EGreedyModule, QValueActor
from torchrl.objectives import DQNLoss, SoftUpdate
from torchrl.record.loggers.csv import CSVLogger
from torchrl.trainers import (
LogReward,
Recorder,
ReplayBufferTrainer,
Trainer,
UpdateWeights,
)
def is_notebook() -> bool:
try:
shell = get_ipython().__class__.__name__
if shell == "ZMQInteractiveShell":
return True # Jupyter notebook or qtconsole
elif shell == "TerminalInteractiveShell":
return False # Terminal running IPython
else:
return False # Other type (?)
except NameError:
return False # Probably standard Python interpreter
让我们从算法所需的各种部分开始:
环境;
一个策略(以及我们在 “model” 伞下分组的相关模块);
数据收集器,它使策略在环境中发挥作用,并且 提供训练数据;
用于存储训练数据的重放缓冲区;
损失模块,计算目标函数来训练我们的策略 使回报最大化;
一个优化器,它根据我们的损失执行参数更新。
其他模块包括 logger、recorder(在 “eval” 模式)和 Target Network Updater 的 SET 实例。将所有这些组件放入 place 中,很容易看出人们是如何放错地方或误用 训练脚本。培训师在那里为您精心安排一切!
构建环境¶
首先,让我们编写一个将输出环境的辅助函数。照常 “原始”环境可能太简单了,无法在实践中使用,我们需要 一些数据转换,以将其输出公开给策略。
我们将使用五个转换:
StepCounter
计算每个轨迹中的步数;将转换一个 uint8 Tensor 在浮点中,Tensor 在 shape
[W, H, C]
[0, 1]
[C, W, H]
;将连接任意数量的 沿通道维度的单个张量中的连续帧 ()。 这很有用,因为单个图像不包含有关 Cartpole 的运动。关于过去的观察和行动的一些记忆 需要,通过递归神经网络或使用 框架。
N=4
在实践中,我们的环境构建器有两个参数:
parallel
:确定是否必须在 平行。我们在 之后堆叠转换以利用 对设备上的操作进行矢量化,尽管这会 从技术上讲,适用于附加到其自己的 变换。
obs_norm_sd
将包含 转换。ObservationNorm
def make_env(
parallel=False,
obs_norm_sd=None,
num_workers=1,
):
if obs_norm_sd is None:
obs_norm_sd = {"standard_normal": True}
if parallel:
def maker():
return GymEnv(
"CartPole-v1",
from_pixels=True,
pixels_only=True,
device=device,
)
base_env = ParallelEnv(
num_workers,
EnvCreator(maker),
# Don't create a sub-process if we have only one worker
serial_for_single=True,
mp_start_method=mp_context,
)
else:
base_env = GymEnv(
"CartPole-v1",
from_pixels=True,
pixels_only=True,
device=device,
)
env = TransformedEnv(
base_env,
Compose(
StepCounter(), # to count the steps of each trajectory
ToTensorImage(),
RewardScaling(loc=0.0, scale=0.1),
GrayScale(),
Resize(64, 64),
CatFrames(4, in_keys=["pixels"], dim=-3),
ObservationNorm(in_keys=["pixels"], **obs_norm_sd),
),
)
return env
计算归一化常量¶
要标准化图像,我们不想单独标准化每个像素
具有完整的归一化蒙版,但具有更简单的形状归一化常量集(LOC 和 Scale 参数)。
我们将使用参数
的来指示哪个
dimensions 必须减小,并且参数要确保
并非所有尺寸都会在此过程中消失:[C, W, H]
[C, 1, 1]
reduce_dim
init_stats()
keep_dims
def get_norm_stats():
test_env = make_env()
test_env.transform[-1].init_stats(
num_iter=1000, cat_dim=0, reduce_dim=[-1, -2, -4], keep_dims=(-1, -2)
)
obs_norm_sd = test_env.transform[-1].state_dict()
# let's check that normalizing constants have a size of ``[C, 1, 1]`` where
# ``C=4`` (because of :class:`~torchrl.envs.CatFrames`).
print("state dict of the observation norm:", obs_norm_sd)
test_env.close()
del test_env
return obs_norm_sd
构建模型 (Deep Q 网络)¶
以下函数构建一个对象,该对象是一个简单的 CNN,后跟一个两层 MLP。唯一使用的技巧
这里是 action 值(即 left 和 right action value)是
计算方式
其中 是我们的动作值向量,是一个函数,是一个函数,对于 和 。
我们的网络被包装在一个 ,
它将读取 state-action
值,选择具有最大值的那个并写入所有这些结果
在输入 .QValueActor
tensordict.TensorDict
def make_model(dummy_env):
cnn_kwargs = {
"num_cells": [32, 64, 64],
"kernel_sizes": [6, 4, 3],
"strides": [2, 2, 1],
"activation_class": nn.ELU,
# This can be used to reduce the size of the last layer of the CNN
# "squeeze_output": True,
# "aggregator_class": nn.AdaptiveAvgPool2d,
# "aggregator_kwargs": {"output_size": (1, 1)},
}
mlp_kwargs = {
"depth": 2,
"num_cells": [
64,
64,
],
"activation_class": nn.ELU,
}
net = DuelingCnnDQNet(
dummy_env.action_spec.shape[-1], 1, cnn_kwargs, mlp_kwargs
).to(device)
net.value[-1].bias.data.fill_(init_bias)
actor = QValueActor(net, in_keys=["pixels"], spec=dummy_env.action_spec).to(device)
# init actor: because the model is composed of lazy conv/linear layers,
# we must pass a fake batch of data through it to instantiate them.
tensordict = dummy_env.fake_tensordict()
actor(tensordict)
# we join our actor with an EGreedyModule for data collection
exploration_module = EGreedyModule(
spec=dummy_env.action_spec,
annealing_num_steps=total_frames,
eps_init=eps_greedy_val,
eps_end=eps_greedy_val_env,
)
actor_explore = TensorDictSequential(actor, exploration_module)
return actor, actor_explore
收集和存储数据¶
重放缓冲区¶
重放缓冲区在 DQN 等非策略 RL 算法中起着核心作用。 它们构成了我们将在训练期间从中采样的数据集。
在这里,我们将使用常规抽样策略,尽管优先考虑 RB 可以显著提高性能。
我们使用 class 将存储放在磁盘上。这
storage 以惰性方式创建:它只会在
第一批数据将传递给它。
此存储的唯一要求是在写入时传递给它的数据 时间必须始终具有相同的形状。
def get_replay_buffer(buffer_size, n_optim, batch_size):
replay_buffer = TensorDictReplayBuffer(
batch_size=batch_size,
storage=LazyMemmapStorage(buffer_size),
prefetch=n_optim,
)
return replay_buffer
数据收集器¶
与 PPO 和 DDPG 一样,我们将使用 一个数据收集器作为外部循环中的 DataLoader。
我们选择以下配置:我们将运行一系列 并行环境 在不同的收集器中同步并行, 它们本身以并行方式运行,但以异步方式运行。
注意
此功能仅在 “spawn” 中运行代码时可用
Python Multiprocessing 库的 start 方法。如果运行本教程
直接作为脚本(从而使用 “fork” 方法),我们将使用
常规 .
这种配置的优点是我们可以平衡
计算,该计算与我们想要执行的内容一起批量执行
异步。我们鼓励读者尝试该系列
速度受修改 collector 数量(即
环境构造函数)和
environment 在每个收集器中并行执行(由 hyperparameter 控制)。num_workers
Collector 的设备可以通过 (general)、 和 参数进行完全参数化。
该参数将修改
正在收集的数据的位置:如果我们正在收集的批次
具有相当大的大小,我们可能希望将它们存储在不同的位置
而不是进行计算的设备。对于异步数据
像我们这样的收集器,不同的存储设备意味着
We collect 不会每次都位于同一设备上,这是
出训练循环必须考虑。为简单起见,我们将设备设置为
所有 Sub-collector 的值相同。device
policy_device
env_device
storing_device
storing_device
def get_collector(
stats,
num_collectors,
actor_explore,
frames_per_batch,
total_frames,
device,
):
# We can't use nested child processes with mp_start_method="fork"
if is_fork:
cls = SyncDataCollector
env_arg = make_env(parallel=True, obs_norm_sd=stats, num_workers=num_workers)
else:
cls = MultiaSyncDataCollector
env_arg = [
make_env(parallel=True, obs_norm_sd=stats, num_workers=num_workers)
] * num_collectors
data_collector = cls(
env_arg,
policy=actor_explore,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
# this is the default behavior: the collector runs in ``"random"`` (or explorative) mode
exploration_type=ExplorationType.RANDOM,
# We set the all the devices to be identical. Below is an example of
# heterogeneous devices
device=device,
storing_device=device,
split_trajs=False,
postproc=MultiStep(gamma=gamma, n_steps=5),
)
return data_collector
损失函数¶
构建我们的损失函数很简单:我们只需要提供 模型和一组超参数添加到 DQNLoss 类中。
目标参数¶
许多非策略 RL 算法在使用 “目标参数” 的概念时 来估计下一个 state 或 state-action 对的值。 目标参数是模型参数的滞后副本。因为 他们的预测与当前模型配置的预测不匹配,则 通过对所估计的值设置悲观界限来帮助学习。 这是一个无处不在的强大技巧(称为“双 Q-学习”) 在类似的算法中。
def get_loss_module(actor, gamma):
loss_module = DQNLoss(actor, delay_value=True)
loss_module.make_value_estimator(gamma=gamma)
target_updater = SoftUpdate(loss_module, eps=0.995)
return loss_module, target_updater
超参数¶
让我们从超参数开始。以下设置应该可以正常工作 在实践中,希望算法的性能不应 对这些的微小变化太敏感了。
is_fork = multiprocessing.get_start_method() == "fork"
device = (
torch.device(0)
if torch.cuda.is_available() and not is_fork
else torch.device("cpu")
)
优化¶
# the learning rate of the optimizer
lr = 2e-3
# weight decay
wd = 1e-5
# the beta parameters of Adam
betas = (0.9, 0.999)
# Optimization steps per batch collected (aka UPD or updates per data)
n_optim = 8
DQN 参数¶
Gamma 衰减因子
gamma = 0.99
Smooth target network update decay 参数。 这大致对应于硬目标网络的 1/tau 区间 更新
tau = 0.02
数据收集和重放缓冲区¶
注意
用于正确训练的值已注释。
环境中收集的帧总数。在其他实现中, user 定义最大剧集数。 这对我们的数据收集器来说更难做到,因为它们会返回批次 的 N 个收集帧,其中 N 是一个常数。 但是,人们可以很容易地获得相同的集数限制 当某个数字 剧集已被收集。
total_frames = 5_000 # 500000
用于初始化重播缓冲区的随机帧。
init_random_frames = 100 # 1000
收集的每个批次中的帧。
frames_per_batch = 32 # 128
在每个优化步骤中从 replay 缓冲区采样的帧
batch_size = 32 # 256
重放缓冲区的大小(以帧为单位)
buffer_size = min(total_frames, 100000)
每个数据收集器中并行运行的环境数
num_workers = 2 # 8
num_collectors = 2 # 4
环境和勘探¶
我们在 Epsilon-greedy 中设置 epsilon 因子的初始值和最终值 勘探。 由于我们的策略是确定性的,因此探索至关重要:没有它, 随机性的唯一来源是环境重置。
eps_greedy_val = 0.1
eps_greedy_val_env = 0.005
为了加快学习速度,我们设置了价值网络最后一层的偏差 设置为预定义值(这不是强制性的)
init_bias = 2.0
注意
用于快速呈现教程超参数
设置为非常低的数字。为了获得合理的性能,请使用更大的
值,例如 500000total_frames
构建 Trainer¶
collector
loss_module
optimizer
logger
:记录器可以是total_frames
:此参数定义 Trainer 的生命周期。frame_skip
:使用跳帧时,必须创建收集器 了解它,以便准确计算帧数 收集等。让 trainer 知道此参数不是 强制的,但有助于在设置之间进行更公平的比较,其中 总帧数 (预算) 是固定的,但 frame-skip 是 变量。
stats = get_norm_stats()
test_env = make_env(parallel=False, obs_norm_sd=stats)
# Get model
actor, actor_explore = make_model(test_env)
loss_module, target_net_updater = get_loss_module(actor, gamma)
collector = get_collector(
stats=stats,
num_collectors=num_collectors,
actor_explore=actor_explore,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
device=device,
)
optimizer = torch.optim.Adam(
loss_module.parameters(), lr=lr, weight_decay=wd, betas=betas
)
exp_name = f"dqn_exp_{uuid.uuid1()}"
tmpdir = tempfile.TemporaryDirectory()
logger = CSVLogger(exp_name=exp_name, log_dir=tmpdir.name)
warnings.warn(f"log dir: {logger.experiment.log_dir}")
state dict of the observation norm: OrderedDict([('standard_normal', tensor(True)), ('loc', tensor([[[0.9895]],
[[0.9895]],
[[0.9895]],
[[0.9895]]])), ('scale', tensor([[[0.0737]],
[[0.0737]],
[[0.0737]],
[[0.0737]]]))])
我们可以控制记录标量的频率。这里我们设置这个 设置为较低的值,因为我们的训练循环很短:
注册 hook¶
注册 hook 可以通过两种不同的方式实现:
如果 hook 有它,则该方法
是首选。只需提供 trainer 作为输入 并且 hook 将在默认位置使用默认名称注册。 对于某些钩子,注册可能相当复杂:
需要 3 个钩子 (和 ),其中 实现起来可能很麻烦。
extend
sample
update_priority
buffer_hook = ReplayBufferTrainer(
get_replay_buffer(buffer_size, n_optim, batch_size=batch_size),
flatten_tensordicts=True,
)
buffer_hook.register(trainer)
weight_updater = UpdateWeights(collector, update_weights_interval=1)
weight_updater.register(trainer)
recorder = Recorder(
record_interval=100, # log every 100 optimization steps
record_frames=1000, # maximum number of frames in the record
frame_skip=1,
policy_exploration=actor_explore,
environment=test_env,
exploration_type=ExplorationType.DETERMINISTIC,
log_keys=[("next", "reward")],
out_keys={("next", "reward"): "rewards"},
log_pbar=True,
)
recorder.register(trainer)
探测模块 ε 因子也经过退火处理:
trainer.register_op("post_steps", actor_explore[1].step, frames=frames_per_batch)
任何可调用对象(包括
子类)都可以使用 . 在这种情况下,必须显式传递 location ()。此方法给出 对钩子的位置有更多的控制,但也需要更多的 了解 Trainer 机制。 查看 Trainer 文档以获取 Trainer 钩子的详细说明。
register_op()
trainer.register_op("post_optim", target_net_updater.step)
我们也可以记录训练奖励。请注意,这无关紧要 使用 CartPole,因为奖励始终为 1。折扣后的奖励总和为 最大化不是通过获得更高的奖励,而是通过保持车杆的活力 更长的时间。 这将total_rewards反映在 进度条。
log_reward = LogReward(log_pbar=True)
log_reward.register(trainer)
我们在这里,准备好训练我们的算法了!只需简单地打电话,我们就会登录我们的结果。trainer.train()
trainer.train()
0%| | 0/5000 [00:00<?, ?it/s]
1%| | 32/5000 [00:07<20:39, 4.01it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 1%| | 32/5000 [00:07<20:39, 4.01it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 1%|▏ | 64/5000 [00:08<09:02, 9.10it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 1%|▏ | 64/5000 [00:08<09:02, 9.10it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 2%|▏ | 96/5000 [00:08<05:18, 15.40it/s]
r_training: 0.3566, rewards: 0.1000, total_rewards: 0.9434: 2%|▏ | 96/5000 [00:08<05:18, 15.40it/s]
r_training: 0.3566, rewards: 0.1000, total_rewards: 0.9434: 3%|▎ | 128/5000 [00:09<03:33, 22.84it/s]
r_training: 0.3323, rewards: 0.1000, total_rewards: 0.9434: 3%|▎ | 128/5000 [00:09<03:33, 22.84it/s]
r_training: 0.3323, rewards: 0.1000, total_rewards: 0.9434: 3%|▎ | 160/5000 [00:09<02:34, 31.38it/s]
r_training: 0.3445, rewards: 0.1000, total_rewards: 0.9434: 3%|▎ | 160/5000 [00:09<02:34, 31.38it/s]
r_training: 0.3445, rewards: 0.1000, total_rewards: 0.9434: 4%|▍ | 192/5000 [00:09<01:57, 40.89it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 4%|▍ | 192/5000 [00:09<01:57, 40.89it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 4%|▍ | 224/5000 [00:10<01:36, 49.41it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434: 4%|▍ | 224/5000 [00:10<01:36, 49.41it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434: 5%|▌ | 256/5000 [00:10<01:21, 58.17it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434: 5%|▌ | 256/5000 [00:10<01:21, 58.17it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434: 6%|▌ | 288/5000 [00:10<01:12, 65.07it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 6%|▌ | 288/5000 [00:10<01:12, 65.07it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 6%|▋ | 320/5000 [00:11<01:05, 71.61it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 6%|▋ | 320/5000 [00:11<01:05, 71.61it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 7%|▋ | 352/5000 [00:11<01:01, 76.03it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434: 7%|▋ | 352/5000 [00:11<01:01, 76.03it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434: 8%|▊ | 384/5000 [00:11<00:57, 80.71it/s]
r_training: 0.3505, rewards: 0.1000, total_rewards: 0.9434: 8%|▊ | 384/5000 [00:11<00:57, 80.71it/s]
r_training: 0.3505, rewards: 0.1000, total_rewards: 0.9434: 8%|▊ | 416/5000 [00:12<00:55, 82.38it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 8%|▊ | 416/5000 [00:12<00:55, 82.38it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 9%|▉ | 448/5000 [00:12<00:52, 86.92it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 9%|▉ | 448/5000 [00:12<00:52, 86.92it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 10%|▉ | 480/5000 [00:12<00:51, 87.57it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 10%|▉ | 480/5000 [00:12<00:51, 87.57it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 10%|█ | 512/5000 [00:13<00:52, 86.28it/s]
r_training: 0.3566, rewards: 0.1000, total_rewards: 0.9434: 10%|█ | 512/5000 [00:13<00:52, 86.28it/s]
r_training: 0.3566, rewards: 0.1000, total_rewards: 0.9434: 11%|█ | 544/5000 [00:13<00:51, 86.63it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434: 11%|█ | 544/5000 [00:13<00:51, 86.63it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434: 12%|█▏ | 576/5000 [00:14<00:50, 88.04it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 12%|█▏ | 576/5000 [00:14<00:50, 88.04it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 12%|█▏ | 608/5000 [00:14<00:49, 88.23it/s]
r_training: 0.3596, rewards: 0.1000, total_rewards: 0.9434: 12%|█▏ | 608/5000 [00:14<00:49, 88.23it/s]
r_training: 0.3596, rewards: 0.1000, total_rewards: 0.9434: 13%|█▎ | 640/5000 [00:14<00:49, 88.38it/s]
r_training: 0.3718, rewards: 0.1000, total_rewards: 0.9434: 13%|█▎ | 640/5000 [00:14<00:49, 88.38it/s]
r_training: 0.3718, rewards: 0.1000, total_rewards: 0.9434: 13%|█▎ | 672/5000 [00:15<00:47, 90.20it/s]
r_training: 0.3566, rewards: 0.1000, total_rewards: 0.9434: 13%|█▎ | 672/5000 [00:15<00:47, 90.20it/s]
r_training: 0.3566, rewards: 0.1000, total_rewards: 0.9434: 14%|█▍ | 704/5000 [00:15<00:48, 88.22it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 14%|█▍ | 704/5000 [00:15<00:48, 88.22it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 15%|█▍ | 736/5000 [00:15<00:48, 88.05it/s]
r_training: 0.3960, rewards: 0.1000, total_rewards: 0.9434: 15%|█▍ | 736/5000 [00:15<00:48, 88.05it/s]
r_training: 0.3960, rewards: 0.1000, total_rewards: 0.9434: 15%|█▌ | 768/5000 [00:16<00:47, 88.59it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 15%|█▌ | 768/5000 [00:16<00:47, 88.59it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 16%|█▌ | 800/5000 [00:16<00:46, 89.47it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434: 16%|█▌ | 800/5000 [00:16<00:46, 89.47it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434: 17%|█▋ | 832/5000 [00:16<00:47, 88.51it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434: 17%|█▋ | 832/5000 [00:16<00:47, 88.51it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434: 17%|█▋ | 864/5000 [00:17<00:45, 91.01it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 17%|█▋ | 864/5000 [00:17<00:45, 91.01it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 18%|█▊ | 896/5000 [00:17<00:46, 89.03it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 18%|█▊ | 896/5000 [00:17<00:46, 89.03it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 19%|█▊ | 928/5000 [00:17<00:46, 88.35it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434: 19%|█▊ | 928/5000 [00:17<00:46, 88.35it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434: 19%|█▉ | 960/5000 [00:18<00:45, 88.13it/s]
r_training: 0.3292, rewards: 0.1000, total_rewards: 0.9434: 19%|█▉ | 960/5000 [00:18<00:45, 88.13it/s]
r_training: 0.3292, rewards: 0.1000, total_rewards: 0.9434: 20%|█▉ | 992/5000 [00:18<00:44, 89.57it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 20%|█▉ | 992/5000 [00:18<00:44, 89.57it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 20%|██ | 1024/5000 [00:19<00:45, 88.35it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 20%|██ | 1024/5000 [00:19<00:45, 88.35it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 21%|██ | 1056/5000 [00:19<00:45, 87.62it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 21%|██ | 1056/5000 [00:19<00:45, 87.62it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 22%|██▏ | 1088/5000 [00:19<00:42, 91.60it/s]
r_training: 0.3899, rewards: 0.1000, total_rewards: 0.9434: 22%|██▏ | 1088/5000 [00:19<00:42, 91.60it/s]
r_training: 0.3899, rewards: 0.1000, total_rewards: 0.9434: 22%|██▏ | 1120/5000 [00:20<00:41, 93.01it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 22%|██▏ | 1120/5000 [00:20<00:41, 93.01it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 23%|██▎ | 1152/5000 [00:20<00:40, 95.24it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 23%|██▎ | 1152/5000 [00:20<00:40, 95.24it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 24%|██▎ | 1184/5000 [00:20<00:41, 92.75it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 24%|██▎ | 1184/5000 [00:20<00:41, 92.75it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 24%|██▍ | 1216/5000 [00:21<00:41, 90.23it/s]
r_training: 0.3445, rewards: 0.1000, total_rewards: 0.9434: 24%|██▍ | 1216/5000 [00:21<00:41, 90.23it/s]
r_training: 0.3445, rewards: 0.1000, total_rewards: 0.9434: 25%|██▍ | 1248/5000 [00:21<00:42, 87.78it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 25%|██▍ | 1248/5000 [00:21<00:42, 87.78it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 26%|██▌ | 1280/5000 [00:21<00:42, 87.94it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434: 26%|██▌ | 1280/5000 [00:21<00:42, 87.94it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434: 26%|██▌ | 1312/5000 [00:22<00:41, 88.97it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 26%|██▌ | 1312/5000 [00:22<00:41, 88.97it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 27%|██▋ | 1344/5000 [00:22<00:40, 90.44it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 27%|██▋ | 1344/5000 [00:22<00:40, 90.44it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 28%|██▊ | 1376/5000 [00:22<00:40, 88.57it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 28%|██▊ | 1376/5000 [00:22<00:40, 88.57it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 28%|██▊ | 1408/5000 [00:23<00:40, 88.32it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 28%|██▊ | 1408/5000 [00:23<00:40, 88.32it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 29%|██▉ | 1440/5000 [00:23<00:40, 88.41it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 29%|██▉ | 1440/5000 [00:23<00:40, 88.41it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 29%|██▉ | 1472/5000 [00:24<00:40, 88.02it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 29%|██▉ | 1472/5000 [00:24<00:40, 88.02it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 30%|███ | 1504/5000 [00:24<00:40, 87.23it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 30%|███ | 1504/5000 [00:24<00:40, 87.23it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 31%|███ | 1536/5000 [00:24<00:39, 86.72it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 31%|███ | 1536/5000 [00:24<00:39, 86.72it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 31%|███▏ | 1568/5000 [00:25<00:39, 87.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 31%|███▏ | 1568/5000 [00:25<00:39, 87.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 32%|███▏ | 1600/5000 [00:25<00:39, 86.43it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 32%|███▏ | 1600/5000 [00:25<00:39, 86.43it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 33%|███▎ | 1632/5000 [00:25<00:38, 86.66it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 33%|███▎ | 1632/5000 [00:25<00:38, 86.66it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 33%|███▎ | 1664/5000 [00:26<00:38, 87.17it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 33%|███▎ | 1664/5000 [00:26<00:38, 87.17it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 34%|███▍ | 1696/5000 [00:26<00:37, 89.14it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 34%|███▍ | 1696/5000 [00:26<00:37, 89.14it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 35%|███▍ | 1728/5000 [00:26<00:36, 89.07it/s]
r_training: 0.3718, rewards: 0.1000, total_rewards: 0.9434: 35%|███▍ | 1728/5000 [00:26<00:36, 89.07it/s]
r_training: 0.3718, rewards: 0.1000, total_rewards: 0.9434: 35%|███▌ | 1760/5000 [00:27<00:35, 91.73it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 35%|███▌ | 1760/5000 [00:27<00:35, 91.73it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 36%|███▌ | 1792/5000 [00:27<00:35, 91.36it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 36%|███▌ | 1792/5000 [00:27<00:35, 91.36it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 36%|███▋ | 1824/5000 [00:28<00:35, 88.62it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 36%|███▋ | 1824/5000 [00:28<00:35, 88.62it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 37%|███▋ | 1856/5000 [00:28<00:35, 88.46it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 37%|███▋ | 1856/5000 [00:28<00:35, 88.46it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 38%|███▊ | 1888/5000 [00:28<00:33, 92.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 38%|███▊ | 1888/5000 [00:28<00:33, 92.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 38%|███▊ | 1920/5000 [00:29<00:32, 94.75it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 38%|███▊ | 1920/5000 [00:29<00:32, 94.75it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 39%|███▉ | 1952/5000 [00:29<00:32, 94.41it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 39%|███▉ | 1952/5000 [00:29<00:32, 94.41it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 40%|███▉ | 1984/5000 [00:29<00:31, 94.62it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 40%|███▉ | 1984/5000 [00:29<00:31, 94.62it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 40%|████ | 2016/5000 [00:30<00:31, 94.09it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 40%|████ | 2016/5000 [00:30<00:31, 94.09it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 41%|████ | 2048/5000 [00:30<00:31, 93.61it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 41%|████ | 2048/5000 [00:30<00:31, 93.61it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 42%|████▏ | 2080/5000 [00:30<00:31, 94.07it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 42%|████▏ | 2080/5000 [00:30<00:31, 94.07it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 42%|████▏ | 2112/5000 [00:31<00:30, 95.79it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 42%|████▏ | 2112/5000 [00:31<00:30, 95.79it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 43%|████▎ | 2144/5000 [00:31<00:31, 91.76it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 43%|████▎ | 2144/5000 [00:31<00:31, 91.76it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 44%|████▎ | 2176/5000 [00:31<00:30, 93.25it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 44%|████▎ | 2176/5000 [00:31<00:30, 93.25it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 44%|████▍ | 2208/5000 [00:32<00:29, 95.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 44%|████▍ | 2208/5000 [00:32<00:29, 95.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 45%|████▍ | 2240/5000 [00:32<00:29, 93.00it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434: 45%|████▍ | 2240/5000 [00:32<00:29, 93.00it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434: 45%|████▌ | 2272/5000 [00:32<00:28, 94.99it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 45%|████▌ | 2272/5000 [00:32<00:28, 94.99it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 46%|████▌ | 2304/5000 [00:33<00:28, 95.30it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 46%|████▌ | 2304/5000 [00:33<00:28, 95.30it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 47%|████▋ | 2336/5000 [00:33<00:28, 94.29it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 47%|████▋ | 2336/5000 [00:33<00:28, 94.29it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 47%|████▋ | 2368/5000 [00:33<00:28, 92.67it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 47%|████▋ | 2368/5000 [00:33<00:28, 92.67it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 48%|████▊ | 2400/5000 [00:34<00:28, 91.16it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 48%|████▊ | 2400/5000 [00:34<00:28, 91.16it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 49%|████▊ | 2432/5000 [00:34<00:28, 91.36it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434: 49%|████▊ | 2432/5000 [00:34<00:28, 91.36it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434: 49%|████▉ | 2464/5000 [00:34<00:27, 92.58it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 49%|████▉ | 2464/5000 [00:34<00:27, 92.58it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 50%|████▉ | 2496/5000 [00:35<00:27, 90.30it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 50%|████▉ | 2496/5000 [00:35<00:27, 90.30it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 51%|█████ | 2528/5000 [00:35<00:27, 89.95it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 51%|█████ | 2528/5000 [00:35<00:27, 89.95it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 51%|█████ | 2560/5000 [00:35<00:27, 87.80it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 51%|█████ | 2560/5000 [00:35<00:27, 87.80it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 52%|█████▏ | 2592/5000 [00:36<00:26, 89.57it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 52%|█████▏ | 2592/5000 [00:36<00:26, 89.57it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 52%|█████▏ | 2624/5000 [00:36<00:26, 90.43it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 52%|█████▏ | 2624/5000 [00:36<00:26, 90.43it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 53%|█████▎ | 2656/5000 [00:37<00:26, 88.07it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 53%|█████▎ | 2656/5000 [00:37<00:26, 88.07it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 54%|█████▍ | 2688/5000 [00:37<00:26, 86.12it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 54%|█████▍ | 2688/5000 [00:37<00:26, 86.12it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 54%|█████▍ | 2720/5000 [00:37<00:26, 86.06it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 54%|█████▍ | 2720/5000 [00:37<00:26, 86.06it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 55%|█████▌ | 2752/5000 [00:38<00:25, 88.22it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 55%|█████▌ | 2752/5000 [00:38<00:25, 88.22it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 56%|█████▌ | 2784/5000 [00:38<00:24, 90.21it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 56%|█████▌ | 2784/5000 [00:38<00:24, 90.21it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 56%|█████▋ | 2816/5000 [00:38<00:23, 92.05it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 56%|█████▋ | 2816/5000 [00:38<00:23, 92.05it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 57%|█████▋ | 2848/5000 [00:39<00:24, 89.07it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 57%|█████▋ | 2848/5000 [00:39<00:24, 89.07it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 58%|█████▊ | 2880/5000 [00:39<00:23, 90.04it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 58%|█████▊ | 2880/5000 [00:39<00:23, 90.04it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 58%|█████▊ | 2912/5000 [00:39<00:22, 92.89it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 58%|█████▊ | 2912/5000 [00:39<00:22, 92.89it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 59%|█████▉ | 2944/5000 [00:40<00:22, 90.57it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 59%|█████▉ | 2944/5000 [00:40<00:22, 90.57it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 60%|█████▉ | 2976/5000 [00:40<00:22, 89.29it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 60%|█████▉ | 2976/5000 [00:40<00:22, 89.29it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 60%|██████ | 3008/5000 [00:40<00:21, 91.42it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 0.9434: 60%|██████ | 3008/5000 [00:40<00:21, 91.42it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 0.9434: 61%|██████ | 3040/5000 [00:41<00:21, 90.81it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 61%|██████ | 3040/5000 [00:41<00:21, 90.81it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 61%|██████▏ | 3072/5000 [00:41<00:21, 91.65it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 61%|██████▏ | 3072/5000 [00:41<00:21, 91.65it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 62%|██████▏ | 3104/5000 [00:41<00:20, 93.16it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 62%|██████▏ | 3104/5000 [00:41<00:20, 93.16it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 63%|██████▎ | 3136/5000 [00:42<00:19, 94.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 63%|██████▎ | 3136/5000 [00:42<00:19, 94.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 63%|██████▎ | 3168/5000 [00:42<00:19, 93.29it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 63%|██████▎ | 3168/5000 [00:42<00:19, 93.29it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 64%|██████▍ | 3200/5000 [00:42<00:18, 94.90it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 64%|██████▍ | 3200/5000 [00:42<00:18, 94.90it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 65%|██████▍ | 3232/5000 [00:50<02:12, 13.38it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556: 65%|██████▍ | 3232/5000 [00:50<02:12, 13.38it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556: 65%|██████▌ | 3264/5000 [00:50<01:37, 17.83it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 65%|██████▌ | 3264/5000 [00:50<01:37, 17.83it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 66%|██████▌ | 3296/5000 [00:50<01:13, 23.32it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 66%|██████▌ | 3296/5000 [00:50<01:13, 23.32it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 67%|██████▋ | 3328/5000 [00:51<00:55, 29.87it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 67%|██████▋ | 3328/5000 [00:51<00:55, 29.87it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 67%|██████▋ | 3360/5000 [00:51<00:44, 36.95it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 67%|██████▋ | 3360/5000 [00:51<00:44, 36.95it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 68%|██████▊ | 3392/5000 [00:52<00:36, 44.30it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 68%|██████▊ | 3392/5000 [00:52<00:36, 44.30it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 68%|██████▊ | 3424/5000 [00:52<00:30, 51.78it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 68%|██████▊ | 3424/5000 [00:52<00:30, 51.78it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 69%|██████▉ | 3456/5000 [00:52<00:26, 58.47it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 69%|██████▉ | 3456/5000 [00:52<00:26, 58.47it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 70%|██████▉ | 3488/5000 [00:53<00:23, 64.95it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556: 70%|██████▉ | 3488/5000 [00:53<00:23, 64.95it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556: 70%|███████ | 3520/5000 [00:53<00:21, 69.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 70%|███████ | 3520/5000 [00:53<00:21, 69.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 71%|███████ | 3552/5000 [00:54<00:19, 72.87it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 71%|███████ | 3552/5000 [00:54<00:19, 72.87it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 72%|███████▏ | 3584/5000 [00:54<00:18, 75.22it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 72%|███████▏ | 3584/5000 [00:54<00:18, 75.22it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 72%|███████▏ | 3616/5000 [00:54<00:17, 78.72it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 72%|███████▏ | 3616/5000 [00:54<00:17, 78.72it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 73%|███████▎ | 3648/5000 [00:55<00:16, 80.90it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556: 73%|███████▎ | 3648/5000 [00:55<00:16, 80.90it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556: 74%|███████▎ | 3680/5000 [00:55<00:15, 82.97it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 74%|███████▎ | 3680/5000 [00:55<00:15, 82.97it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 74%|███████▍ | 3712/5000 [00:55<00:15, 84.10it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 74%|███████▍ | 3712/5000 [00:55<00:15, 84.10it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 75%|███████▍ | 3744/5000 [00:56<00:14, 83.85it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 75%|███████▍ | 3744/5000 [00:56<00:14, 83.85it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 76%|███████▌ | 3776/5000 [00:56<00:14, 84.22it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 76%|███████▌ | 3776/5000 [00:56<00:14, 84.22it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 76%|███████▌ | 3808/5000 [00:56<00:14, 84.97it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 76%|███████▌ | 3808/5000 [00:56<00:14, 84.97it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 77%|███████▋ | 3840/5000 [00:57<00:13, 85.76it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 77%|███████▋ | 3840/5000 [00:57<00:13, 85.76it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 77%|███████▋ | 3872/5000 [00:57<00:12, 88.48it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 77%|███████▋ | 3872/5000 [00:57<00:12, 88.48it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 78%|███████▊ | 3904/5000 [00:58<00:12, 89.26it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 78%|███████▊ | 3904/5000 [00:58<00:12, 89.26it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 79%|███████▊ | 3936/5000 [00:58<00:11, 90.04it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 79%|███████▊ | 3936/5000 [00:58<00:11, 90.04it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 79%|███████▉ | 3968/5000 [00:58<00:11, 89.77it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556: 79%|███████▉ | 3968/5000 [00:58<00:11, 89.77it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556: 80%|████████ | 4000/5000 [00:59<00:11, 89.92it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 5.5556: 80%|████████ | 4000/5000 [00:59<00:11, 89.92it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 5.5556: 81%|████████ | 4032/5000 [00:59<00:10, 88.32it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 81%|████████ | 4032/5000 [00:59<00:10, 88.32it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 81%|████████▏ | 4064/5000 [00:59<00:10, 86.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 81%|████████▏ | 4064/5000 [00:59<00:10, 86.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 82%|████████▏ | 4096/5000 [01:00<00:10, 84.34it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 82%|████████▏ | 4096/5000 [01:00<00:10, 84.34it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 83%|████████▎ | 4128/5000 [01:00<00:10, 85.10it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 5.5556: 83%|████████▎ | 4128/5000 [01:00<00:10, 85.10it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 5.5556: 83%|████████▎ | 4160/5000 [01:00<00:09, 87.36it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 83%|████████▎ | 4160/5000 [01:00<00:09, 87.36it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 84%|████████▍ | 4192/5000 [01:01<00:08, 90.26it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 84%|████████▍ | 4192/5000 [01:01<00:08, 90.26it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 84%|████████▍ | 4224/5000 [01:01<00:08, 92.19it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 84%|████████▍ | 4224/5000 [01:01<00:08, 92.19it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 85%|████████▌ | 4256/5000 [01:01<00:08, 92.86it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 5.5556: 85%|████████▌ | 4256/5000 [01:01<00:08, 92.86it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 5.5556: 86%|████████▌ | 4288/5000 [01:02<00:07, 90.60it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 86%|████████▌ | 4288/5000 [01:02<00:07, 90.60it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 86%|████████▋ | 4320/5000 [01:02<00:07, 92.06it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 86%|████████▋ | 4320/5000 [01:02<00:07, 92.06it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 87%|████████▋ | 4352/5000 [01:03<00:06, 93.60it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 87%|████████▋ | 4352/5000 [01:03<00:06, 93.60it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 88%|████████▊ | 4384/5000 [01:03<00:06, 92.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 88%|████████▊ | 4384/5000 [01:03<00:06, 92.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 88%|████████▊ | 4416/5000 [01:03<00:06, 92.21it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 88%|████████▊ | 4416/5000 [01:03<00:06, 92.21it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 89%|████████▉ | 4448/5000 [01:04<00:05, 92.90it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 89%|████████▉ | 4448/5000 [01:04<00:05, 92.90it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 90%|████████▉ | 4480/5000 [01:04<00:05, 89.09it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 90%|████████▉ | 4480/5000 [01:04<00:05, 89.09it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 90%|█████████ | 4512/5000 [01:04<00:05, 87.57it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 5.5556: 90%|█████████ | 4512/5000 [01:04<00:05, 87.57it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 5.5556: 91%|█████████ | 4544/5000 [01:05<00:05, 87.44it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 5.5556: 91%|█████████ | 4544/5000 [01:05<00:05, 87.44it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 5.5556: 92%|█████████▏| 4576/5000 [01:05<00:04, 86.47it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 92%|█████████▏| 4576/5000 [01:05<00:04, 86.47it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 92%|█████████▏| 4608/5000 [01:05<00:04, 85.84it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 92%|█████████▏| 4608/5000 [01:05<00:04, 85.84it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 93%|█████████▎| 4640/5000 [01:06<00:04, 85.98it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 93%|█████████▎| 4640/5000 [01:06<00:04, 85.98it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 93%|█████████▎| 4672/5000 [01:06<00:03, 87.97it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 5.5556: 93%|█████████▎| 4672/5000 [01:06<00:03, 87.97it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 5.5556: 94%|█████████▍| 4704/5000 [01:07<00:03, 87.37it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 5.5556: 94%|█████████▍| 4704/5000 [01:07<00:03, 87.37it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 5.5556: 95%|█████████▍| 4736/5000 [01:07<00:02, 90.46it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 95%|█████████▍| 4736/5000 [01:07<00:02, 90.46it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 95%|█████████▌| 4768/5000 [01:07<00:02, 93.67it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 95%|█████████▌| 4768/5000 [01:07<00:02, 93.67it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 96%|█████████▌| 4800/5000 [01:08<00:02, 93.89it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556: 96%|█████████▌| 4800/5000 [01:08<00:02, 93.89it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556: 97%|█████████▋| 4832/5000 [01:08<00:01, 91.12it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 97%|█████████▋| 4832/5000 [01:08<00:01, 91.12it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 97%|█████████▋| 4864/5000 [01:08<00:01, 89.91it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 97%|█████████▋| 4864/5000 [01:08<00:01, 89.91it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 98%|█████████▊| 4896/5000 [01:09<00:01, 88.75it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 98%|█████████▊| 4896/5000 [01:09<00:01, 88.75it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 99%|█████████▊| 4928/5000 [01:09<00:00, 90.36it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 5.5556: 99%|█████████▊| 4928/5000 [01:09<00:00, 90.36it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 5.5556: 99%|█████████▉| 4960/5000 [01:09<00:00, 93.68it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 99%|█████████▉| 4960/5000 [01:09<00:00, 93.68it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 100%|█████████▉| 4992/5000 [01:10<00:00, 91.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 100%|█████████▉| 4992/5000 [01:10<00:00, 91.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: : 5024it [01:10, 90.51it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 5.5556: : 5024it [01:10, 90.51it/s]
我们现在可以使用结果快速检查 CSV。
def print_csv_files_in_folder(folder_path):
"""
Find all CSV files in a folder and prints the first 10 lines of each file.
Args:
folder_path (str): The relative path to the folder.
"""
csv_files = []
output_str = ""
for dirpath, _, filenames in os.walk(folder_path):
for file in filenames:
if file.endswith(".csv"):
csv_files.append(os.path.join(dirpath, file))
for csv_file in csv_files:
output_str += f"File: {csv_file}\n"
with open(csv_file, "r") as f:
for i, line in enumerate(f):
if i == 10:
break
output_str += line.strip() + "\n"
output_str += "\n"
print(output_str)
print_csv_files_in_folder(logger.experiment.log_dir)
File: /tmp/tmp16y9hbib/dqn_exp_a0ee01f2-90c8-11ef-a49b-0242ac110002/scalars/r_training.csv
512,0.3566153347492218
1024,0.39912936091423035
1536,0.39912936091423035
2048,0.39912936091423035
2560,0.42945271730422974
3072,0.40213119983673096
3584,0.39912933111190796
4096,0.42945271730422974
4608,0.42945271730422974
File: /tmp/tmp16y9hbib/dqn_exp_a0ee01f2-90c8-11ef-a49b-0242ac110002/scalars/optim_steps.csv
512,128.0
1024,256.0
1536,384.0
2048,512.0
2560,640.0
3072,768.0
3584,896.0
4096,1024.0
4608,1152.0
File: /tmp/tmp16y9hbib/dqn_exp_a0ee01f2-90c8-11ef-a49b-0242ac110002/scalars/loss.csv
512,0.47876793146133423
1024,0.18667784333229065
1536,0.1948033571243286
2048,0.22345909476280212
2560,0.2145865112543106
3072,0.47586697340011597
3584,0.28343674540519714
4096,0.3203103542327881
4608,0.3053428530693054
File: /tmp/tmp16y9hbib/dqn_exp_a0ee01f2-90c8-11ef-a49b-0242ac110002/scalars/grad_norm_0.csv
512,5.5816755294799805
1024,2.9089717864990234
1536,3.4687838554382324
2048,2.8756051063537598
2560,2.7815587520599365
3072,6.685841083526611
3584,3.793360948562622
4096,3.469670295715332
4608,3.317387104034424
File: /tmp/tmp16y9hbib/dqn_exp_a0ee01f2-90c8-11ef-a49b-0242ac110002/scalars/rewards.csv
3232,0.10000000894069672
File: /tmp/tmp16y9hbib/dqn_exp_a0ee01f2-90c8-11ef-a49b-0242ac110002/scalars/total_rewards.csv
3232,5.555555820465088
结论和可能的改进¶
在本教程中,我们学习了:
如何编写 Trainer,包括构建其组件和注册 他们在教练中;
如何对 DQN 算法进行编码,包括如何创建选取 将具有最高值的动作上
QValueNetwork
;如何构建多进程数据收集器;
本教程的可能改进可能包括:
也可以使用 Prioritized replay buffer。这将得到一个 对于具有最差值准确度的样本,优先级更高。 在文档的重放缓冲区部分了解更多信息。
脚本总运行时间:(2 分 40.957 秒)
估计内存使用量:1267 MB