跟踪 TensorDictModule¶
我们支持跟踪执行以创建 FX 图表。只需导入 from 而不是 .TensorDictModule
symbolic_trace
tensordict.prototype.fx
torch.fx
注意
支持是高度实验性的,可能会发生变化。请谨慎使用,如果您尝试并遇到问题,请提出问题。torch.fx
跟踪TensorDictModule
¶
我们将用概述中的一个示例来说明。我们创建一个 ,跟踪它,并检查图形和生成的代码。TensorDictModule
>>> import torch
>>> import torch.nn as nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> from tensordict.prototype.fx import symbolic_trace
>>> class Net(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.LazyLinear(1)
...
... def forward(self, x):
... logits = self.linear(x)
... return logits, torch.sigmoid(logits)
>>> module = TensorDictModule(
... Net(),
... in_keys=["input"],
... out_keys=[("outputs", "logits"), ("outputs", "probabilities")],
... )
>>> graph_module = symbolic_trace(module)
>>> print(graph_module.graph)
graph():
%tensordict : [#users=1] = placeholder[target=tensordict]
%getitem : [#users=1] = call_function[target=operator.getitem](args = (%tensordict, input), kwargs = {})
%linear : [#users=2] = call_module[target=linear](args = (%getitem,), kwargs = {})
%sigmoid : [#users=1] = call_function[target=torch.sigmoid](args = (%linear,), kwargs = {})
return (linear, sigmoid)
>>> print(graph_module.code)
def forward(self, tensordict):
getitem = tensordict['input']; tensordict = None
linear = self.linear(getitem); getitem = None
sigmoid = torch.sigmoid(linear)
return (linear, sigmoid)
我们可以检查每个 module 的 forward pass 是否会产生相同的 output。
>>> tensordict = TensorDict({"input": torch.randn(32, 100)}, [32])
>>> module_out = module(tensordict, tensordict_out=TensorDict())
>>> graph_module_out = graph_module(tensordict, tensordict_out=TensorDict())
>>> assert (
... module_out["outputs", "logits"] == graph_module_out["outputs", "logits"]
... ).all()
>>> assert (
... module_out["outputs", "probabilities"]
... == graph_module_out["outputs", "probabilities"]
... ).all()
跟踪TensorDictSequential
¶
我们还可以追踪 .在这种情况下,模块的整个执行过程被追踪到单个图形中,从而消除了对 input 的中间读取和写入。TensorDictSequential
TensorDict
我们通过跟踪概述中的顺序示例来演示。
>>> import torch
>>> import torch.nn as nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule, TensorDictSequential
>>> from tensordict.prototype.fx import symbolic_trace
>>> class Net(nn.Module):
... def __init__(self, input_size=100, hidden_size=50, output_size=10):
... super().__init__()
... self.fc1 = nn.Linear(input_size, hidden_size)
... self.fc2 = nn.Linear(hidden_size, output_size)
...
... def forward(self, x):
... x = torch.relu(self.fc1(x))
... return self.fc2(x)
...
... class Masker(nn.Module):
... def forward(self, x, mask):
... return torch.softmax(x * mask, dim=1)
>>> net = TensorDictModule(
... Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")]
... )
>>> masker = TensorDictModule(
... Masker(),
... in_keys=[("intermediate", "x"), ("input", "mask")],
... out_keys=[("output", "probabilities")],
... )
>>> module = TensorDictSequential(net, masker)
>>> graph_module = symbolic_trace(module)
>>> print(graph_module.code)
def forward(self, tensordict):
getitem = tensordict[('input', 'x')]
_0_fc1 = getattr(self, "0").module.fc1(getitem); getitem = None
relu = torch.relu(_0_fc1); _0_fc1 = None
_0_fc2 = getattr(self, "0").module.fc2(relu); relu = None
getitem_1 = tensordict[('input', 'mask')]; tensordict = None
mul = _0_fc2 * getitem_1; getitem_1 = None
softmax = torch.softmax(mul, dim = 1); mul = None
return (_0_fc2, softmax)
在这种情况下,生成的图形和代码稍微复杂一些。我们可以按如下方式可视化它(需要pydot
)
>>> from torch.fx.passes.graph_drawer import FxGraphDrawer
>>> g = FxGraphDrawer(graph_module, "sequential")
>>> with open("graph.svg", "wb") as f:
... f.write(g.get_dot_graph().create_svg())
这将产生以下可视化项