MPS 后端¶
mps 设备通过Metal编程框架在MacOS设备上启用高性能GPU训练。它引入了一种新的设备,用于将机器学习计算图和原语映射到高效的Metal性能着色器图框架,并分别使用Metal性能着色器框架提供的优化内核。
新的MPS后端扩展了PyTorch生态系统,并为现有脚本提供了在GPU上设置和运行操作的能力。
要开始,请将您的Tensor和Module移动到mps设备:
# Check that MPS is available
if not torch.backends.mps.is_available():
if not torch.backends.mps.is_built():
print("MPS not available because the current PyTorch install was not "
"built with MPS enabled.")
else:
print("MPS not available because the current MacOS version is not 12.3+ "
"and/or you do not have an MPS-enabled device on this machine.")
else:
mps_device = torch.device("mps")
# Create a Tensor directly on the mps device
x = torch.ones(5, device=mps_device)
# Or
x = torch.ones(5, device="mps")
# Any operation happens on the GPU
y = x * 2
# Move your model to mps just like any other device
model = YourFavoriteNet()
model.to(mps_device)
# Now every call runs on the GPU
pred = model(x)