目录

torch.hub

Pytorch Hub 是一个预训练模型仓库,旨在促进研究的可重复性。

发布模型

Pytorch Hub支持通过添加一个简单的hubconf.py文件,将预训练模型(模型定义和预训练权重)发布到GitHub仓库。

hubconf.py 可以有多个入口点。每个入口点定义为一个Python函数 (例如:你想发布的预训练模型)。

def entrypoint_name(*args, **kwargs):
    # args & kwargs are optional, for models which take positional/keyword arguments.
    ...

如何实现入口函数?

这里是一个代码片段,指定了如果我们扩展resnet18模型的实现时的入口点。 在大多数情况下,在pytorch/vision/hubconf.py中导入正确的函数就足够了。这里 我们只是想用扩展版本作为一个例子来展示它是如何工作的。 你可以在 pytorch/vision repo 中查看完整的脚本。

dependencies = ['torch']
from torchvision.models.resnet import resnet18 as _resnet18

# resnet18 is the name of entrypoint
def resnet18(pretrained=False, **kwargs):
    """ # This docstring shows up in hub.help()
    Resnet18 model
    pretrained (bool): kwargs, load pretrained weights into the model
    """
    # Call the model, load pretrained weights
    model = _resnet18(pretrained=pretrained, **kwargs)
    return model
  • dependencies 变量是一个列表,其中包含加载模型所需的包名称。请注意,这可能与训练模型所需的依赖项略有不同。

  • argskwargs 被传递给真正的可调用函数。

  • 函数的文档字符串用作帮助信息。它解释了模型的功能以及允许的位置参数/关键字参数。强烈建议在这里添加一些示例。

  • 入口函数可以返回一个模型(nn.module),或者辅助工具以使用户的 workflow 更加顺畅,例如分词器。

  • 以下划线开头的函数被视为辅助函数,不会出现在torch.hub.list()中。

  • 预训练权重可以存储在GitHub仓库中,或者通过torch.hub.load_state_dict_from_url()加载。如果小于2GB,建议将其附加到项目发布中,并使用发布的URL。 在上面的例子中,torchvision.models.resnet.resnet18处理pretrained,或者你也可以在入口点定义中添加以下逻辑。

if pretrained:
    # For checkpoint saved in local GitHub repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pth
    dirname = os.path.dirname(__file__)
    checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)
    state_dict = torch.load(checkpoint)
    model.load_state_dict(state_dict)

    # For checkpoint saved elsewhere
    checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
    model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))

重要通知

  • 发布的模型至少应在某个分支/标签中。它不能是随机的提交。

从中心加载模型

Pytorch Hub 提供了方便的API来浏览hub中所有可用的模型 通过 torch.hub.list(),通过 torch.hub.help() 显示文档字符串和示例,并使用 torch.hub.load() 加载预训练模型。

torch.hub.list(github, force_reload=False, skip_validation=False, trust_repo=None)[source]

列出由github指定的仓库中所有可用的可调用入口点。

Parameters
  • github (字符串) – 格式为 “repo_owner/repo_name[:ref]” 的字符串,其中 ref(标签或分支)是可选的。如果未指定 ref, 则默认分支假设为 main (如果存在),否则为 master。 示例:‘pytorch/vision:0.10’

  • force_reload (bool, 可选) – 是否丢弃现有缓存并强制重新下载。 默认值为False

  • skip_validation (bool, 可选) – 如果为 False,torchhub 将检查由 github 参数指定的分支或提交是否正确属于仓库所有者。这将向 GitHub API 发起请求;你可以通过设置 GITHUB_TOKEN 环境变量来指定非默认的 GitHub 访问令牌。默认值为 False

  • 信任仓库 (boolstrNone) –

    "check", True, FalseNone。 此参数在v1.12版本中引入,有助于确保用户 仅运行他们信任的仓库中的代码。

    • 如果 False,将弹出提示询问用户是否信任该仓库。

    • 如果 True,该仓库将被添加到受信任列表中并加载, 而无需明确确认。

    • 如果 "check",该仓库将与缓存中的可信仓库列表进行比对。如果不在该列表中,则会回退到trust_repo=False选项。

    • 如果 None:这将触发一个警告,提示用户将 trust_repo 设置为 FalseTrue"check" 中的一个。此功能仅为了向后兼容而存在,并将在 v2.0 中移除。

    默认值为 None,并在v2.0版本中最终更改为 "check"

Returns

可用的可调用入口函数

Return type

列表

示例

>>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
torch.hub.help(github, model, force_reload=False, skip_validation=False, trust_repo=None)[source]

显示入口点 model 的文档字符串。

