目录

导出TorchRL模块

作者: Vincent Moens

注意

如需在笔记本中运行本教程,请在开头添加一个安装单元,内容为:

!pip install tensordict
!pip install torchrl
!pip install "gymnasium[atari,accept-rom-license]"<1.0.0

介绍

学习策略的价值有限,如果该策略无法在现实环境中部署。 正如其他教程所示,TorchRL 强调模块化和可组合性:由于 tensordict, 库的组件可以通过抽象其签名为一组对输入 TensorDict 的操作来以最通用的方式编写。 这可能会给人留下印象,即该库仅限于用于训练,因为典型的低级执行硬件(边缘设备、机器人、Arduino、Raspberry Pi)不执行 Python 代码,更不用说安装 PyTorch、tensordict 或 torchrl 了。

幸运的是,PyTorch 提供了一整套完整的解决方案,用于将代码和训练好的模型导出至各类设备与硬件;而 TorchRL 已全面适配并支持与该生态系统的交互。 您可以从多种后端中进行选择,包括本教程中示例所用的 ONNX 或 AOTInductor。 本教程简要概述了如何将一个已训练的模型独立封装,并作为独立可执行文件导出至硬件设备。

核心收获:

  • 训练后导出任何 TorchRL 模块;

  • 使用各种后端;

  • 测试你导出的模型。

快速回顾:一个简单的TorchRL训练循环

在本节中,我们将复现上一篇“入门教程”中的训练循环,并稍作调整,使其适用于 gymnasium 库所渲染的 Atari 游戏。 我们将沿用 DQN 示例,并在后续展示如何使用一种输出值分布的策略。

import time
from pathlib import Path

import numpy as np

import torch

from tensordict.nn import (
    TensorDictModule as Mod,
    TensorDictSequential,
    TensorDictSequential as Seq,
)

from torch.optim import Adam

from torchrl._utils import timeit
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, ReplayBuffer

from torchrl.envs import (
    Compose,
    GrayScale,
    GymEnv,
    Resize,
    set_exploration_type,
    StepCounter,
    ToTensorImage,
    TransformedEnv,
)

from torchrl.modules import ConvNet, EGreedyModule, QValueModule

from torchrl.objectives import DQNLoss, SoftUpdate

torch.manual_seed(0)

env = TransformedEnv(
    GymEnv("ALE/Pong-v5", categorical_action_encoding=True),
    Compose(
        ToTensorImage(), Resize(84, interpolation="nearest"), GrayScale(), StepCounter()
    ),
)
env.set_seed(0)

value_mlp = ConvNet.default_atari_dqn(num_actions=env.action_spec.space.n)
value_net = Mod(value_mlp, in_keys=["pixels"], out_keys=["action_value"])
policy = Seq(value_net, QValueModule(spec=env.action_spec))
exploration_module = EGreedyModule(
    env.action_spec, annealing_num_steps=100_000, eps_init=0.5
)
policy_explore = Seq(policy, exploration_module)

init_rand_steps = 5000
frames_per_batch = 100
optim_steps = 10
collector = SyncDataCollector(
    env,
    policy_explore,
    frames_per_batch=frames_per_batch,
    total_frames=-1,
    init_random_frames=init_rand_steps,
)
rb = ReplayBuffer(storage=LazyTensorStorage(100_000))

loss = DQNLoss(value_network=policy, action_space=env.action_spec, delay_value=True)
optim = Adam(loss.parameters())
updater = SoftUpdate(loss, eps=0.99)

total_count = 0
total_episodes = 0
t0 = time.time()
for data in collector:
    # Write data in replay buffer
    rb.extend(data)
    max_length = rb[:]["next", "step_count"].max()
    if len(rb) > init_rand_steps:
        # Optim loop (we do several optim steps
        # per batch collected for efficiency)
        for _ in range(optim_steps):
            sample = rb.sample(128)
            loss_vals = loss(sample)
            loss_vals["loss"].backward()
            optim.step()
            optim.zero_grad()
            # Update exploration factor
            exploration_module.step(data.numel())
            # Update target params
            updater.step()
            total_count += data.numel()
            total_episodes += data["next", "done"].sum()
    if max_length > 200:
        break

导出基于TensorDictModule的策略

TensorDict 允许我们构建一个具有高度灵活性的策略:从一个常规的 Module,它 根据观察输出动作值,我们添加了一个 QValueModule 模块,该 模块读取这些值并使用某种启发式方法(例如,argmax调用)计算动作。

然而,我们的情况中存在一个微小的技术难点:环境(即实际的 Atari 游戏)返回的并非灰度、84×84 的图像,而是原始屏幕尺寸的彩色图像。我们附加到环境上的变换操作确保了这些图像可被模型读取。从训练角度看,环境与模型之间的边界较为模糊;但在执行阶段,这一边界则清晰得多:模型应负责将输入数据(图像)转换为卷积神经网络(CNN)可处理的格式。

再次,tensordict的魔力将为我们解困:实际上,TorchRL 的大多数本地(非递归)变换既可以作为环境变换使用,也可以作为 Module 实例中的预处理块。让我们看看如何将它们添加到我们的策略中:

policy_transform = TensorDictSequential(
    env.transform[
        :-1
    ],  # the last transform is a step counter which we don't need for preproc
    policy_explore.requires_grad_(
        False
    ),  # Using the explorative version of the policy for didactic purposes, see below.
)

我们创建一个假输入,并将其传递给 export() 与策略。这将给出一个“原始”的Python函数,该函数将读取我们的输入张量并输出一个动作,而不引用TorchRL或tensordict模块。

一个好的做法是调用select_out_keys()来让模型知道我们只需要一组输出(如果策略返回多个张量)。

fake_td = env.base_env.fake_tensordict()
pixels = fake_td["pixels"]
with set_exploration_type("DETERMINISTIC"):
    exported_policy = torch.export.export(
        # Select only the "action" output key
        policy_transform.select_out_keys("action"),
        args=(),
        kwargs={"pixels": pixels},
        strict=False,
    )

策略的表示方式非常具有启发性:我们可以看到,前几个操作是 permute(置换)、div(除法)、unsqueeze(升维)、resize(调整尺寸),然后是卷积层和多层感知机(MLP)层。

