torch.hub¶
Pytorch Hub 是一个预先训练的模型存储库,旨在促进研究的可重复性。
发布模型¶
Pytorch Hub 支持发布预训练模型(模型定义和预训练权重)
通过添加一个简单的文件添加到 GitHub 存储库;hubconf.py
hubconf.py
可以有多个入口点。每个入口点都定义为一个 python 函数
(示例:您要发布的预训练模型)。
def entrypoint_name(*args, **kwargs):
# args & kwargs are optional, for models which take positional/keyword arguments.
...
如何实现入口点?¶
下面是一个代码片段,如果我们扩展
中的 实现。
在大多数情况下,导入正确的函数就足够了。在这里,我们
只想以扩展版本为例,说明它是如何工作的。
您可以在 pytorch/vision 存储库中查看完整脚本resnet18
pytorch/vision/hubconf.py
hubconf.py
dependencies = ['torch']
from torchvision.models.resnet import resnet18 as _resnet18
# resnet18 is the name of entrypoint
def resnet18(pretrained=False, **kwargs):
""" # This docstring shows up in hub.help()
Resnet18 model
pretrained (bool): kwargs, load pretrained weights into the model
"""
# Call the model, load pretrained weights
model = _resnet18(pretrained=pretrained, **kwargs)
return model
dependencies
variable 是加载模型所需的包名称列表。请注意,这可能会 与训练模型所需的依赖项略有不同。args
并传递给真正的可调用函数。kwargs
该函数的文档字符串用作帮助消息。它解释了模型的作用和作用 是允许的位置/关键字参数。强烈建议在此处添加一些示例。
Entrypoint 函数可以返回一个 model(nn.module),也可以返回辅助工具,使用户工作流程更顺畅,例如 tokenizers。
预训练的权重可以本地存储在 github 存储库中,也可以由
.如果小于 2GB,建议将其附加到项目发布并使用发布中的 URL。 在上面的示例中 handles ,或者,您可以将以下逻辑放在入口点定义中。
torchvision.models.resnet.resnet18
pretrained
if pretrained:
# For checkpoint saved in local github repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pth
dirname = os.path.dirname(__file__)
checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)
state_dict = torch.load(checkpoint)
model.load_state_dict(state_dict)
# For checkpoint saved elsewhere
checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))
重要通知¶
已发布的模型应至少位于分支/标记中。它不能是随机提交。
从 Hub 加载模型¶
Pytorch Hub 提供了方便的 API 来探索 Hub 中的所有可用模型
通过 显示文档字符串和示例,并使用
.
-
torch.hub.
list
(github, force_reload=False, skip_validation=False)[来源]¶ 列出 指定的 repo 中所有可用的可调用入口点。
github
- 参数
github (string) – 格式为 “repo_owner/repo_name[:tag_name]” 的字符串,带有可选的 标签/分支。如果未指定,则默认分支假定为 if 它存在,否则 . 示例: 'pytorch/vision:0.10'
tag_name
main
master
force_reload (bool, optional) – 是否丢弃现有缓存并强制重新下载。 默认值为 。
False
skip_validation (bool, optional) – 如果 , TorchHub 将检查分支或提交 由参数指定正确属于存储库所有者。这将使 对 GitHub API 的请求;您可以通过设置环境变量来指定非默认 GitHub 令牌。默认值为 。
False
github
GITHUB_TOKEN
False
- 返回
可用的 callables 入口点
- 返回类型
例
>>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
-
torch.hub.
help
(github, 模型, force_reload=False, skip_validation=False)[来源]¶ 显示 entrypoint 的文档字符串 。
model
- 参数
github (string) – 格式为 <repo_owner/repo_name[:tag_name]> 的字符串,带有可选的 标签/分支。如果未指定,则默认分支假定为 if 它存在,否则 . 示例: 'pytorch/vision:0.10'
tag_name
main
master
model (string) – 在 repo 的
hubconf.py
force_reload (bool, optional) – 是否丢弃现有缓存并强制重新下载。 默认值为 。
False
skip_validation (bool, optional) – 如果 , TorchHub 将检查分支或提交 由参数指定正确属于存储库所有者。这将使 对 GitHub API 的请求;您可以通过设置环境变量来指定非默认 GitHub 令牌。默认值为 。
False
github
GITHUB_TOKEN
False
例
>>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))
-
torch.hub.
load
(repo_or_dir、model、*args、source='github'、force_reload=False、verbose=True、skip_validation=False, **kwargs)[来源]¶ 从 github 存储库或本地目录加载模型。
注意:加载模型是典型的用例,但这也可以用于 用于加载其他对象,例如 tokenizers、loss 函数等。
如果是 'github',则预期为 的形式中带有可选的 标签/分支。
source
repo_or_dir
repo_owner/repo_name[:tag_name]
如果为 'local',则应为 路径。
source
repo_or_dir
- 参数
repo_or_dir (string) – 如果为 'github', 这应该对应于 format 为 可选标签/分支,例如 'pytorch/vision:0.10'。如果未指定,则 默认分支假定为(如果存在),否则为 。 如果是 'local',则它应该是本地目录的路径。
source
repo_owner/repo_name[:tag_name]
tag_name
main
master
source
model (string) – 在 repo/dir 的 .
hubconf.py
*args (可选) – callable 的相应 args 。
model
source (string, optional) – 'github' 或 'local'。指定如何解释。默认为 'github'。
repo_or_dir
force_reload (bool, optional) – 是否强制全新下载 GitHub 存储库。如果 ,则没有任何效果。默认值为 。
source = 'local'
False
verbose (bool, optional) – 如果 ,将有关击球的消息静音 local caches。请注意,不能显示有关首次下载的消息 温和。如果 ,则没有任何效果。 默认值为 。
False
source = 'local'
True
skip_validation (bool, optional) – 如果 , TorchHub 将检查分支或提交 由参数指定正确属于存储库所有者。这将使 对 GitHub API 的请求;您可以通过设置环境变量来指定非默认 GitHub 令牌。默认值为 。
False
github
GITHUB_TOKEN
False
**kwargs (可选) – 可调用的相应 kwargs 。
model
- 返回
使用给定的 和 调用可调用对象时的输出。
model
*args
**kwargs
例
>>> # from a github repo >>> repo = 'pytorch/vision' >>> model = torch.hub.load(repo, 'resnet50', pretrained=True) >>> # from a local directory >>> path = '/some/local/path/pytorch/vision' >>> model = torch.hub.load(path, 'resnet50', pretrained=True)
-
torch.hub.
download_url_to_file
(url, dst, hash_prefix=None, progress=True)[来源]¶ Download 对象复制到本地路径。
- 参数
url (string) – 要下载的对象的 URL
dst (string) – 保存对象的完整路径,例如
/tmp/temporary_file
hash_prefix (string, optional) – 如果不是 None,则 SHA256 下载的文件应以 . 默认值:无
hash_prefix
progress (bool, optional) – 是否向 stderr 显示进度条 默认值:True
例
>>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')
-
torch.hub.
load_state_dict_from_url
(url、model_dir=无、map_location=无、progress=True、check_hash=False、file_name=没有)[来源]¶ 在给定的 URL 处加载 Torch 序列化对象。
如果下载的文件是 zip 文件,它将自动 减压。
如果对象已经存在于 model_dir 中,则会对其进行反序列化,并且 返回。 的默认值 是 where,其中 是 返回的
目录。
model_dir
<hub_dir>/checkpoints
hub_dir
- 参数
url (string) – 要下载的对象的 URL
model_dir (string, optional) – 保存对象的目录
map_location (可选) – 指定如何重新映射存储位置的函数或字典(参见 torch.load)
progress (bool, optional) – 是否向 stderr 显示进度条。 默认值:True
check_hash (bool, optional) – 如果为 True,则 URL 的文件名部分应遵循命名约定,其中前 8 个或更多 文件内容的 SHA256 哈希值的数字。哈希值用于 确保名称唯一并验证文件的内容。 默认值:False
filename-<sha256>.ext
<sha256>
file_name (string, optional) – 下载文件的名称。如果未设置,则将使用 Filename from。
url
例
>>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
运行加载的模型:¶
请注意,和 in 用于实例化模型。加载模型后,如何查找
你可以用这个模型做什么?
建议的工作流程是
*args
**kwargs
dir(model)
以查看模型的所有可用方法。help(model.foo)
检查运行哪些参数model.foo
为了帮助用户在不来回参考文档的情况下进行探索,我们强烈 建议 Repo 所有者 使函数帮助消息清晰简洁。这也很有帮助 以包含一个最小的工作示例。
我下载的模型保存在哪里?¶
这些位置按
叫
hub.set_dir(<PATH_TO_HUB_DIR>)
$TORCH_HOME/hub
(如果设置了环境变量)。TORCH_HOME
$XDG_CACHE_HOME/torch/hub
(如果设置了环境变量)。XDG_CACHE_HOME
~/.cache/torch/hub
缓存逻辑¶
默认情况下,我们不会在加载文件后清理文件。如果缓存已存在于
由 返回的目录。
用户可以通过调用 来强制重新加载。这将删除
现有的 GitHub 文件夹和下载的权重,重新初始化新的下载。这很有用
当更新发布到同一分支时,用户可以跟上最新版本。hub.load(..., force_reload=True)
已知限制:¶
Torch Hub 的工作原理是导入软件包,就像它已安装一样。有一些副作用
通过在 Python 中导入引入。例如,您可以在 Python 缓存中看到新项目,这是正常的 Python 行为。
这也意味着在导入不同的模型时可能会遇到导入错误
来自不同的存储库,如果存储库具有相同的子包名称(通常是 Subpackage)。此类导入错误的解决方法是
从 dict 中删除有问题的子包;更多细节可以
可以在这个 GitHub 问题中找到。sys.modules
sys.path_importer_cache
model
sys.modules
这里值得一提的一个已知限制:用户无法加载 同一个 Python 进程中的同一个仓库。这就像使用 在 Python 中同名,这并不好。如果您满足以下条件,Cache 可能会加入派对并给您带来惊喜 实际上试试那个。当然,将它们加载到单独的进程中是完全可以的。