torch.hub¶
Pytorch Hub 是一个预先训练的模型存储库,旨在促进研究的可重复性。
发布模型¶
Pytorch Hub 支持发布预训练模型(模型定义和预训练权重)
通过添加简单文件添加到 GitHub 存储库;hubconf.py
hubconf.py
可以有多个入口点。每个入口点都定义为一个 python 函数
(示例:您要发布的预训练模型)。
def entrypoint_name(*args, **kwargs):
# args & kwargs are optional, for models which take positional/keyword arguments.
...
如何实现入口点?¶
下面是一个代码片段,如果我们扩展
中的 实现。
在大多数情况下,导入正确的函数就足够了。在这里,我们
只想以扩展版本为例,说明它是如何工作的。
您可以在 pytorch/vision 存储库中查看完整脚本resnet18
pytorch/vision/hubconf.py
hubconf.py
dependencies = ['torch']
from torchvision.models.resnet import resnet18 as _resnet18
# resnet18 is the name of entrypoint
def resnet18(pretrained=False, **kwargs):
""" # This docstring shows up in hub.help()
Resnet18 model
pretrained (bool): kwargs, load pretrained weights into the model
"""
# Call the model, load pretrained weights
model = _resnet18(pretrained=pretrained, **kwargs)
return model
dependencies
variable 是加载模型所需的包名称列表。请注意,这可能会 与训练模型所需的依赖项略有不同。args
并传递给真正的可调用函数。kwargs
该函数的文档字符串用作帮助消息。它解释了模型的作用和作用 是允许的位置/关键字参数。强烈建议在此处添加一些示例。
Entrypoint 函数可以返回一个 model(nn.module),也可以返回辅助工具,使用户工作流程更顺畅,例如 tokenizers。
预训练的权重可以本地存储在 GitHub 存储库中,也可以由
.如果小于 2GB,建议将其附加到项目发布并使用发布中的 URL。 在上面的示例中 handles ,或者,您可以将以下逻辑放在入口点定义中。
torchvision.models.resnet.resnet18
pretrained
if pretrained:
# For checkpoint saved in local GitHub repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pth
dirname = os.path.dirname(__file__)
checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)
state_dict = torch.load(checkpoint)
model.load_state_dict(state_dict)
# For checkpoint saved elsewhere
checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))
重要通知¶
已发布的模型应至少位于分支/标记中。它不能是随机提交。
从 Hub 加载模型¶
Pytorch Hub 提供了方便的 API 来探索 Hub 中的所有可用模型
通过 显示文档字符串和示例,并使用
.
- torch.hub 的list(github, force_reload=False, skip_validation=False, trust_repo=无, verbose=真)[来源]¶
列出 指定的 repo 中所有可用的可调用入口点。
github
- 参数
github (str) – 格式为 “repo_owner/repo_name[:ref]” 的字符串,带有可选的 ref (标签或分支)。如果未指定,则默认分支假定为 if 它存在,否则 . 示例: 'pytorch/vision:0.10'
ref
main
master
force_reload (bool, optional) – 是否丢弃现有缓存并强制重新下载。 默认值为 。
False
skip_validation (bool, optional) – 如果 , TorchHub 将检查分支或提交 由参数指定正确属于存储库所有者。这将使 对 GitHub API 的请求;您可以通过设置环境变量来指定非默认 GitHub 令牌。默认值为 。
False
github
GITHUB_TOKEN
False
trust_repo (bool, str 或 None) –
"check"
或。 该参数在 v1.12 中引入,有助于确保用户 仅从他们信任的存储库运行代码。True
False
None
如果 ,则提示将询问用户是否应将存储库 值得信赖。
False
如果 ,则存储库将被添加到受信任列表并加载 而无需明确确认。
True
如果 ,将根据 缓存中的 trusted repos。如果该列表中不存在该 ID,则 behaviour 将回退到该选项。
"check"
trust_repo=False
如果 :这将引发警告,邀请用户设置为 、 或 。这 仅出于向后兼容性而存在,并将在 2.0 版。
None
trust_repo
False
True
"check"
默认值是,并且最终将更改为 v2.0。
None
"check"
verbose (bool, optional) – 如果 ,将有关击球的消息静音 local caches。请注意,不能显示有关首次下载的消息 温和。默认值为 。
False
True
- 返回
可用的 callables 入口点
- 返回类型
例
>>> entrypoints = torch.hub.list("pytorch/vision", force_reload=True)
- torch.hub 的help(github, model, force_reload=False, skip_validation=False, trust_repo=None)[来源]¶
显示 entrypoint 的文档字符串 。
model
- 参数
github (str) – 格式为 <repo_owner/repo_name[:ref]> 的字符串,带有可选的 ref (标签或分支)。如果未指定,则假定默认分支 如果存在,则为 ,否则为 。 示例: 'pytorch/vision:0.10'
ref
main
master
model (str) – 在存储库的
hubconf.py
force_reload (bool, optional) – 是否丢弃现有缓存并强制重新下载。 默认值为 。
False
skip_validation (bool, optional) – 如果 , TorchHub 将检查 ref 由参数指定正确属于存储库所有者。这将使 对 GitHub API 的请求;您可以通过设置环境变量来指定非默认 GitHub 令牌。默认值为 。
False
github
GITHUB_TOKEN
False
trust_repo (bool, str 或 None) –
"check"
或。 该参数在 v1.12 中引入,有助于确保用户 仅从他们信任的存储库运行代码。True
False
None
如果 ,则提示将询问用户是否应将存储库 值得信赖。
False
如果 ,则存储库将被添加到受信任列表并加载 而无需明确确认。
True
如果 ,将根据 缓存中的 trusted repos。如果该列表中不存在该 ID,则 behaviour 将回退到该选项。
"check"
trust_repo=False
如果 :这将引发警告,邀请用户设置为 、 或 。这 仅出于向后兼容性而存在,并将在 2.0 版。
None
trust_repo
False
True
"check"
默认值是,并且最终将更改为 v2.0。
None
"check"
例
>>> print(torch.hub.help("pytorch/vision", "resnet18", force_reload=True))
- torch.hub 的load(repo_or_dir, model, *args, source='github', trust_repo=无, force_reload=False, verbose=True, skip_validation=False, **kwargs)[来源]¶
从 github 存储库或本地目录加载模型。
注意:加载模型是典型的用例,但这也可以用于 用于加载其他对象,例如 tokenizers、loss 函数等。
如果是 'github',则预期为 的形式中带有可选的 ref (标签或分支)。
source
repo_or_dir
repo_owner/repo_name[:ref]
如果为 'local',则应为 路径。
source
repo_or_dir
- 参数
repo_or_dir (str) – 如果为 'github', 这应该对应于 format 为 可选的 ref(标签或分支),例如 'pytorch/vision:0.10'。如果未指定,则 默认分支假定为(如果存在),否则为 。 如果是 'local',则它应该是本地目录的路径。
source
repo_owner/repo_name[:ref]
ref
main
master
source
model (str) – 在 repo/dir 的 .
hubconf.py
*args (可选) – callable 的相应 args 。
model
source (str, optional) – 'github' 或 'local'。指定如何解释。默认为 'github'。
repo_or_dir
trust_repo (bool, str 或 None) –
"check"
或。 该参数在 v1.12 中引入,有助于确保用户 仅从他们信任的存储库运行代码。True
False
None
如果 ,则提示将询问用户是否应将存储库 值得信赖。
False
如果 ,则存储库将被添加到受信任列表并加载 而无需明确确认。
True
如果 ,将根据 缓存中的 trusted repos。如果该列表中不存在该 ID,则 behaviour 将回退到该选项。
"check"
trust_repo=False
如果 :这将引发警告,邀请用户设置为 、 或 。这 仅出于向后兼容性而存在,并将在 2.0 版。
None
trust_repo
False
True
"check"
默认值是,并且最终将更改为 v2.0。
None
"check"
force_reload (bool, optional) – 是否强制全新下载 GitHub 存储库。如果 ,则没有任何效果。默认值为 。
source = 'local'
False
verbose (bool, optional) – 如果 ,将有关击球的消息静音 local caches。请注意,不能显示有关首次下载的消息 温和。如果 ,则没有任何效果。 默认值为 。
False
source = 'local'
True
skip_validation (bool, optional) – 如果 , TorchHub 将检查分支或提交 由参数指定正确属于存储库所有者。这将使 对 GitHub API 的请求;您可以通过设置环境变量来指定非默认 GitHub 令牌。默认值为 。
False
github
GITHUB_TOKEN
False
**kwargs (可选) – 可调用的相应 kwargs 。
model
- 返回
使用给定的 和 调用可调用对象时的输出。
model
*args
**kwargs
例
>>> # from a github repo >>> repo = "pytorch/vision" >>> model = torch.hub.load( ... repo, "resnet50", weights="ResNet50_Weights.IMAGENET1K_V1" ... ) >>> # from a local directory >>> path = "/some/local/path/pytorch/vision" >>> model = torch.hub.load(path, "resnet50", weights="ResNet50_Weights.DEFAULT")
- torch.hub 的download_url_to_file(url, dst, hash_prefix=None, progress=True)[来源]¶
Download 对象复制到本地路径。
- 参数
例
>>> torch.hub.download_url_to_file( ... "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth", ... "/tmp/temporary_file", ... )
- torch.hub 的load_state_dict_from_url(url, model_dir=无, map_location=无, progress=True, check_hash=False, file_name=无, weights_only=False)[来源]¶
在给定的 URL 处加载 Torch 序列化对象。
如果下载的文件是 zip 文件,它将自动 减压。
如果对象已经存在于 model_dir 中,则会对其进行反序列化,并且 返回。 的默认值 是 where,其中 是 返回的
目录。
model_dir
<hub_dir>/checkpoints
hub_dir
- 参数
url (str) – 要下载的对象的 URL
model_dir (str, optional) – 保存对象的目录
map_location (可选) – 指定如何重新映射存储位置的函数或字典(参见 torch.load)
progress (bool, optional) – 是否向 stderr 显示进度条。 默认值:True
check_hash (bool, optional) – 如果为 True,则 URL 的文件名部分应遵循命名约定,其中前 8 个或更多 文件内容的 SHA256 哈希值的数字。哈希值用于 确保名称唯一并验证文件的内容。 默认值:False
filename-<sha256>.ext
<sha256>
file_name (str, optional) – 下载文件的名称。如果未设置,则将使用 Filename from。
url
weights_only (bool, optional) – 如果为 True,则仅加载权重,而不会加载复杂的腌制对象。 建议用于不受信任的源。有关更多详细信息,请参阅
。
- 返回类型
例
>>> state_dict = torch.hub.load_state_dict_from_url( ... "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth" ... )
运行加载的模型:¶
请注意,和 in 用于实例化模型。加载模型后,如何查找
你可以用这个模型做什么?
建议的工作流程是
*args
**kwargs
dir(model)
以查看模型的所有可用方法。help(model.foo)
检查运行哪些参数model.foo
为了帮助用户在不来回参考文档的情况下进行探索,我们强烈 建议 Repo 所有者 使函数帮助消息清晰简洁。这也很有帮助 以包含一个最小的工作示例。
我下载的模型保存在哪里?¶
这些位置按
叫
hub.set_dir(<PATH_TO_HUB_DIR>)
$TORCH_HOME/hub
(如果设置了环境变量)。TORCH_HOME
$XDG_CACHE_HOME/torch/hub
(如果设置了环境变量)。XDG_CACHE_HOME
~/.cache/torch/hub
缓存逻辑¶
默认情况下,我们不会在加载文件后清理文件。如果缓存已存在于
由 返回的目录。
用户可以通过调用 来强制重新加载。这将删除
现有的 GitHub 文件夹和下载的权重,重新初始化新的下载。这很有用
当更新发布到同一分支时,用户可以跟上最新版本。hub.load(..., force_reload=True)
已知限制:¶
Torch Hub 的工作原理是导入软件包,就像它已安装一样。有一些副作用
通过在 Python 中导入引入。例如,您可以在 Python 缓存中看到新项目,这是正常的 Python 行为。
这也意味着在导入不同的模型时可能会遇到导入错误
来自不同的存储库,如果存储库具有相同的子包名称(通常是 Subpackage)。此类导入错误的解决方法是
从 dict 中删除有问题的子包;更多细节可以
可以在此 GitHub 问题中找到。sys.modules
sys.path_importer_cache
model
sys.modules
这里值得一提的一个已知限制:用户无法加载 同一个 Python 进程中的同一个仓库。这就像使用 在 Python 中同名,这并不好。如果您满足以下条件,Cache 可能会加入派对并给您带来惊喜 实际上试试那个。当然,将它们加载到单独的进程中是完全可以的。