目录

torch.hub

Pytorch Hub 是一个预先训练的模型存储库,旨在促进研究的可重复性。

发布模型

Pytorch Hub 支持发布预训练模型(模型定义和预训练权重) 通过添加一个简单的文件添加到 GitHub 存储库;hubconf.py

hubconf.py可以有多个入口点。每个入口点都定义为一个 python 函数 (示例:您要发布的预训练模型)。

def entrypoint_name(*args, **kwargs):
    # args & kwargs are optional, for models which take positional/keyword arguments.
    ...

如何实现入口点?

下面是一个代码片段,如果我们扩展 中的 实现。 在大多数情况下,导入正确的函数就足够了。在这里,我们 只想以扩展版本为例,说明它是如何工作的。 您可以在 pytorch/vision 存储库中查看完整脚本resnet18pytorch/vision/hubconf.pyhubconf.py

dependencies = ['torch']
from torchvision.models.resnet import resnet18 as _resnet18

# resnet18 is the name of entrypoint
def resnet18(pretrained=False, **kwargs):
    """ # This docstring shows up in hub.help()
    Resnet18 model
    pretrained (bool): kwargs, load pretrained weights into the model
    """
    # Call the model, load pretrained weights
    model = _resnet18(pretrained=pretrained, **kwargs)
    return model
  • dependenciesvariable 是加载模型所需的包名称列表。请注意,这可能会 与训练模型所需的依赖项略有不同。

  • args并传递给真正的可调用函数。kwargs

  • 该函数的文档字符串用作帮助消息。它解释了模型的作用和作用 是允许的位置/关键字参数。强烈建议在此处添加一些示例。

  • Entrypoint 函数可以返回一个 model(nn.module),也可以返回辅助工具,使用户工作流程更顺畅,例如 tokenizers。

  • 以下划线为前缀的可调用对象被视为辅助函数,不会显示在 .

  • 预训练的权重可以本地存储在 github 存储库中,也可以由 .如果小于 2GB,建议将其附加到项目发布并使用发布中的 URL。 在上面的示例中 handles ,或者,您可以将以下逻辑放在入口点定义中。torchvision.models.resnet.resnet18pretrained

if pretrained:
    # For checkpoint saved in local github repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pth
    dirname = os.path.dirname(__file__)
    checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)
    state_dict = torch.load(checkpoint)
    model.load_state_dict(state_dict)

    # For checkpoint saved elsewhere
    checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
    model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))

重要通知

  • 已发布的模型应至少位于分支/标记中。它不能是随机提交。

从 Hub 加载模型

Pytorch Hub 提供了方便的 API 来探索 Hub 中的所有可用模型 通过 显示文档字符串和示例,并使用 .

torch.hub.list(githubforce_reload=Falseskip_validation=False[来源]

列出 指定的 repo 中所有可用的可调用入口点。github

参数
  • githubstring) – 格式为 “repo_owner/repo_name[:tag_name]” 的字符串,带有可选的 标签/分支。如果未指定,则默认分支假定为 if 它存在,否则 . 示例: 'pytorch/vision:0.10'tag_namemainmaster

  • force_reloadbooloptional) – 是否丢弃现有缓存并强制重新下载。 默认值为 。False

  • skip_validationbooloptional) – 如果 , TorchHub 将检查分支或提交 由参数指定正确属于存储库所有者。这将使 对 GitHub API 的请求;您可以通过设置环境变量来指定非默认 GitHub 令牌。默认值为 。FalsegithubGITHUB_TOKENFalse

返回

可用的 callables 入口点

返回类型

列表

>>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
torch.hub.help(github模型force_reload=Falseskip_validation=False[来源]

显示 entrypoint 的文档字符串 。model

参数
  • githubstring) – 格式为 <repo_owner/repo_name[:tag_name]> 的字符串,带有可选的 标签/分支。如果未指定,则默认分支假定为 if 它存在,否则 . 示例: 'pytorch/vision:0.10'tag_namemainmaster

  • modelstring) – 在 repo 的hubconf.py

  • force_reloadbooloptional) – 是否丢弃现有缓存并强制重新下载。 默认值为 。False

  • skip_validationbooloptional) – 如果 , TorchHub 将检查分支或提交 由参数指定正确属于存储库所有者。这将使 对 GitHub API 的请求;您可以通过设置环境变量来指定非默认 GitHub 令牌。默认值为 。FalsegithubGITHUB_TOKENFalse

>>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))
torch.hub.load(repo_or_dirmodel*argssource='github'force_reload=Falseverbose=Trueskip_validation=False**kwargs[来源]

从 github 存储库或本地目录加载模型。

注意:加载模型是典型的用例,但这也可以用于 用于加载其他对象,例如 tokenizers、loss 函数等。

如果是 'github',则预期为 的形式中带有可选的 标签/分支。sourcerepo_or_dirrepo_owner/repo_name[:tag_name]

如果为 'local',则应为 路径。sourcerepo_or_dir

参数
  • repo_or_dirstring) – 如果为 'github', 这应该对应于 format 为 可选标签/分支,例如 'pytorch/vision:0.10'。如果未指定,则 默认分支假定为(如果存在),否则为 。 如果是 'local',则它应该是本地目录的路径。sourcerepo_owner/repo_name[:tag_name]tag_namemainmastersource

  • modelstring) – 在 repo/dir 的 .hubconf.py

  • *args可选) – callable 的相应 args 。model

  • sourcestringoptional) – 'github' 或 'local'。指定如何解释。默认为 'github'。repo_or_dir

  • force_reloadbooloptional) – 是否强制全新下载 GitHub 存储库。如果 ,则没有任何效果。默认值为 。source = 'local'False

  • verbosebooloptional) – 如果 ,将有关击球的消息静音 local caches。请注意,不能显示有关首次下载的消息 温和。如果 ,则没有任何效果。 默认值为 。Falsesource = 'local'True

  • skip_validationbooloptional) – 如果 , TorchHub 将检查分支或提交 由参数指定正确属于存储库所有者。这将使 对 GitHub API 的请求;您可以通过设置环境变量来指定非默认 GitHub 令牌。默认值为 。FalsegithubGITHUB_TOKENFalse

  • **kwargs可选) – 可调用的相应 kwargs 。model