print("Deterministic policy")
exported_policy.graph_module.print_readable()
Deterministic policy
class GraphModule(torch.nn.Module):
    def forward(self, p_module_1_module_0_module_0_module_0_0_weight: "f32[32, 1, 8, 8]", p_module_1_module_0_module_0_module_0_0_bias: "f32[32]", p_module_1_module_0_module_0_module_0_2_weight: "f32[64, 32, 4, 4]", p_module_1_module_0_module_0_module_0_2_bias: "f32[64]", p_module_1_module_0_module_0_module_0_4_weight: "f32[64, 64, 3, 3]", p_module_1_module_0_module_0_module_0_4_bias: "f32[64]", p_module_1_module_0_module_0_module_1_0_weight: "f32[512, 3136]", p_module_1_module_0_module_0_module_1_0_bias: "f32[512]", p_module_1_module_0_module_0_module_1_2_weight: "f32[6, 512]", p_module_1_module_0_module_0_module_1_2_bias: "f32[6]", b_module_1_module_1_eps_init: "f32[]", b_module_1_module_1_eps_end: "f32[]", b_module_1_module_1_eps: "f32[]", kwargs_pixels: "u8[210, 160, 3]"):
         # File: /pytorch/rl/torchrl/envs/transforms/transforms.py:308 in forward, code: data = self._apply_transform(data)
        permute: "u8[3, 210, 160]" = torch.ops.aten.permute.default(kwargs_pixels, [-1, -3, -2]);  kwargs_pixels = None
        div: "f32[3, 210, 160]" = torch.ops.aten.div.Tensor(permute, 255);  permute = None
        to: "f32[3, 210, 160]" = torch.ops.aten.to.dtype(div, torch.float32);  div = None
        unsqueeze: "f32[1, 3, 210, 160]" = torch.ops.aten.unsqueeze.default(to, 0);  to = None
        upsample_nearest2d: "f32[1, 3, 84, 84]" = torch.ops.aten.upsample_nearest2d.vec(unsqueeze, [84, 84], None);  unsqueeze = None
        squeeze: "f32[3, 84, 84]" = torch.ops.aten.squeeze.dim(upsample_nearest2d, 0);  upsample_nearest2d = None
        unbind = torch.ops.aten.unbind.int(squeeze, -3);  squeeze = None
        getitem: "f32[84, 84]" = unbind[0]
        getitem_1: "f32[84, 84]" = unbind[1]
        getitem_2: "f32[84, 84]" = unbind[2];  unbind = None
        mul: "f32[84, 84]" = torch.ops.aten.mul.Tensor(getitem, 0.2989);  getitem = None
        mul_1: "f32[84, 84]" = torch.ops.aten.mul.Tensor(getitem_1, 0.587);  getitem_1 = None
        add: "f32[84, 84]" = torch.ops.aten.add.Tensor(mul, mul_1);  mul = mul_1 = None
        mul_2: "f32[84, 84]" = torch.ops.aten.mul.Tensor(getitem_2, 0.114);  getitem_2 = None
        add_1: "f32[84, 84]" = torch.ops.aten.add.Tensor(add, mul_2);  add = mul_2 = None
        to_1: "f32[84, 84]" = torch.ops.aten.to.dtype(add_1, torch.float32);  add_1 = None
        unsqueeze_1: "f32[1, 84, 84]" = torch.ops.aten.unsqueeze.default(to_1, -3);  to_1 = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/conv.py:554 in forward, code: return self._conv_forward(input, self.weight, self.bias)
        conv2d: "f32[32, 20, 20]" = torch.ops.aten.conv2d.default(unsqueeze_1, p_module_1_module_0_module_0_module_0_0_weight, p_module_1_module_0_module_0_module_0_0_bias, [4, 4]);  unsqueeze_1 = p_module_1_module_0_module_0_module_0_0_weight = p_module_1_module_0_module_0_module_0_0_bias = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)
        relu: "f32[32, 20, 20]" = torch.ops.aten.relu.default(conv2d);  conv2d = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/conv.py:554 in forward, code: return self._conv_forward(input, self.weight, self.bias)
        conv2d_1: "f32[64, 9, 9]" = torch.ops.aten.conv2d.default(relu, p_module_1_module_0_module_0_module_0_2_weight, p_module_1_module_0_module_0_module_0_2_bias, [2, 2]);  relu = p_module_1_module_0_module_0_module_0_2_weight = p_module_1_module_0_module_0_module_0_2_bias = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)
        relu_1: "f32[64, 9, 9]" = torch.ops.aten.relu.default(conv2d_1);  conv2d_1 = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/conv.py:554 in forward, code: return self._conv_forward(input, self.weight, self.bias)
        conv2d_2: "f32[64, 7, 7]" = torch.ops.aten.conv2d.default(relu_1, p_module_1_module_0_module_0_module_0_4_weight, p_module_1_module_0_module_0_module_0_4_bias);  relu_1 = p_module_1_module_0_module_0_module_0_4_weight = p_module_1_module_0_module_0_module_0_4_bias = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)
        relu_2: "f32[64, 7, 7]" = torch.ops.aten.relu.default(conv2d_2);  conv2d_2 = None

         # File: /pytorch/rl/torchrl/modules/models/utils.py:86 in forward, code: value = value.flatten(-self.ndims_in, -1)
        flatten: "f32[3136]" = torch.ops.aten.flatten.using_ints(relu_2, -3);  relu_2 = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
        linear: "f32[512]" = torch.ops.aten.linear.default(flatten, p_module_1_module_0_module_0_module_1_0_weight, p_module_1_module_0_module_0_module_1_0_bias);  flatten = p_module_1_module_0_module_0_module_1_0_weight = p_module_1_module_0_module_0_module_1_0_bias = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)
        relu_3: "f32[512]" = torch.ops.aten.relu.default(linear);  linear = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
        linear_1: "f32[6]" = torch.ops.aten.linear.default(relu_3, p_module_1_module_0_module_0_module_1_2_weight, p_module_1_module_0_module_0_module_1_2_bias);  relu_3 = p_module_1_module_0_module_0_module_1_2_weight = p_module_1_module_0_module_0_module_1_2_bias = None

         # File: /pytorch/rl/torchrl/modules/tensordict_module/actors.py:616 in forward, code: action = self.action_func_mapping[self.action_space](action_values)
        argmax: "i64[]" = torch.ops.aten.argmax.default(linear_1, -1)
        to_2: "i64[]" = torch.ops.aten.to.dtype(argmax, torch.int64);  argmax = None

         # File: /pytorch/rl/torchrl/modules/tensordict_module/actors.py:621 in forward, code: chosen_action_value = action_value_func(action_values, action)
        unsqueeze_2: "i64[1]" = torch.ops.aten.unsqueeze.default(to_2, -1)
        gather: "f32[1]" = torch.ops.aten.gather.default(linear_1, -1, unsqueeze_2);  linear_1 = unsqueeze_2 = gather = None
        return (to_2,)


'class GraphModule(torch.nn.Module):\n    def forward(self, p_module_1_module_0_module_0_module_0_0_weight: "f32[32, 1, 8, 8]", p_module_1_module_0_module_0_module_0_0_bias: "f32[32]", p_module_1_module_0_module_0_module_0_2_weight: "f32[64, 32, 4, 4]", p_module_1_module_0_module_0_module_0_2_bias: "f32[64]", p_module_1_module_0_module_0_module_0_4_weight: "f32[64, 64, 3, 3]", p_module_1_module_0_module_0_module_0_4_bias: "f32[64]", p_module_1_module_0_module_0_module_1_0_weight: "f32[512, 3136]", p_module_1_module_0_module_0_module_1_0_bias: "f32[512]", p_module_1_module_0_module_0_module_1_2_weight: "f32[6, 512]", p_module_1_module_0_module_0_module_1_2_bias: "f32[6]", b_module_1_module_1_eps_init: "f32[]", b_module_1_module_1_eps_end: "f32[]", b_module_1_module_1_eps: "f32[]", kwargs_pixels: "u8[210, 160, 3]"):\n         # File: /pytorch/rl/torchrl/envs/transforms/transforms.py:308 in forward, code: data = self._apply_transform(data)\n        permute: "u8[3, 210, 160]" = torch.ops.aten.permute.default(kwargs_pixels, [-1, -3, -2]);  kwargs_pixels = None\n        div: "f32[3, 210, 160]" = torch.ops.aten.div.Tensor(permute, 255);  permute = None\n        to: "f32[3, 210, 160]" = torch.ops.aten.to.dtype(div, torch.float32);  div = None\n        unsqueeze: "f32[1, 3, 210, 160]" = torch.ops.aten.unsqueeze.default(to, 0);  to = None\n        upsample_nearest2d: "f32[1, 3, 84, 84]" = torch.ops.aten.upsample_nearest2d.vec(unsqueeze, [84, 84], None);  unsqueeze = None\n        squeeze: "f32[3, 84, 84]" = torch.ops.aten.squeeze.dim(upsample_nearest2d, 0);  upsample_nearest2d = None\n        unbind = torch.ops.aten.unbind.int(squeeze, -3);  squeeze = None\n        getitem: "f32[84, 84]" = unbind[0]\n        getitem_1: "f32[84, 84]" = unbind[1]\n        getitem_2: "f32[84, 84]" = unbind[2];  unbind = None\n        mul: "f32[84, 84]" = torch.ops.aten.mul.Tensor(getitem, 0.2989);  getitem = None\n        mul_1: "f32[84, 84]" = torch.ops.aten.mul.Tensor(getitem_1, 0.587);  getitem_1 = None\n        add: "f32[84, 84]" = torch.ops.aten.add.Tensor(mul, mul_1);  mul = mul_1 = None\n        mul_2: "f32[84, 84]" = torch.ops.aten.mul.Tensor(getitem_2, 0.114);  getitem_2 = None\n        add_1: "f32[84, 84]" = torch.ops.aten.add.Tensor(add, mul_2);  add = mul_2 = None\n        to_1: "f32[84, 84]" = torch.ops.aten.to.dtype(add_1, torch.float32);  add_1 = None\n        unsqueeze_1: "f32[1, 84, 84]" = torch.ops.aten.unsqueeze.default(to_1, -3);  to_1 = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/conv.py:554 in forward, code: return self._conv_forward(input, self.weight, self.bias)\n        conv2d: "f32[32, 20, 20]" = torch.ops.aten.conv2d.default(unsqueeze_1, p_module_1_module_0_module_0_module_0_0_weight, p_module_1_module_0_module_0_module_0_0_bias, [4, 4]);  unsqueeze_1 = p_module_1_module_0_module_0_module_0_0_weight = p_module_1_module_0_module_0_module_0_0_bias = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)\n        relu: "f32[32, 20, 20]" = torch.ops.aten.relu.default(conv2d);  conv2d = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/conv.py:554 in forward, code: return self._conv_forward(input, self.weight, self.bias)\n        conv2d_1: "f32[64, 9, 9]" = torch.ops.aten.conv2d.default(relu, p_module_1_module_0_module_0_module_0_2_weight, p_module_1_module_0_module_0_module_0_2_bias, [2, 2]);  relu = p_module_1_module_0_module_0_module_0_2_weight = p_module_1_module_0_module_0_module_0_2_bias = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)\n        relu_1: "f32[64, 9, 9]" = torch.ops.aten.relu.default(conv2d_1);  conv2d_1 = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/conv.py:554 in forward, code: return self._conv_forward(input, self.weight, self.bias)\n        conv2d_2: "f32[64, 7, 7]" = torch.ops.aten.conv2d.default(relu_1, p_module_1_module_0_module_0_module_0_4_weight, p_module_1_module_0_module_0_module_0_4_bias);  relu_1 = p_module_1_module_0_module_0_module_0_4_weight = p_module_1_module_0_module_0_module_0_4_bias = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)\n        relu_2: "f32[64, 7, 7]" = torch.ops.aten.relu.default(conv2d_2);  conv2d_2 = None\n        \n         # File: /pytorch/rl/torchrl/modules/models/utils.py:86 in forward, code: value = value.flatten(-self.ndims_in, -1)\n        flatten: "f32[3136]" = torch.ops.aten.flatten.using_ints(relu_2, -3);  relu_2 = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)\n        linear: "f32[512]" = torch.ops.aten.linear.default(flatten, p_module_1_module_0_module_0_module_1_0_weight, p_module_1_module_0_module_0_module_1_0_bias);  flatten = p_module_1_module_0_module_0_module_1_0_weight = p_module_1_module_0_module_0_module_1_0_bias = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)\n        relu_3: "f32[512]" = torch.ops.aten.relu.default(linear);  linear = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)\n        linear_1: "f32[6]" = torch.ops.aten.linear.default(relu_3, p_module_1_module_0_module_0_module_1_2_weight, p_module_1_module_0_module_0_module_1_2_bias);  relu_3 = p_module_1_module_0_module_0_module_1_2_weight = p_module_1_module_0_module_0_module_1_2_bias = None\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/actors.py:616 in forward, code: action = self.action_func_mapping[self.action_space](action_values)\n        argmax: "i64[]" = torch.ops.aten.argmax.default(linear_1, -1)\n        to_2: "i64[]" = torch.ops.aten.to.dtype(argmax, torch.int64);  argmax = None\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/actors.py:621 in forward, code: chosen_action_value = action_value_func(action_values, action)\n        unsqueeze_2: "i64[1]" = torch.ops.aten.unsqueeze.default(to_2, -1)\n        gather: "f32[1]" = torch.ops.aten.gather.default(linear_1, -1, unsqueeze_2);  linear_1 = unsqueeze_2 = gather = None\n        return (to_2,)\n        '

