注意
转到末尾 以下载完整示例代码。
开始使用您自己的第一个训练循环¶
作者: Vincent Moens
注意
如需在笔记本中运行本教程,请在开头添加一个安装单元,内容为:
!pip install tensordict !pip install torchrl
是时候总结我们在本“入门”系列中迄今所学的全部内容了!
在本教程中,我们将仅使用之前课程中介绍的组件,编写最基础的训练循环。
我们将使用DQN与CartPole环境作为典型示例。
我们将自愿将说明性文字精简至最少,仅将各部分内容链接至相关教程。
构建环境¶
我们将使用一个带有 StepCounter 变换的 Gym 环境。如果您需要复习相关内容,请参阅 环境教程 中介绍的这些特性。
import torch
torch.manual_seed(0)
import time
from torchrl.envs import GymEnv, StepCounter, TransformedEnv
env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter())
env.set_seed(0)
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
设计策略¶
下一步是构建我们的策略。 我们将创建一个常规的、确定性的执行器版本,用于 损失模块以及 评估阶段。 接下来,我们将为其添加一个探索模块,以支持 推理。
from torchrl.modules import EGreedyModule, MLP, QValueModule
value_mlp = MLP(out_features=env.action_spec.shape[-1], num_cells=[64, 64])
value_net = Mod(value_mlp, in_keys=["observation"], out_keys=["action_value"])
policy = Seq(value_net, QValueModule(spec=env.action_spec))
exploration_module = EGreedyModule(
env.action_spec, annealing_num_steps=100_000, eps_init=0.5
)
policy_explore = Seq(policy, exploration_module)
数据收集器和重放缓冲区¶
数据部分如下:我们需要一个 数据收集器来轻松获取数据批次, 以及一个经验回放缓冲区来存储这些数据以供训练使用。
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, ReplayBuffer
init_rand_steps = 5000
frames_per_batch = 100
optim_steps = 10
collector = SyncDataCollector(
env,
policy,
frames_per_batch=frames_per_batch,
total_frames=-1,
init_random_frames=init_rand_steps,
)
rb = ReplayBuffer(storage=LazyTensorStorage(100_000))
from torch.optim import Adam
损失模块和优化器¶
我们按照专用教程中所述构建损失函数,并配置其优化器和目标参数更新器:
from torchrl.objectives import DQNLoss, SoftUpdate
loss = DQNLoss(value_network=policy, action_space=env.action_spec, delay_value=True)
optim = Adam(loss.parameters(), lr=0.02)
updater = SoftUpdate(loss, eps=0.99)
日志记录器¶
我们将使用CSV记录器来记录我们的结果,并保存渲染的视频。
from torchrl._utils import logger as torchrl_logger
from torchrl.record import CSVLogger, VideoRecorder
path = "./training_loop"
logger = CSVLogger(exp_name="dqn", log_dir=path, video_format="mp4")
video_recorder = VideoRecorder(logger, tag="video")
record_env = TransformedEnv(
GymEnv("CartPole-v1", from_pixels=True, pixels_only=False), video_recorder
)
训练循环¶
我们不会固定训练迭代次数,而是持续训练网络,直至其达到某一性能指标(此处任意定义为在环境中运行200步——对于CartPole任务,成功即定义为获得更长的轨迹)。
total_count = 0
total_episodes = 0
t0 = time.time()
for i, data in enumerate(collector):
# Write data in replay buffer
rb.extend(data)
max_length = rb[:]["next", "step_count"].max()
if len(rb) > init_rand_steps:
# Optim loop (we do several optim steps
# per batch collected for efficiency)
for _ in range(optim_steps):
sample = rb.sample(128)
loss_vals = loss(sample)
loss_vals["loss"].backward()
optim.step()
optim.zero_grad()
# Update exploration factor
exploration_module.step(data.numel())
# Update target params
updater.step()
if i % 10:
torchrl_logger.info(f"Max num steps: {max_length}, rb length {len(rb)}")
total_count += data.numel()
total_episodes += data["next", "done"].sum()
if max_length > 200:
break
t1 = time.time()
torchrl_logger.info(
f"solved after {total_count} steps, {total_episodes} episodes and in {t1-t0}s."
)
渲染¶
最后,我们尽可能多地运行环境步数,并将视频本地保存(请注意,此时并未进行探索)。
record_env.rollout(max_steps=1000, policy=policy)
video_recorder.dump()
这是完成完整训练循环后,您渲染出的 CartPole 视频效果:

本系列“TorchRL 入门”教程到此结束! 欢迎在 GitHub 上分享您对本教程的反馈。
脚本总运行时间: (0 分钟 20.681 秒)
估计内存使用量: 165 MB