目录

跟踪

概述与用法

注意

实验性功能,风险自担,API 可能会更改。

在TorchX中,应用程序是二进制文件(可执行文件), 因此没有内置的方法来“返回”应用程序的结果。 torchx.runtime.tracking 模块允许应用程序 返回简单的结果(注意关键词“简单”)。跟踪器模块支持的返回类型是有意受限的。例如, 尝试返回训练好的模型权重,这些权重可能有数百GB之大, 是不允许的。此模块不适用于传递大量数据或二进制数据块。

当应用程序作为更高层次的协调努力的一部分启动时(例如工作流、管道、超参数优化),通常情况下,应用程序的结果需要被协调器或其他工作流中的应用程序访问。

假设 App1 和 App2 依次启动,并且 App1 的输出作为 App2 的输入。由于这些是二进制文件,通常情况下,应用程序之间链接输入/输出的方式是通过将 App1 的输出文件路径作为 App2 的输入文件路径传递:

$ app1 --output-file s3://foo/out/app1.out
$ app2 --input-file s3://foo/out/app1.out

虽然这看起来很简单,但仍有一些需要注意的地方:

  1. 文件的格式 app1.out (app1 需要以 app2 能理解的格式写入)

  2. 实际上解析 URL 并写入/读取输出文件

因此,应用程序的主函数最终看起来像这样(为了演示目的,此处为伪代码):

# in app1.py
if __name__ == "__main__":
   accuracy = do_something()
   s3client = ...
   out = {"accuracy": accuracy}

   with open("/tmp/out", "w") as f:
       f = json.dumps(out).encode("utf-8")

   s3client.put(args.output_file, f)

# in app2.py
if __name__ == "__main__":
   s3client = ...
   with open("/tmp/out", "w") as f:
       s3client.get(args.input_file, f)

   with open("/tmp/out", "r") as f:
       in = json.loads(f.read().decode("utf-8"))

   do_something_else(in["accuracy"])

相反,可以使用具有相同 tracker_base 的跟踪器, 该跟踪器可以在不同的应用程序之间使用,使一个应用程序的返回值 可供另一个应用程序使用,而无需链接一个应用程序的输出文件路径 与另一个应用程序的输入文件路径,并处理自定义序列化和文件写入。

# in app1.py
if __name__ == "__main__":
   accuracy = do_something()
   tracker = FsspecResultTracker(args.tracker_base)
   tracker["app1_out"] = {"accuracy": accuracy}

# in app2.py
if __name__ == "__main__":
   tracker = FsspecResultTracker(args.tracker_base)
   app1_accuracy = tracker["app1_out"]
   do_something_else(app1_accuracy)

ResultTracker

基础

class torchx.runtime.tracking.ResultTracker[source]

基础结果跟踪器,应该被子类化以实现具体的跟踪器。 通常每个存储后端都存在一个跟踪器实现。

Usage:

# get and put APIs can be used directly or in map-like API
# the following are equivalent
tracker.put("foo", l2norm=1.2)
tracker["foo"] = {"l2norm": 1.2}

# so are these
tracker.get("foo")["l2norm"] == 1.2
tracker["foo"]["l2norm"] == 1.2

有效的 result 类型为:

  1. 数字:整数、浮点数

  2. 字面量:str(UTF-8 编码时大小限为 1KB)

有效的 key 类型为:

  1. int

  2. str

作为惯例,“斜杠”可以用于键中以存储统计结果。例如,要存储 l2norm 的均值和标准误差:

tracker[key] = {"l2norm/mean" : 1.2, "l2norm/sem": 3.4}
tracker[key]["l2norm/mean"] # returns 1.2
tracker[key]["l2norm/sem"] # returns 3.4

在跟踪器的支持存储范围内,键被假设计为唯一。例如,如果一个跟踪器由本地目录支持,并且key是保存结果的目录中的文件,则

# same key, different backing directory -> results are not overwritten
FsspecResultTracker("/tmp/foo")["1"] = {"l2norm":1.2}
FsspecResultTracker("/tmp/bar")["1"] = {"l2norm":3.4}

跟踪器不是一个中心实体,因此在同一个键上的putget操作之间不会提供强一致性保证(超出存储库提供的保证)。同样,在同一个键上的两个连续的putget操作之间也不会提供强一致性保证。

例如:

tracker[1] = {"l2norm":1.2}
tracker[1] = {"l2norm":3.4}
tracker[1] # NOT GUARANTEED TO BE 3.4!

sleep(1*MIN)
tracker[1] # more likely to be 3.4 but still not guaranteed!

强烈建议使用唯一的 ID 作为键。这个 ID 通常是简单任务的作业 ID,也可以是迭代应用程序(如超参数优化)中实验 ID 和试验编号的组合,或者是作业 ID 和副本/工作者排名的组合。

Fsspec

class torchx.runtime.tracking.FsspecResultTracker(tracker_base: str)[source]

使用fsspec作为底层框架来保存结果的追踪器。

Usage:

from torchx.runtime.tracking import FsspecResultTracker

# PUT: in trainer.py
tracker_base = "/tmp/foobar" # also supports URIs (e.g. "s3://bucket/trainer/123")
tracker = FsspecResultTracker(tracker_base)
tracker["attempt_1/out"] = {"accuracy": 0.233}

# GET: anywhere outside trainer.py
tracker = FsspecResultTracker(tracker_base)
print(tracker["attempt_1/out"]["accuracy"])
0.233

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源