Parameters
  • github (str) – 一个具有格式<repo_owner/repo_name[:ref]>的字符串,可选的 ref(一个标签或分支)。如果未指定ref,则假设默认分支为main,如果不存在,则为master。 示例:‘pytorch/vision:0.10’

  • 模型 (字符串) – 由repo中的hubconf.py定义的入口点名称的字符串

  • force_reload (bool, 可选) – 是否丢弃现有缓存并强制重新下载。 默认值为False

  • 跳过验证 (bool, 可选) – 如果为 False,torchhub 将检查由 github 参数指定的引用是否正确属于仓库所有者。这将向 GitHub API 发起请求;你可以通过设置 GITHUB_TOKEN 环境变量来指定非默认的 GitHub 令牌。默认值为 False

  • 信任仓库 (boolstrNone) –

    "check", True, FalseNone。 此参数在v1.12版本中引入,有助于确保用户 仅运行他们信任的仓库中的代码。

    • 如果 False,将弹出提示询问用户是否信任该仓库。

    • 如果 True,该仓库将被添加到受信任列表中并加载, 而无需明确确认。

    • 如果 "check",该仓库将与缓存中的可信仓库列表进行比对。如果不在该列表中,则会回退到trust_repo=False选项。

    • 如果 None:这将触发一个警告,提示用户将 trust_repo 设置为 FalseTrue"check" 中的一个。此功能仅为了向后兼容而存在,并将在 v2.0 中移除。

    默认值为 None,并在v2.0版本中最终更改为 "check"

示例

>>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))
torch.hub.load(repo_or_dir, model, *args, source='github', trust_repo=None, force_reload=False, verbose=True, skip_validation=False, **kwargs)[source]

从 GitHub 仓库或本地目录加载模型。

注意:加载模型是典型用例,但这也可用于加载其他对象,例如分词器、损失函数等。

如果 source 是 ‘github’,repo_or_dir 应该是 repo_owner/repo_name[:ref] 的形式,并可选地包含 一个引用(一个标签或一个分支)。

如果 source 是 ‘本地’,repo_or_dir 应该是一个指向本地目录的路径。

Parameters
  • repo_or_dir (字符串) – 如果 source 是 ‘github’, 这应该对应于一个格式为 repo_owner/repo_name[:ref] 的 github 仓库,并且可以包含可选的引用(标签或分支),例如 ‘pytorch/vision:0.10’。如果 ref 没有指定, 则默认分支将被假设为 main (如果存在),否则为 master。 如果 source 是 ‘local’,那么它应该是指向本地目录的路径。

  • 模型 (字符串) – 在repo/dir的hubconf.py中定义的可调用项(入口点)的名称。

  • *args (可选) – 用于可调用对象model的相应参数。

  • 来源 (字符串可选) – ‘github’ 或 ‘local’。指定如何解释 repo_or_dir。默认值是 ‘github’。

  • 信任仓库 (boolstrNone) –

    "check", True, FalseNone。 此参数在v1.12版本中引入,有助于确保用户 仅运行他们信任的仓库中的代码。

    • 如果 False,将弹出提示询问用户是否信任该仓库。

    • 如果 True,该仓库将被添加到受信任列表中并加载, 而无需明确确认。

    • 如果 "check",该仓库将与缓存中的可信仓库列表进行比对。如果不在该列表中,则会回退到trust_repo=False选项。

    • 如果 None:这将触发一个警告,提示用户将 trust_repo 设置为 FalseTrue"check" 中的一个。此功能仅为了向后兼容而存在,并将在 v2.0 中移除。

    默认值为 None,并在v2.0版本中最终更改为 "check"

  • force_reload (bool, 可选) – 是否无条件强制下载github仓库。如果 source = 'local',则不会有任何效果。默认值为False

  • 详细模式 (bool, 可选) – 如果为False,则静默显示命中本地缓存的消息。请注意,首次下载的消息无法静默。如果为source = 'local',则不会有任何效果。 默认值为True

  • skip_validation (bool, 可选) – 如果为 False,torchhub 将检查由 github 参数指定的分支或提交是否正确属于仓库所有者。这将向 GitHub API 发起请求;你可以通过设置 GITHUB_TOKEN 环境变量来指定非默认的 GitHub 访问令牌。默认值为 False

  • **kwargs (可选) – 用于可调用 model 的相应参数。

Returns

调用给定的 model 可调用对象时的输出,以及 *args**kwargs

示例

>>> # from a github repo
>>> repo = 'pytorch/vision'
>>> model = torch.hub.load(repo, 'resnet50', weights='ResNet50_Weights.IMAGENET1K_V1')
>>> # from a local directory
>>> path = '/some/local/path/pytorch/vision'
>>> model = torch.hub.load(path, 'resnet50', weights='ResNet50_Weights.DEFAULT')
torch.hub.download_url_to_file(url, dst, hash_prefix=None, progress=True)[source]