最后,我们可以使用一个虚拟输入来执行该策略。输出(针对单张图像)应为一个 0 到 6 之间的整数,表示游戏中要执行的动作。

output = exported_policy.module()(pixels=pixels)
print("Exported module output", output)
Exported module output tensor(1)

有关导出 TensorDictModule 个实例的更多详细信息,请参阅 tensordict 文档

注意

导出接收和输出嵌套键的模块完全可行。 对应的关键字参数将是该键的 “_”.join(key) 版本,即 (“group0”, “agent0”, “obs”) 键将对应 “group0_agent0_obs” 关键字参数。键名冲突(例如 (“group0_agent0”, “obs”)(“group0”, “agent0_obs”))可能导致未定义行为,应不惜一切代价避免。 显然,键名也必须始终能生成有效的关键字参数,即不能包含空格、逗号等特殊字符。

torch.export 还有许多其他功能,我们将在下面进一步探讨。在此之前,让我们先简要讨论一下在测试时推理以及循环策略中的探索和随机策略。

与随机策略一起工作

如你所见,上面我们使用了set_exploration_type上下文管理器来控制策略的行为。如果策略是随机的(例如,策略输出一个动作空间上的分布,就像在PPO或其他类似的策略算法中那样)或探索性的(附加了一个探索模块,如E-Greedy、高斯加性或Ornstein-Uhlenbeck),我们可能希望或不希望在导出版本中使用这种探索策略。 幸运的是,导出工具可以理解这个上下文管理器,只要导出发生在正确的上下文管理器内,策略的行为应该与指示的一致。为了演示这一点,让我们尝试另一种探索类型:

with set_exploration_type("RANDOM"):
    exported_stochastic_policy = torch.export.export(
        policy_transform.select_out_keys("action"),
        args=(),
        kwargs={"pixels": pixels},
        strict=False,
    )

我们导出的策略现在应在调用栈末尾包含一个随机模块,这与之前的版本不同。 事实上,最后三个操作是:在 0 到 6 之间生成一个随机整数;使用一个随机掩码;并根据掩码中的值,选择网络输出或随机动作。

print("Stochastic policy")
exported_stochastic_policy.graph_module.print_readable()
Stochastic policy
class GraphModule(torch.nn.Module):
    def forward(self, p_module_1_module_0_module_0_module_0_0_weight: "f32[32, 1, 8, 8]", p_module_1_module_0_module_0_module_0_0_bias: "f32[32]", p_module_1_module_0_module_0_module_0_2_weight: "f32[64, 32, 4, 4]", p_module_1_module_0_module_0_module_0_2_bias: "f32[64]", p_module_1_module_0_module_0_module_0_4_weight: "f32[64, 64, 3, 3]", p_module_1_module_0_module_0_module_0_4_bias: "f32[64]", p_module_1_module_0_module_0_module_1_0_weight: "f32[512, 3136]", p_module_1_module_0_module_0_module_1_0_bias: "f32[512]", p_module_1_module_0_module_0_module_1_2_weight: "f32[6, 512]", p_module_1_module_0_module_0_module_1_2_bias: "f32[6]", b_module_1_module_1_eps_init: "f32[]", b_module_1_module_1_eps_end: "f32[]", b_module_1_module_1_eps: "f32[]", kwargs_pixels: "u8[210, 160, 3]"):
         # File: /pytorch/rl/torchrl/envs/transforms/transforms.py:308 in forward, code: data = self._apply_transform(data)
        permute: "u8[3, 210, 160]" = torch.ops.aten.permute.default(kwargs_pixels, [-1, -3, -2]);  kwargs_pixels = None
        div: "f32[3, 210, 160]" = torch.ops.aten.div.Tensor(permute, 255);  permute = None
        to: "f32[3, 210, 160]" = torch.ops.aten.to.dtype(div, torch.float32);  div = None
        unsqueeze: "f32[1, 3, 210, 160]" = torch.ops.aten.unsqueeze.default(to, 0);  to = None
        upsample_nearest2d: "f32[1, 3, 84, 84]" = torch.ops.aten.upsample_nearest2d.vec(unsqueeze, [84, 84], None);  unsqueeze = None
        squeeze: "f32[3, 84, 84]" = torch.ops.aten.squeeze.dim(upsample_nearest2d, 0);  upsample_nearest2d = None
        unbind = torch.ops.aten.unbind.int(squeeze, -3);  squeeze = None
        getitem: "f32[84, 84]" = unbind[0]
        getitem_1: "f32[84, 84]" = unbind[1]
        getitem_2: "f32[84, 84]" = unbind[2];  unbind = None
        mul: "f32[84, 84]" = torch.ops.aten.mul.Tensor(getitem, 0.2989);  getitem = None
        mul_1: "f32[84, 84]" = torch.ops.aten.mul.Tensor(getitem_1, 0.587);  getitem_1 = None
        add: "f32[84, 84]" = torch.ops.aten.add.Tensor(mul, mul_1);  mul = mul_1 = None
        mul_2: "f32[84, 84]" = torch.ops.aten.mul.Tensor(getitem_2, 0.114);  getitem_2 = None
        add_1: "f32[84, 84]" = torch.ops.aten.add.Tensor(add, mul_2);  add = mul_2 = None
        to_1: "f32[84, 84]" = torch.ops.aten.to.dtype(add_1, torch.float32);  add_1 = None
        unsqueeze_1: "f32[1, 84, 84]" = torch.ops.aten.unsqueeze.default(to_1, -3);  to_1 = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/conv.py:554 in forward, code: return self._conv_forward(input, self.weight, self.bias)
        conv2d: "f32[32, 20, 20]" = torch.ops.aten.conv2d.default(unsqueeze_1, p_module_1_module_0_module_0_module_0_0_weight, p_module_1_module_0_module_0_module_0_0_bias, [4, 4]);  unsqueeze_1 = p_module_1_module_0_module_0_module_0_0_weight = p_module_1_module_0_module_0_module_0_0_bias = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)
        relu: "f32[32, 20, 20]" = torch.ops.aten.relu.default(conv2d);  conv2d = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/conv.py:554 in forward, code: return self._conv_forward(input, self.weight, self.bias)
        conv2d_1: "f32[64, 9, 9]" = torch.ops.aten.conv2d.default(relu, p_module_1_module_0_module_0_module_0_2_weight, p_module_1_module_0_module_0_module_0_2_bias, [2, 2]);  relu = p_module_1_module_0_module_0_module_0_2_weight = p_module_1_module_0_module_0_module_0_2_bias = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)
        relu_1: "f32[64, 9, 9]" = torch.ops.aten.relu.default(conv2d_1);  conv2d_1 = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/conv.py:554 in forward, code: return self._conv_forward(input, self.weight, self.bias)
        conv2d_2: "f32[64, 7, 7]" = torch.ops.aten.conv2d.default(relu_1, p_module_1_module_0_module_0_module_0_4_weight, p_module_1_module_0_module_0_module_0_4_bias);  relu_1 = p_module_1_module_0_module_0_module_0_4_weight = p_module_1_module_0_module_0_module_0_4_bias = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)
        relu_2: "f32[64, 7, 7]" = torch.ops.aten.relu.default(conv2d_2);  conv2d_2 = None

         # File: /pytorch/rl/torchrl/modules/models/utils.py:86 in forward, code: value = value.flatten(-self.ndims_in, -1)
        flatten: "f32[3136]" = torch.ops.aten.flatten.using_ints(relu_2, -3);  relu_2 = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
        linear: "f32[512]" = torch.ops.aten.linear.default(flatten, p_module_1_module_0_module_0_module_1_0_weight, p_module_1_module_0_module_0_module_1_0_bias);  flatten = p_module_1_module_0_module_0_module_1_0_weight = p_module_1_module_0_module_0_module_1_0_bias = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)
        relu_3: "f32[512]" = torch.ops.aten.relu.default(linear);  linear = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
        linear_1: "f32[6]" = torch.ops.aten.linear.default(relu_3, p_module_1_module_0_module_0_module_1_2_weight, p_module_1_module_0_module_0_module_1_2_bias);  relu_3 = p_module_1_module_0_module_0_module_1_2_weight = p_module_1_module_0_module_0_module_1_2_bias = None

         # File: /pytorch/rl/torchrl/modules/tensordict_module/actors.py:616 in forward, code: action = self.action_func_mapping[self.action_space](action_values)
        argmax: "i64[]" = torch.ops.aten.argmax.default(linear_1, -1)
        to_2: "i64[]" = torch.ops.aten.to.dtype(argmax, torch.int64);  argmax = None

         # File: /pytorch/rl/torchrl/modules/tensordict_module/actors.py:621 in forward, code: chosen_action_value = action_value_func(action_values, action)
        unsqueeze_2: "i64[1]" = torch.ops.aten.unsqueeze.default(to_2, -1)
        gather: "f32[1]" = torch.ops.aten.gather.default(linear_1, -1, unsqueeze_2);  linear_1 = unsqueeze_2 = gather = None

         # File: /pytorch/rl/torchrl/modules/tensordict_module/exploration.py:167 in forward, code: cond = torch.rand(action_tensordict.shape, device=out.device) < eps
        rand: "f32[]" = torch.ops.aten.rand.default([], device = device(type='cpu'), pin_memory = False)
        lt: "b8[]" = torch.ops.aten.lt.Tensor(rand, b_module_1_module_1_eps);  rand = b_module_1_module_1_eps = None

         # File: /pytorch/rl/torchrl/modules/tensordict_module/exploration.py:169 in forward, code: cond = expand_as_right(cond, out)
        expand: "b8[]" = torch.ops.aten.expand.default(lt, []);  lt = None

         # File: /pytorch/rl/torchrl/modules/tensordict_module/exploration.py:193 in forward, code: r = spec.rand()
        randint: "i64[]" = torch.ops.aten.randint.low(0, 6, [], device = device(type='cpu'), pin_memory = False)

         # File: /pytorch/rl/torchrl/modules/tensordict_module/exploration.py:196 in forward, code: out = torch.where(cond, r, out)
        where: "i64[]" = torch.ops.aten.where.self(expand, randint, to_2);  expand = randint = to_2 = None
        return (where,)


