torch.utils.module_tracker¶
此工具可用于跟踪当前在 torch.nn.Module 层次结构中的位置。
它可以在其他跟踪工具中使用,以便能够轻松地将测量的数值与用户友好的名称相关联。这目前特别用于 FlopCounterMode 中。
- class torch.utils.module_tracker.ModuleTracker[source][source]¶
ModuleTracker是一个上下文管理器,用于在执行过程中跟踪 nn.Module 层级结构, 以便其他系统可以查询当前正在执行的 Module(或其反向传播正在执行)。你可以访问此上下文管理器的
parents属性,以获取当前通过其 fqn(全限定名,也用作 state_dict 中的键)执行的所有 Modules 的集合。 你可以访问is_bw属性来判断你当前是否处于反向传播中或不是。请注意
parents永远不会为空,并且始终包含“Global”键。标志is_bw在前向传播之后将保持为True,直到执行另一个模块。如果你需要更精确的功能,请提交一个请求此功能的问题。添加从 fqn 到模块实例的映射 是可能的但尚未实现,如果你需要它,请提交一个请求此功能的问题。示例用法
mod = torch.nn.Linear(2, 2) with ModuleTracker() as tracker: # Access anything during the forward pass def my_linear(m1, m2, bias): print(f"Current modules: {tracker.parents}") return torch.mm(m1, m2.t()) + bias torch.nn.functional.linear = my_linear mod(torch.rand(2, 2))