返回

使用给定的 和 调用可调用对象时的输出。model*args**kwargs

>>> # from a github repo
>>> repo = 'pytorch/vision'
>>> model = torch.hub.load(repo, 'resnet50', pretrained=True)
>>> # from a local directory
>>> path = '/some/local/path/pytorch/vision'
>>> model = torch.hub.load(path, 'resnet50', pretrained=True)
torch.hub.download_url_to_file(urldsthash_prefix=Noneprogress=True[来源]

Download 对象复制到本地路径。

参数
  • urlstring) – 要下载的对象的 URL

  • dststring) – 保存对象的完整路径,例如/tmp/temporary_file

  • hash_prefixstringoptional) – 如果不是 None,则 SHA256 下载的文件应以 . 默认值:无hash_prefix

  • progressbooloptional) – 是否向 stderr 显示进度条 默认值:True

>>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')
torch.hub.load_state_dict_from_url(urlmodel_dir=map_location=progress=Truecheck_hash=Falsefile_name=没有[来源]

在给定的 URL 处加载 Torch 序列化对象。

如果下载的文件是 zip 文件,它将自动 减压。

如果对象已经存在于 model_dir 中,则会对其进行反序列化,并且 返回。 的默认值 是 where,其中 是 返回的目录。model_dir<hub_dir>/checkpointshub_dir

参数
  • urlstring) – 要下载的对象的 URL

  • model_dirstringoptional) – 保存对象的目录

  • map_location可选) – 指定如何重新映射存储位置的函数或字典(参见 torch.load)

  • progressbooloptional) – 是否向 stderr 显示进度条。 默认值:True

  • check_hashbooloptional) – 如果为 True,则 URL 的文件名部分应遵循命名约定,其中前 8 个或更多 文件内容的 SHA256 哈希值的数字。哈希值用于 确保名称唯一并验证文件的内容。 默认值:Falsefilename-<sha256>.ext<sha256>

  • file_namestringoptional) – 下载文件的名称。如果未设置,则将使用 Filename from。url

>>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')

运行加载的模型:

请注意,和 in 用于实例化模型。加载模型后,如何查找 你可以用这个模型做什么? 建议的工作流程是*args**kwargs

  • dir(model)以查看模型的所有可用方法。

  • help(model.foo)检查运行哪些参数model.foo

为了帮助用户在不来回参考文档的情况下进行探索,我们强烈 建议 Repo 所有者 使函数帮助消息清晰简洁。这也很有帮助 以包含一个最小的工作示例。

我下载的模型保存在哪里?

这些位置按

  • hub.set_dir(<PATH_TO_HUB_DIR>)

  • $TORCH_HOME/hub(如果设置了环境变量)。TORCH_HOME

  • $XDG_CACHE_HOME/torch/hub(如果设置了环境变量)。XDG_CACHE_HOME

  • ~/.cache/torch/hub

torch.hub.get_dir()[来源]

获取用于存储下载模型和重量的torch.hub缓存目录。

如果未调用,则默认路径为 环境变量默认为 。 遵循 Linux 的 X Design Group 规范 filesystem 布局,如果环境 变量。$TORCH_HOME/hub$TORCH_HOME$XDG_CACHE_HOME/torch$XDG_CACHE_HOME~/.cache

torch.hub.set_dir(d[来源]

可选设置用于保存下载的模型和重量的torch.hub目录。

参数

dstring) – 本地文件夹的路径,用于保存下载的模型和权重。

缓存逻辑

默认情况下,我们不会在加载文件后清理文件。如果缓存已存在于 由 返回的目录。

用户可以通过调用 来强制重新加载。这将删除 现有的 GitHub 文件夹和下载的权重,重新初始化新的下载。这很有用 当更新发布到同一分支时,用户可以跟上最新版本。hub.load(..., force_reload=True)

已知限制:

Torch Hub 的工作原理是导入软件包,就像它已安装一样。有一些副作用 通过在 Python 中导入引入。例如,您可以在 Python 缓存中看到新项目,这是正常的 Python 行为。 这也意味着在导入不同的模型时可能会遇到导入错误 来自不同的存储库,如果存储库具有相同的子包名称(通常是 Subpackage)。此类导入错误的解决方法是 从 dict 中删除有问题的子包;更多细节可以 可以在这个 GitHub 问题中找到。sys.modulessys.path_importer_cachemodelsys.modules

这里值得一提的一个已知限制:用户无法加载 同一个 Python 进程中的同一个仓库。这就像使用 在 Python 中同名,这并不好。如果您满足以下条件,Cache 可能会加入派对并给您带来惊喜 实际上试试那个。当然,将它们加载到单独的进程中是完全可以的。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源