'class GraphModule(torch.nn.Module):\n    def forward(self, p_module_1_module_0_module_0_module_0_0_weight: "f32[32, 1, 8, 8]", p_module_1_module_0_module_0_module_0_0_bias: "f32[32]", p_module_1_module_0_module_0_module_0_2_weight: "f32[64, 32, 4, 4]", p_module_1_module_0_module_0_module_0_2_bias: "f32[64]", p_module_1_module_0_module_0_module_0_4_weight: "f32[64, 64, 3, 3]", p_module_1_module_0_module_0_module_0_4_bias: "f32[64]", p_module_1_module_0_module_0_module_1_0_weight: "f32[512, 3136]", p_module_1_module_0_module_0_module_1_0_bias: "f32[512]", p_module_1_module_0_module_0_module_1_2_weight: "f32[6, 512]", p_module_1_module_0_module_0_module_1_2_bias: "f32[6]", b_module_1_module_1_eps_init: "f32[]", b_module_1_module_1_eps_end: "f32[]", b_module_1_module_1_eps: "f32[]", kwargs_pixels: "u8[210, 160, 3]"):\n         # File: /pytorch/rl/torchrl/envs/transforms/transforms.py:308 in forward, code: data = self._apply_transform(data)\n        permute: "u8[3, 210, 160]" = torch.ops.aten.permute.default(kwargs_pixels, [-1, -3, -2]);  kwargs_pixels = None\n        div: "f32[3, 210, 160]" = torch.ops.aten.div.Tensor(permute, 255);  permute = None\n        to: "f32[3, 210, 160]" = torch.ops.aten.to.dtype(div, torch.float32);  div = None\n        unsqueeze: "f32[1, 3, 210, 160]" = torch.ops.aten.unsqueeze.default(to, 0);  to = None\n        upsample_nearest2d: "f32[1, 3, 84, 84]" = torch.ops.aten.upsample_nearest2d.vec(unsqueeze, [84, 84], None);  unsqueeze = None\n        squeeze: "f32[3, 84, 84]" = torch.ops.aten.squeeze.dim(upsample_nearest2d, 0);  upsample_nearest2d = None\n        unbind = torch.ops.aten.unbind.int(squeeze, -3);  squeeze = None\n        getitem: "f32[84, 84]" = unbind[0]\n        getitem_1: "f32[84, 84]" = unbind[1]\n        getitem_2: "f32[84, 84]" = unbind[2];  unbind = None\n        mul: "f32[84, 84]" = torch.ops.aten.mul.Tensor(getitem, 0.2989);  getitem = None\n        mul_1: "f32[84, 84]" = torch.ops.aten.mul.Tensor(getitem_1, 0.587);  getitem_1 = None\n        add: "f32[84, 84]" = torch.ops.aten.add.Tensor(mul, mul_1);  mul = mul_1 = None\n        mul_2: "f32[84, 84]" = torch.ops.aten.mul.Tensor(getitem_2, 0.114);  getitem_2 = None\n        add_1: "f32[84, 84]" = torch.ops.aten.add.Tensor(add, mul_2);  add = mul_2 = None\n        to_1: "f32[84, 84]" = torch.ops.aten.to.dtype(add_1, torch.float32);  add_1 = None\n        unsqueeze_1: "f32[1, 84, 84]" = torch.ops.aten.unsqueeze.default(to_1, -3);  to_1 = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/conv.py:554 in forward, code: return self._conv_forward(input, self.weight, self.bias)\n        conv2d: "f32[32, 20, 20]" = torch.ops.aten.conv2d.default(unsqueeze_1, p_module_1_module_0_module_0_module_0_0_weight, p_module_1_module_0_module_0_module_0_0_bias, [4, 4]);  unsqueeze_1 = p_module_1_module_0_module_0_module_0_0_weight = p_module_1_module_0_module_0_module_0_0_bias = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)\n        relu: "f32[32, 20, 20]" = torch.ops.aten.relu.default(conv2d);  conv2d = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/conv.py:554 in forward, code: return self._conv_forward(input, self.weight, self.bias)\n        conv2d_1: "f32[64, 9, 9]" = torch.ops.aten.conv2d.default(relu, p_module_1_module_0_module_0_module_0_2_weight, p_module_1_module_0_module_0_module_0_2_bias, [2, 2]);  relu = p_module_1_module_0_module_0_module_0_2_weight = p_module_1_module_0_module_0_module_0_2_bias = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)\n        relu_1: "f32[64, 9, 9]" = torch.ops.aten.relu.default(conv2d_1);  conv2d_1 = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/conv.py:554 in forward, code: return self._conv_forward(input, self.weight, self.bias)\n        conv2d_2: "f32[64, 7, 7]" = torch.ops.aten.conv2d.default(relu_1, p_module_1_module_0_module_0_module_0_4_weight, p_module_1_module_0_module_0_module_0_4_bias);  relu_1 = p_module_1_module_0_module_0_module_0_4_weight = p_module_1_module_0_module_0_module_0_4_bias = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)\n        relu_2: "f32[64, 7, 7]" = torch.ops.aten.relu.default(conv2d_2);  conv2d_2 = None\n        \n         # File: /pytorch/rl/torchrl/modules/models/utils.py:86 in forward, code: value = value.flatten(-self.ndims_in, -1)\n        flatten: "f32[3136]" = torch.ops.aten.flatten.using_ints(relu_2, -3);  relu_2 = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)\n        linear: "f32[512]" = torch.ops.aten.linear.default(flatten, p_module_1_module_0_module_0_module_1_0_weight, p_module_1_module_0_module_0_module_1_0_bias);  flatten = p_module_1_module_0_module_0_module_1_0_weight = p_module_1_module_0_module_0_module_1_0_bias = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)\n        relu_3: "f32[512]" = torch.ops.aten.relu.default(linear);  linear = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)\n        linear_1: "f32[6]" = torch.ops.aten.linear.default(relu_3, p_module_1_module_0_module_0_module_1_2_weight, p_module_1_module_0_module_0_module_1_2_bias);  relu_3 = p_module_1_module_0_module_0_module_1_2_weight = p_module_1_module_0_module_0_module_1_2_bias = None\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/actors.py:616 in forward, code: action = self.action_func_mapping[self.action_space](action_values)\n        argmax: "i64[]" = torch.ops.aten.argmax.default(linear_1, -1)\n        to_2: "i64[]" = torch.ops.aten.to.dtype(argmax, torch.int64);  argmax = None\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/actors.py:621 in forward, code: chosen_action_value = action_value_func(action_values, action)\n        unsqueeze_2: "i64[1]" = torch.ops.aten.unsqueeze.default(to_2, -1)\n        gather: "f32[1]" = torch.ops.aten.gather.default(linear_1, -1, unsqueeze_2);  linear_1 = unsqueeze_2 = gather = None\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/exploration.py:167 in forward, code: cond = torch.rand(action_tensordict.shape, device=out.device) < eps\n        rand: "f32[]" = torch.ops.aten.rand.default([], device = device(type=\'cpu\'), pin_memory = False)\n        lt: "b8[]" = torch.ops.aten.lt.Tensor(rand, b_module_1_module_1_eps);  rand = b_module_1_module_1_eps = None\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/exploration.py:169 in forward, code: cond = expand_as_right(cond, out)\n        expand: "b8[]" = torch.ops.aten.expand.default(lt, []);  lt = None\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/exploration.py:193 in forward, code: r = spec.rand()\n        randint: "i64[]" = torch.ops.aten.randint.low(0, 6, [], device = device(type=\'cpu\'), pin_memory = False)\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/exploration.py:196 in forward, code: out = torch.where(cond, r, out)\n        where: "i64[]" = torch.ops.aten.where.self(expand, randint, to_2);  expand = randint = to_2 = None\n        return (where,)\n        '