下载给定网址的对象到本地路径。

Parameters
  • url (str) – 要下载对象的URL

  • 目标路径 (字符串) – 对象将被保存的完整路径,例如 /tmp/temporary_file

  • hash_prefix (字符串, 可选) – 如果不为 None,下载文件的 SHA256 哈希值应以 hash_prefix 开头。 默认值: None

  • 进度 (bool, 可选) – 是否在stderr中显示进度条 默认值:True

示例

>>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')
torch.hub.load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None, weights_only=False)[source]

加载给定网址上的 Torch 序列化对象。

如果下载的文件是 zip 文件,它将被自动解压缩。

如果对象已经在model_dir中,它会被反序列化并返回。 默认值为model_dir,即<hub_dir>/checkpoints,其中 hub_dir是由get_dir()返回的目录。

Parameters
  • url (str) – 要下载对象的URL

  • 模型目录 (str, 可选) – 保存对象的目录

  • map_location (可选) – 指定如何重映射存储位置的函数或字典(参见 torch.load)

  • 进度 (bool, 可选) – 是否在stderr中显示进度条。 默认值:True

  • check_hash (bool, 可选) – 如果为 True,URL 的文件名部分应遵循命名约定 filename-<sha256>.ext 其中 <sha256> 是文件内容的 SHA256 哈希值的前八个或更多位数字。该哈希用于确保名称唯一并验证文件内容。 默认值:False

  • file_name (字符串, 可选) – 下载文件的名称。如果没有设置,将使用url中的文件名。

  • 仅权重 (bool, 可选) – 如果为True,将只加载权重而不加载复杂的pickle对象。 推荐用于不信任的来源。有关更多详细信息,请参见load()

Return type

字典[字符串, 任意类型]

示例

>>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')

运行加载的模型:

注意*args**kwargstorch.hub.load()中用于 实例化一个模型。加载模型后,如何找出可以对模型进行的操作? 建议的工作流是

  • dir(model) 查看模型的所有可用方法。

  • help(model.foo) 检查运行所需的参数 model.foo

为了帮助用户在不频繁查阅文档的情况下进行探索,我们强烈建议仓库所有者使函数帮助信息清晰简洁。包含一个最小工作示例也很有帮助。

我下载的模型保存在哪里?

位置按顺序使用

  • 调用 hub.set_dir(<PATH_TO_HUB_DIR>)

  • $TORCH_HOME/hub,如果环境变量 TORCH_HOME 被设置。

  • $XDG_CACHE_HOME/torch/hub,如果环境变量 XDG_CACHE_HOME 被设置。

  • ~/.cache/torch/hub

torch.hub.get_dir()[source]

获取用于存储下载模型和权重的 Torch Hub 缓存目录。

如果set_dir()没有被调用,默认路径是$TORCH_HOME/hub,其中环境变量$TORCH_HOME默认为$XDG_CACHE_HOME/torch$XDG_CACHE_HOME遵循Linux文件系统布局的X设计组规范,如果未设置环境变量,则默认值为~/.cache

torch.hub.set_dir(d)[source]

可选设置用于保存下载模型及权重的Torch Hub目录。

Parameters

d (str) – 保存下载模型和权重的本地文件夹路径。

缓存逻辑

默认情况下,我们在加载后不会清理文件。如果缓存已经存在于由get_dir()返回的目录中,Hub 默认会使用该缓存。

用户可以通过调用hub.load(..., force_reload=True)强制刷新。这将删除现有的GitHub文件夹和已下载的权重,并重新初始化一个新的下载。当更新发布到同一分支时,此操作有助于用户跟上最新版本。

已知限制:

Torch hub 的工作原理是像安装包一样导入。在 Python 中导入会引入一些副作用。 例如,您可以在 Python 缓存中看到新的条目 sys.modulessys.path_importer_cache,这是正常的 Python 行为。 这也意味着当您从不同的仓库导入不同的模型时,可能会遇到导入错误, 如果这些仓库有相同的子包名称(通常是 model 子包)。解决这类导入错误的方法是 从 sys.modules 字典中删除有问题的子包;更多详情可以 在 此 GitHub 问题 中找到。

这里值得一提的一个已知限制是:用户不能在同一个Python进程中加载同一仓库的两个不同分支。这就像在Python中安装两个同名的包,不太好。如果你真的尝试这样做,缓存可能会带来一些意想不到的问题。当然,在不同的进程中加载它们是没有问题的。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源