处理循环策略

另一个典型用例是循环策略,它将输出一个动作以及一个或多个循环状态。 LSTM 和 GRU 是基于 CuDNN 的模块,这意味着它们的行为与常规 Module 实例不同(导出工具可能无法很好地追踪它们)。幸运的是,TorchRL 提供了这些模块的 Python 实现,可以在需要时替换为 CuDNN 版本。

为了展示这一点,让我们编写一个依赖于RNN的原型策略:

from tensordict.nn import TensorDictModule
from torchrl.envs import BatchSizeTransform
from torchrl.modules import LSTMModule, MLP

lstm = LSTMModule(
    input_size=32,
    num_layers=2,
    hidden_size=256,
    in_keys=["observation", "hidden0", "hidden1"],
    out_keys=["intermediate", "hidden0", "hidden1"],
)

如果LSTM模块不是基于Python而是CuDNN (LSTM), 可以使用make_python_based()方法来使用Python版本。

lstm = lstm.make_python_based()

现在,我们来创建策略网络。我们将两个用于改变输入形状的层(unsqueeze/squeeze 操作)与 LSTM 和一个多层感知机(MLP)组合在一起。

recurrent_policy = TensorDictSequential(
    # Unsqueeze the first dim of all tensors to make LSTMCell happy
    BatchSizeTransform(reshape_fn=lambda x: x.unsqueeze(0)),
    lstm,
    TensorDictModule(
        MLP(in_features=256, out_features=5, num_cells=[64, 64]),
        in_keys=["intermediate"],
        out_keys=["action"],
    ),
    # Squeeze the first dim of all tensors to get the original shape back
    BatchSizeTransform(reshape_fn=lambda x: x.squeeze(0)),
)

如前所述,我们选择相关的键:

recurrent_policy.select_out_keys("action", "hidden0", "hidden1")
print("recurrent policy input keys:", recurrent_policy.in_keys)
print("recurrent policy output keys:", recurrent_policy.out_keys)
recurrent policy input keys: ['observation', 'hidden0', 'hidden1', 'is_init']
recurrent policy output keys: ['action', 'hidden0', 'hidden1']

我们现在准备好导出了。为此,我们构建假输入并将其传递给export()

fake_obs = torch.randn(32)
fake_hidden0 = torch.randn(2, 256)
fake_hidden1 = torch.randn(2, 256)

# Tensor indicating whether the state is the first of a sequence
fake_is_init = torch.zeros((), dtype=torch.bool)

exported_recurrent_policy = torch.export.export(
    recurrent_policy,
    args=(),
    kwargs={
        "observation": fake_obs,
        "hidden0": fake_hidden0,
        "hidden1": fake_hidden1,
        "is_init": fake_is_init,
    },
    strict=False,
)
print("Recurrent policy graph:")
exported_recurrent_policy.graph_module.print_readable()
Recurrent policy graph:
class GraphModule(torch.nn.Module):
    def forward(self, p_module_1_lstm_weight_ih_l0: "f32[1024, 32]", p_module_1_lstm_weight_hh_l0: "f32[1024, 256]", p_module_1_lstm_bias_ih_l0: "f32[1024]", p_module_1_lstm_bias_hh_l0: "f32[1024]", p_module_1_lstm_weight_ih_l1: "f32[1024, 256]", p_module_1_lstm_weight_hh_l1: "f32[1024, 256]", p_module_1_lstm_bias_ih_l1: "f32[1024]", p_module_1_lstm_bias_hh_l1: "f32[1024]", p_module_2_module_0_weight: "f32[64, 256]", p_module_2_module_0_bias: "f32[64]", p_module_2_module_2_weight: "f32[64, 64]", p_module_2_module_2_bias: "f32[64]", p_module_2_module_4_weight: "f32[5, 64]", p_module_2_module_4_bias: "f32[5]", kwargs_observation: "f32[32]", kwargs_hidden0: "f32[2, 256]", kwargs_hidden1: "f32[2, 256]", kwargs_is_init: "b8[]"):
         # File: /pytorch/rl/env/lib/python3.10/site-packages/tensordict/nn/sequence.py:540 in forward, code: tensordict_exec = self._run_module(module, tensordict_exec, **kwargs)
        unsqueeze: "f32[1, 32]" = torch.ops.aten.unsqueeze.default(kwargs_observation, 0);  kwargs_observation = None
        unsqueeze_1: "f32[1, 2, 256]" = torch.ops.aten.unsqueeze.default(kwargs_hidden0, 0);  kwargs_hidden0 = None
        unsqueeze_2: "f32[1, 2, 256]" = torch.ops.aten.unsqueeze.default(kwargs_hidden1, 0);  kwargs_hidden1 = None
        unsqueeze_3: "b8[1]" = torch.ops.aten.unsqueeze.default(kwargs_is_init, 0);  kwargs_is_init = None

         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:743 in forward, code: tensordict_shaped = tensordict.reshape(-1).unsqueeze(-1)
        unsqueeze_4: "f32[1, 1, 32]" = torch.ops.aten.unsqueeze.default(unsqueeze, 1);  unsqueeze = None
        unsqueeze_5: "f32[1, 1, 2, 256]" = torch.ops.aten.unsqueeze.default(unsqueeze_1, 1);  unsqueeze_1 = None
        unsqueeze_6: "f32[1, 1, 2, 256]" = torch.ops.aten.unsqueeze.default(unsqueeze_2, 1);  unsqueeze_2 = None
        unsqueeze_7: "b8[1, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_3, 1);  unsqueeze_3 = None

         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:745 in forward, code: is_init = tensordict_shaped["is_init"].squeeze(-1)
        squeeze: "b8[1]" = torch.ops.aten.squeeze.dim(unsqueeze_7, -1)

         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:772 in forward, code: is_init_expand = expand_as_right(is_init, hidden0)
        unsqueeze_8: "b8[1, 1]" = torch.ops.aten.unsqueeze.default(squeeze, -1);  squeeze = None
        unsqueeze_9: "b8[1, 1, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_8, -1);  unsqueeze_8 = None
        unsqueeze_10: "b8[1, 1, 1, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_9, -1);  unsqueeze_9 = None
        expand: "b8[1, 1, 2, 256]" = torch.ops.aten.expand.default(unsqueeze_10, [1, 1, 2, 256]);  unsqueeze_10 = None

         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:773 in forward, code: hidden0 = torch.where(is_init_expand, 0, hidden0)
        where: "f32[1, 1, 2, 256]" = torch.ops.aten.where.ScalarSelf(expand, 0, unsqueeze_5);  unsqueeze_5 = None

         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:774 in forward, code: hidden1 = torch.where(is_init_expand, 0, hidden1)
        where_1: "f32[1, 1, 2, 256]" = torch.ops.aten.where.ScalarSelf(expand, 0, unsqueeze_6);  expand = unsqueeze_6 = None

         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:775 in forward, code: val, hidden0, hidden1 = self._lstm(
        slice_1: "f32[1, 1, 2, 256]" = torch.ops.aten.slice.Tensor(where, 0, 0, 9223372036854775807);  where = None
        select: "f32[1, 2, 256]" = torch.ops.aten.select.int(slice_1, 1, 0);  slice_1 = None
        slice_2: "f32[1, 1, 2, 256]" = torch.ops.aten.slice.Tensor(where_1, 0, 0, 9223372036854775807);  where_1 = None
        select_1: "f32[1, 2, 256]" = torch.ops.aten.select.int(slice_2, 1, 0);  slice_2 = None
        transpose: "f32[2, 1, 256]" = torch.ops.aten.transpose.int(select, -3, -2);  select = None
        transpose_1: "f32[2, 1, 256]" = torch.ops.aten.transpose.int(select_1, -3, -2);  select_1 = None

         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:317 in forward, code: return self._lstm(input, hx)
        unbind = torch.ops.aten.unbind.int(transpose);  transpose = None
        getitem: "f32[1, 256]" = unbind[0]
        getitem_1: "f32[1, 256]" = unbind[1];  unbind = None
        unbind_1 = torch.ops.aten.unbind.int(transpose_1);  transpose_1 = None
        getitem_2: "f32[1, 256]" = unbind_1[0]
        getitem_3: "f32[1, 256]" = unbind_1[1];  unbind_1 = None
        unbind_2 = torch.ops.aten.unbind.int(unsqueeze_4, 1)
        getitem_4: "f32[1, 32]" = unbind_2[0];  unbind_2 = None
        linear: "f32[1, 1024]" = torch.ops.aten.linear.default(getitem_4, p_module_1_lstm_weight_ih_l0, p_module_1_lstm_bias_ih_l0);  getitem_4 = p_module_1_lstm_weight_ih_l0 = p_module_1_lstm_bias_ih_l0 = None
        linear_1: "f32[1, 1024]" = torch.ops.aten.linear.default(getitem, p_module_1_lstm_weight_hh_l0, p_module_1_lstm_bias_hh_l0);  getitem = p_module_1_lstm_weight_hh_l0 = p_module_1_lstm_bias_hh_l0 = None
        add: "f32[1, 1024]" = torch.ops.aten.add.Tensor(linear, linear_1);  linear = linear_1 = None
        chunk = torch.ops.aten.chunk.default(add, 4, 1);  add = None
        getitem_5: "f32[1, 256]" = chunk[0]
        getitem_6: "f32[1, 256]" = chunk[1]
        getitem_7: "f32[1, 256]" = chunk[2]
        getitem_8: "f32[1, 256]" = chunk[3];  chunk = None
        sigmoid: "f32[1, 256]" = torch.ops.aten.sigmoid.default(getitem_5);  getitem_5 = None
        sigmoid_1: "f32[1, 256]" = torch.ops.aten.sigmoid.default(getitem_6);  getitem_6 = None
        tanh: "f32[1, 256]" = torch.ops.aten.tanh.default(getitem_7);  getitem_7 = None
        sigmoid_2: "f32[1, 256]" = torch.ops.aten.sigmoid.default(getitem_8);  getitem_8 = None
        mul: "f32[1, 256]" = torch.ops.aten.mul.Tensor(getitem_2, sigmoid_1);  getitem_2 = sigmoid_1 = None
        mul_1: "f32[1, 256]" = torch.ops.aten.mul.Tensor(sigmoid, tanh);  sigmoid = tanh = None
        add_1: "f32[1, 256]" = torch.ops.aten.add.Tensor(mul, mul_1);  mul = mul_1 = None
        tanh_1: "f32[1, 256]" = torch.ops.aten.tanh.default(add_1)
        mul_2: "f32[1, 256]" = torch.ops.aten.mul.Tensor(sigmoid_2, tanh_1);  sigmoid_2 = tanh_1 = None
        linear_2: "f32[1, 1024]" = torch.ops.aten.linear.default(mul_2, p_module_1_lstm_weight_ih_l1, p_module_1_lstm_bias_ih_l1);  p_module_1_lstm_weight_ih_l1 = p_module_1_lstm_bias_ih_l1 = None
        linear_3: "f32[1, 1024]" = torch.ops.aten.linear.default(getitem_1, p_module_1_lstm_weight_hh_l1, p_module_1_lstm_bias_hh_l1);  getitem_1 = p_module_1_lstm_weight_hh_l1 = p_module_1_lstm_bias_hh_l1 = None
        add_2: "f32[1, 1024]" = torch.ops.aten.add.Tensor(linear_2, linear_3);  linear_2 = linear_3 = None
        chunk_1 = torch.ops.aten.chunk.default(add_2, 4, 1);  add_2 = None
        getitem_9: "f32[1, 256]" = chunk_1[0]
        getitem_10: "f32[1, 256]" = chunk_1[1]
        getitem_11: "f32[1, 256]" = chunk_1[2]
        getitem_12: "f32[1, 256]" = chunk_1[3];  chunk_1 = None
        sigmoid_3: "f32[1, 256]" = torch.ops.aten.sigmoid.default(getitem_9);  getitem_9 = None
        sigmoid_4: "f32[1, 256]" = torch.ops.aten.sigmoid.default(getitem_10);  getitem_10 = None
        tanh_2: "f32[1, 256]" = torch.ops.aten.tanh.default(getitem_11);  getitem_11 = None
        sigmoid_5: "f32[1, 256]" = torch.ops.aten.sigmoid.default(getitem_12);  getitem_12 = None
        mul_3: "f32[1, 256]" = torch.ops.aten.mul.Tensor(getitem_3, sigmoid_4);  getitem_3 = sigmoid_4 = None
        mul_4: "f32[1, 256]" = torch.ops.aten.mul.Tensor(sigmoid_3, tanh_2);  sigmoid_3 = tanh_2 = None
        add_3: "f32[1, 256]" = torch.ops.aten.add.Tensor(mul_3, mul_4);  mul_3 = mul_4 = None
        tanh_3: "f32[1, 256]" = torch.ops.aten.tanh.default(add_3)
        mul_5: "f32[1, 256]" = torch.ops.aten.mul.Tensor(sigmoid_5, tanh_3);  sigmoid_5 = tanh_3 = None
        stack: "f32[1, 1, 256]" = torch.ops.aten.stack.default([mul_5], 1)
        stack_1: "f32[2, 1, 256]" = torch.ops.aten.stack.default([mul_2, mul_5]);  mul_2 = mul_5 = None
        stack_2: "f32[2, 1, 256]" = torch.ops.aten.stack.default([add_1, add_3]);  add_1 = add_3 = None

         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:775 in forward, code: val, hidden0, hidden1 = self._lstm(
        transpose_2: "f32[1, 2, 256]" = torch.ops.aten.transpose.int(stack_1, 0, 1);  stack_1 = None
        transpose_3: "f32[1, 2, 256]" = torch.ops.aten.transpose.int(stack_2, 0, 1);  stack_2 = None
        stack_3: "f32[1, 1, 2, 256]" = torch.ops.aten.stack.default([transpose_2], 1);  transpose_2 = None
        stack_4: "f32[1, 1, 2, 256]" = torch.ops.aten.stack.default([transpose_3], 1);  transpose_3 = None

         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:788 in forward, code: tensordict.update(tensordict_shaped.reshape(shape))
        reshape: "f32[1, 32]" = torch.ops.aten.reshape.default(unsqueeze_4, [1, 32]);  unsqueeze_4 = None
        reshape_1: "f32[1, 2, 256]" = torch.ops.aten.reshape.default(stack_3, [1, 2, 256]);  stack_3 = None
        reshape_2: "f32[1, 2, 256]" = torch.ops.aten.reshape.default(stack_4, [1, 2, 256]);  stack_4 = None
        reshape_3: "b8[1]" = torch.ops.aten.reshape.default(unsqueeze_7, [1]);  unsqueeze_7 = None
        reshape_4: "f32[1, 256]" = torch.ops.aten.reshape.default(stack, [1, 256]);  stack = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
        linear_4: "f32[1, 64]" = torch.ops.aten.linear.default(reshape_4, p_module_2_module_0_weight, p_module_2_module_0_bias);  p_module_2_module_0_weight = p_module_2_module_0_bias = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:392 in forward, code: return torch.tanh(input)
        tanh_4: "f32[1, 64]" = torch.ops.aten.tanh.default(linear_4);  linear_4 = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
        linear_5: "f32[1, 64]" = torch.ops.aten.linear.default(tanh_4, p_module_2_module_2_weight, p_module_2_module_2_bias);  tanh_4 = p_module_2_module_2_weight = p_module_2_module_2_bias = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:392 in forward, code: return torch.tanh(input)
        tanh_5: "f32[1, 64]" = torch.ops.aten.tanh.default(linear_5);  linear_5 = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
        linear_6: "f32[1, 5]" = torch.ops.aten.linear.default(tanh_5, p_module_2_module_4_weight, p_module_2_module_4_bias);  tanh_5 = p_module_2_module_4_weight = p_module_2_module_4_bias = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/tensordict/nn/sequence.py:540 in forward, code: tensordict_exec = self._run_module(module, tensordict_exec, **kwargs)
        squeeze_1: "f32[32]" = torch.ops.aten.squeeze.dim(reshape, 0);  reshape = squeeze_1 = None
        squeeze_2: "f32[2, 256]" = torch.ops.aten.squeeze.dim(reshape_1, 0);  reshape_1 = None
        squeeze_3: "f32[2, 256]" = torch.ops.aten.squeeze.dim(reshape_2, 0);  reshape_2 = None
        squeeze_4: "b8[]" = torch.ops.aten.squeeze.dim(reshape_3, 0);  reshape_3 = squeeze_4 = None
        squeeze_5: "f32[256]" = torch.ops.aten.squeeze.dim(reshape_4, 0);  reshape_4 = squeeze_5 = None
        squeeze_6: "f32[5]" = torch.ops.aten.squeeze.dim(linear_6, 0);  linear_6 = None
        return (squeeze_6, squeeze_2, squeeze_3)


'class GraphModule(torch.nn.Module):\n    def forward(self, p_module_1_lstm_weight_ih_l0: "f32[1024, 32]", p_module_1_lstm_weight_hh_l0: "f32[1024, 256]", p_module_1_lstm_bias_ih_l0: "f32[1024]", p_module_1_lstm_bias_hh_l0: "f32[1024]", p_module_1_lstm_weight_ih_l1: "f32[1024, 256]", p_module_1_lstm_weight_hh_l1: "f32[1024, 256]", p_module_1_lstm_bias_ih_l1: "f32[1024]", p_module_1_lstm_bias_hh_l1: "f32[1024]", p_module_2_module_0_weight: "f32[64, 256]", p_module_2_module_0_bias: "f32[64]", p_module_2_module_2_weight: "f32[64, 64]", p_module_2_module_2_bias: "f32[64]", p_module_2_module_4_weight: "f32[5, 64]", p_module_2_module_4_bias: "f32[5]", kwargs_observation: "f32[32]", kwargs_hidden0: "f32[2, 256]", kwargs_hidden1: "f32[2, 256]", kwargs_is_init: "b8[]"):\n         # File: /pytorch/rl/env/lib/python3.10/site-packages/tensordict/nn/sequence.py:540 in forward, code: tensordict_exec = self._run_module(module, tensordict_exec, **kwargs)\n        unsqueeze: "f32[1, 32]" = torch.ops.aten.unsqueeze.default(kwargs_observation, 0);  kwargs_observation = None\n        unsqueeze_1: "f32[1, 2, 256]" = torch.ops.aten.unsqueeze.default(kwargs_hidden0, 0);  kwargs_hidden0 = None\n        unsqueeze_2: "f32[1, 2, 256]" = torch.ops.aten.unsqueeze.default(kwargs_hidden1, 0);  kwargs_hidden1 = None\n        unsqueeze_3: "b8[1]" = torch.ops.aten.unsqueeze.default(kwargs_is_init, 0);  kwargs_is_init = None\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:743 in forward, code: tensordict_shaped = tensordict.reshape(-1).unsqueeze(-1)\n        unsqueeze_4: "f32[1, 1, 32]" = torch.ops.aten.unsqueeze.default(unsqueeze, 1);  unsqueeze = None\n        unsqueeze_5: "f32[1, 1, 2, 256]" = torch.ops.aten.unsqueeze.default(unsqueeze_1, 1);  unsqueeze_1 = None\n        unsqueeze_6: "f32[1, 1, 2, 256]" = torch.ops.aten.unsqueeze.default(unsqueeze_2, 1);  unsqueeze_2 = None\n        unsqueeze_7: "b8[1, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_3, 1);  unsqueeze_3 = None\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:745 in forward, code: is_init = tensordict_shaped["is_init"].squeeze(-1)\n        squeeze: "b8[1]" = torch.ops.aten.squeeze.dim(unsqueeze_7, -1)\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:772 in forward, code: is_init_expand = expand_as_right(is_init, hidden0)\n        unsqueeze_8: "b8[1, 1]" = torch.ops.aten.unsqueeze.default(squeeze, -1);  squeeze = None\n        unsqueeze_9: "b8[1, 1, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_8, -1);  unsqueeze_8 = None\n        unsqueeze_10: "b8[1, 1, 1, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_9, -1);  unsqueeze_9 = None\n        expand: "b8[1, 1, 2, 256]" = torch.ops.aten.expand.default(unsqueeze_10, [1, 1, 2, 256]);  unsqueeze_10 = None\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:773 in forward, code: hidden0 = torch.where(is_init_expand, 0, hidden0)\n        where: "f32[1, 1, 2, 256]" = torch.ops.aten.where.ScalarSelf(expand, 0, unsqueeze_5);  unsqueeze_5 = None\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:774 in forward, code: hidden1 = torch.where(is_init_expand, 0, hidden1)\n        where_1: "f32[1, 1, 2, 256]" = torch.ops.aten.where.ScalarSelf(expand, 0, unsqueeze_6);  expand = unsqueeze_6 = None\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:775 in forward, code: val, hidden0, hidden1 = self._lstm(\n        slice_1: "f32[1, 1, 2, 256]" = torch.ops.aten.slice.Tensor(where, 0, 0, 9223372036854775807);  where = None\n        select: "f32[1, 2, 256]" = torch.ops.aten.select.int(slice_1, 1, 0);  slice_1 = None\n        slice_2: "f32[1, 1, 2, 256]" = torch.ops.aten.slice.Tensor(where_1, 0, 0, 9223372036854775807);  where_1 = None\n        select_1: "f32[1, 2, 256]" = torch.ops.aten.select.int(slice_2, 1, 0);  slice_2 = None\n        transpose: "f32[2, 1, 256]" = torch.ops.aten.transpose.int(select, -3, -2);  select = None\n        transpose_1: "f32[2, 1, 256]" = torch.ops.aten.transpose.int(select_1, -3, -2);  select_1 = None\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:317 in forward, code: return self._lstm(input, hx)\n        unbind = torch.ops.aten.unbind.int(transpose);  transpose = None\n        getitem: "f32[1, 256]" = unbind[0]\n        getitem_1: "f32[1, 256]" = unbind[1];  unbind = None\n        unbind_1 = torch.ops.aten.unbind.int(transpose_1);  transpose_1 = None\n        getitem_2: "f32[1, 256]" = unbind_1[0]\n        getitem_3: "f32[1, 256]" = unbind_1[1];  unbind_1 = None\n        unbind_2 = torch.ops.aten.unbind.int(unsqueeze_4, 1)\n        getitem_4: "f32[1, 32]" = unbind_2[0];  unbind_2 = None\n        linear: "f32[1, 1024]" = torch.ops.aten.linear.default(getitem_4, p_module_1_lstm_weight_ih_l0, p_module_1_lstm_bias_ih_l0);  getitem_4 = p_module_1_lstm_weight_ih_l0 = p_module_1_lstm_bias_ih_l0 = None\n        linear_1: "f32[1, 1024]" = torch.ops.aten.linear.default(getitem, p_module_1_lstm_weight_hh_l0, p_module_1_lstm_bias_hh_l0);  getitem = p_module_1_lstm_weight_hh_l0 = p_module_1_lstm_bias_hh_l0 = None\n        add: "f32[1, 1024]" = torch.ops.aten.add.Tensor(linear, linear_1);  linear = linear_1 = None\n        chunk = torch.ops.aten.chunk.default(add, 4, 1);  add = None\n        getitem_5: "f32[1, 256]" = chunk[0]\n        getitem_6: "f32[1, 256]" = chunk[1]\n        getitem_7: "f32[1, 256]" = chunk[2]\n        getitem_8: "f32[1, 256]" = chunk[3];  chunk = None\n        sigmoid: "f32[1, 256]" = torch.ops.aten.sigmoid.default(getitem_5);  getitem_5 = None\n        sigmoid_1: "f32[1, 256]" = torch.ops.aten.sigmoid.default(getitem_6);  getitem_6 = None\n        tanh: "f32[1, 256]" = torch.ops.aten.tanh.default(getitem_7);  getitem_7 = None\n        sigmoid_2: "f32[1, 256]" = torch.ops.aten.sigmoid.default(getitem_8);  getitem_8 = None\n        mul: "f32[1, 256]" = torch.ops.aten.mul.Tensor(getitem_2, sigmoid_1);  getitem_2 = sigmoid_1 = None\n        mul_1: "f32[1, 256]" = torch.ops.aten.mul.Tensor(sigmoid, tanh);  sigmoid = tanh = None\n        add_1: "f32[1, 256]" = torch.ops.aten.add.Tensor(mul, mul_1);  mul = mul_1 = None\n        tanh_1: "f32[1, 256]" = torch.ops.aten.tanh.default(add_1)\n        mul_2: "f32[1, 256]" = torch.ops.aten.mul.Tensor(sigmoid_2, tanh_1);  sigmoid_2 = tanh_1 = None\n        linear_2: "f32[1, 1024]" = torch.ops.aten.linear.default(mul_2, p_module_1_lstm_weight_ih_l1, p_module_1_lstm_bias_ih_l1);  p_module_1_lstm_weight_ih_l1 = p_module_1_lstm_bias_ih_l1 = None\n        linear_3: "f32[1, 1024]" = torch.ops.aten.linear.default(getitem_1, p_module_1_lstm_weight_hh_l1, p_module_1_lstm_bias_hh_l1);  getitem_1 = p_module_1_lstm_weight_hh_l1 = p_module_1_lstm_bias_hh_l1 = None\n        add_2: "f32[1, 1024]" = torch.ops.aten.add.Tensor(linear_2, linear_3);  linear_2 = linear_3 = None\n        chunk_1 = torch.ops.aten.chunk.default(add_2, 4, 1);  add_2 = None\n        getitem_9: "f32[1, 256]" = chunk_1[0]\n        getitem_10: "f32[1, 256]" = chunk_1[1]\n        getitem_11: "f32[1, 256]" = chunk_1[2]\n        getitem_12: "f32[1, 256]" = chunk_1[3];  chunk_1 = None\n        sigmoid_3: "f32[1, 256]" = torch.ops.aten.sigmoid.default(getitem_9);  getitem_9 = None\n        sigmoid_4: "f32[1, 256]" = torch.ops.aten.sigmoid.default(getitem_10);  getitem_10 = None\n        tanh_2: "f32[1, 256]" = torch.ops.aten.tanh.default(getitem_11);  getitem_11 = None\n        sigmoid_5: "f32[1, 256]" = torch.ops.aten.sigmoid.default(getitem_12);  getitem_12 = None\n        mul_3: "f32[1, 256]" = torch.ops.aten.mul.Tensor(getitem_3, sigmoid_4);  getitem_3 = sigmoid_4 = None\n        mul_4: "f32[1, 256]" = torch.ops.aten.mul.Tensor(sigmoid_3, tanh_2);  sigmoid_3 = tanh_2 = None\n        add_3: "f32[1, 256]" = torch.ops.aten.add.Tensor(mul_3, mul_4);  mul_3 = mul_4 = None\n        tanh_3: "f32[1, 256]" = torch.ops.aten.tanh.default(add_3)\n        mul_5: "f32[1, 256]" = torch.ops.aten.mul.Tensor(sigmoid_5, tanh_3);  sigmoid_5 = tanh_3 = None\n        stack: "f32[1, 1, 256]" = torch.ops.aten.stack.default([mul_5], 1)\n        stack_1: "f32[2, 1, 256]" = torch.ops.aten.stack.default([mul_2, mul_5]);  mul_2 = mul_5 = None\n        stack_2: "f32[2, 1, 256]" = torch.ops.aten.stack.default([add_1, add_3]);  add_1 = add_3 = None\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:775 in forward, code: val, hidden0, hidden1 = self._lstm(\n        transpose_2: "f32[1, 2, 256]" = torch.ops.aten.transpose.int(stack_1, 0, 1);  stack_1 = None\n        transpose_3: "f32[1, 2, 256]" = torch.ops.aten.transpose.int(stack_2, 0, 1);  stack_2 = None\n        stack_3: "f32[1, 1, 2, 256]" = torch.ops.aten.stack.default([transpose_2], 1);  transpose_2 = None\n        stack_4: "f32[1, 1, 2, 256]" = torch.ops.aten.stack.default([transpose_3], 1);  transpose_3 = None\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:788 in forward, code: tensordict.update(tensordict_shaped.reshape(shape))\n        reshape: "f32[1, 32]" = torch.ops.aten.reshape.default(unsqueeze_4, [1, 32]);  unsqueeze_4 = None\n        reshape_1: "f32[1, 2, 256]" = torch.ops.aten.reshape.default(stack_3, [1, 2, 256]);  stack_3 = None\n        reshape_2: "f32[1, 2, 256]" = torch.ops.aten.reshape.default(stack_4, [1, 2, 256]);  stack_4 = None\n        reshape_3: "b8[1]" = torch.ops.aten.reshape.default(unsqueeze_7, [1]);  unsqueeze_7 = None\n        reshape_4: "f32[1, 256]" = torch.ops.aten.reshape.default(stack, [1, 256]);  stack = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)\n        linear_4: "f32[1, 64]" = torch.ops.aten.linear.default(reshape_4, p_module_2_module_0_weight, p_module_2_module_0_bias);  p_module_2_module_0_weight = p_module_2_module_0_bias = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:392 in forward, code: return torch.tanh(input)\n        tanh_4: "f32[1, 64]" = torch.ops.aten.tanh.default(linear_4);  linear_4 = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)\n        linear_5: "f32[1, 64]" = torch.ops.aten.linear.default(tanh_4, p_module_2_module_2_weight, p_module_2_module_2_bias);  tanh_4 = p_module_2_module_2_weight = p_module_2_module_2_bias = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:392 in forward, code: return torch.tanh(input)\n        tanh_5: "f32[1, 64]" = torch.ops.aten.tanh.default(linear_5);  linear_5 = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)\n        linear_6: "f32[1, 5]" = torch.ops.aten.linear.default(tanh_5, p_module_2_module_4_weight, p_module_2_module_4_bias);  tanh_5 = p_module_2_module_4_weight = p_module_2_module_4_bias = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/tensordict/nn/sequence.py:540 in forward, code: tensordict_exec = self._run_module(module, tensordict_exec, **kwargs)\n        squeeze_1: "f32[32]" = torch.ops.aten.squeeze.dim(reshape, 0);  reshape = squeeze_1 = None\n        squeeze_2: "f32[2, 256]" = torch.ops.aten.squeeze.dim(reshape_1, 0);  reshape_1 = None\n        squeeze_3: "f32[2, 256]" = torch.ops.aten.squeeze.dim(reshape_2, 0);  reshape_2 = None\n        squeeze_4: "b8[]" = torch.ops.aten.squeeze.dim(reshape_3, 0);  reshape_3 = squeeze_4 = None\n        squeeze_5: "f32[256]" = torch.ops.aten.squeeze.dim(reshape_4, 0);  reshape_4 = squeeze_5 = None\n        squeeze_6: "f32[5]" = torch.ops.aten.squeeze.dim(linear_6, 0);  linear_6 = None\n        return (squeeze_6, squeeze_2, squeeze_3)\n        '

AOTInductor: 将您的策略导出为无PyTorch的C++二进制文件

AOTInductor 是 PyTorch 的一个模块,支持将您的模型(策略或其他模型)导出为不依赖 PyTorch 的 C++ 二进制文件。 这在您需要将模型部署到无法使用 PyTorch 的设备或平台时尤为有用。

以下是一个使用 AOTInductor 导出策略的示例,灵感源自 AOTI 文档

from tempfile import TemporaryDirectory

from torch._inductor import aoti_compile_and_package, aoti_load_package

with TemporaryDirectory() as tmpdir:
    path = str(Path(tmpdir) / "model.pt2")
    with torch.no_grad():
        pkg_path = aoti_compile_and_package(
            exported_policy,
            # Specify the generated shared library path
            package_path=path,
        )
    print("pkg_path", pkg_path)

    compiled_module = aoti_load_package(pkg_path)

print(compiled_module(pixels=pixels))
Traceback (most recent call last):
  File "/pytorch/rl/docs/source/reference/generated/tutorials/export.py", line 351, in <module>
    compiled_module = aoti_load_package(pkg_path)
  File "/pytorch/rl/env/lib/python3.10/site-packages/torch/_inductor/__init__.py", line 196, in aoti_load_package
    return load_package(path)
  File "/pytorch/rl/env/lib/python3.10/site-packages/torch/_inductor/package/package.py", line 287, in load_package
    loader = torch._C._aoti.AOTIModelPackageLoader(path, model_name)  # type: ignore[call-arg]
RuntimeError: Error in dlopen: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.29' not found (required by /tmp/DB4zgV/data/aotinductor/model/cpzj3tnsfucylu2xqny7ltsfgtvx3w226y6duqknjgtewcbybcre.so)

导出TorchRL模型为ONNX

注意

要执行此脚本的这一部分,请确保已安装 pytorch onnx:

!pip install onnx-pytorch
!pip install onnxruntime

您还可以在此处找到有关在 PyTorch 生态系统中使用 ONNX 的更多信息 此处。以下示例基于该 文档。

在本节中,我们将展示如何导出模型,使其能够在不依赖 PyTorch 的环境中运行。

网上有大量资源介绍如何使用 ONNX 将 PyTorch 模型部署到各种硬件和设备上,包括 树莓派(Raspberry Pi)NVIDIA TensorRTiOSAndroid

我们所训练的雅达利(Atari)游戏,无需 TorchRL 或 gymnasium,仅通过 ALE 库 即可独立运行,因此为我们提供了使用 ONNX 所能实现功能的一个良好示例。

让我们看看这个API的样子:

from ale_py import ALEInterface, roms

# Create the interface
ale = ALEInterface()
# Load the pong environment
ale.loadROM(roms.Pong)
ale.reset_game()

# Make a step in the simulator
action = 0
reward = ale.act(action)
screen_obs = ale.getScreenRGB()
print("Observation from ALE simulator:", type(screen_obs), screen_obs.shape)

from matplotlib import pyplot as plt

plt.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
plt.imshow(screen_obs)
plt.title("Screen rendering of Pong game.")

导出到 ONNX 与上面的导出/AOTI 非常相似:

import onnxruntime

with set_exploration_type("DETERMINISTIC"):
    # We use torch.onnx.dynamo_export to capture the computation graph from our policy_explore model
    pixels = torch.as_tensor(screen_obs)
    onnx_policy_export = torch.onnx.dynamo_export(policy_transform, pixels=pixels)

我们现在可以将程序保存到磁盘并加载它:

with TemporaryDirectory() as tmpdir:
    onnx_file_path = str(Path(tmpdir) / "policy.onnx")
    onnx_policy_export.save(onnx_file_path)

    ort_session = onnxruntime.InferenceSession(
        onnx_file_path, providers=["CPUExecutionProvider"]
    )

onnxruntime_input = {ort_session.get_inputs()[0].name: screen_obs}
onnx_policy = ort_session.run(None, onnxruntime_input)

运行带有ONNX的展开

我们现在拥有一个可运行策略的 ONNX 模型。让我们将其与原始的 TorchRL 实例进行对比:由于 ONNX 版本更轻量,其运行速度应快于 TorchRL 版本。

def onnx_policy(screen_obs: np.ndarray) -> int:
    onnxruntime_input = {ort_session.get_inputs()[0].name: screen_obs}
    onnxruntime_outputs = ort_session.run(None, onnxruntime_input)
    action = int(onnxruntime_outputs[0])
    return action


with timeit("ONNX rollout"):
    num_steps = 1000
    ale.reset_game()
    for _ in range(num_steps):
        screen_obs = ale.getScreenRGB()
        action = onnx_policy(screen_obs)
        reward = ale.act(action)

with timeit("TorchRL version"), torch.no_grad(), set_exploration_type("DETERMINISTIC"):
    env.rollout(num_steps, policy_explore)

print(timeit.print())

请注意,ONNX 也提供了直接优化模型的功能,但本教程不涉及此内容。

结论

在本教程中,我们学习了如何使用各种后端导出TorchRL模块,例如PyTorch内置的导出功能、AOTInductorONNX。 我们演示了如何将一个在Atari游戏中训练的策略导出并在不使用pytorch的环境中运行,使用 ALE 库。我们还比较了原始TorchRL实例与导出的ONNX模型的性能。

关键要点:

  • 导出 TorchRL 模块可以在没有安装 PyTorch 的设备上进行部署。

  • AOTInductor 和 ONNX 提供了用于导出模型的替代后端。

  • 优化 ONNX 模型可以提高性能。

进一步阅读和学习步骤:

  • 请查阅 PyTorch 官方文档,了解其 导出功能AOTInductorONNX 的更多信息。

  • 尝试在不同设备上部署导出的模型。

  • 探索优化 ONNX 模型的技术以提高性能。

脚本总运行时间: (0 分钟 58.085 秒)

估计内存使用量: 4641 MB

通过 Sphinx-Gallery 生成的画廊

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源