XLA设备上的量化操作(实验功能)¶
本文档概述了如何利用量化操作在XLA设备上启用量化的方法。
XLA 量化操作提供了量化操作的高级抽象(例如,块状 int4 量化矩阵乘法)。这些操作类似于 CUDA 生态系统中的量化 CUDA 内核(示例),在 XLA 框架内提供了类似的功能和性能优势。
注意:目前这被分类为实验性功能。在下一个(2.5)版本中,其API细节将会改变。
如何使用:¶
XLA量化操作可以作为torch op使用,或者作为torch.nn.Module来包裹torch.op。这两种选项给模型开发者提供了灵活性,以便他们选择最佳方式将XLA量化操作集成到他们的解决方案中。
Both torch op and nn.Module are compatible with torch.compile( backend='openxla').
在模型代码中调用XLA量化操作¶
用户可以像调用其他常规的PyTorch操作一样调用XLA量化操作。这为在应用程序中集成XLA量化操作提供了最大的灵活性。这些量化操作既可以在急切模式下工作,也可以在Dynamo中工作,并且可以与常规的PyTorch CPU张量和XLA张量一起使用。
注意 请检查量化操作的文档字符串以获取量化权重的布局。
import torch
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_quantized_matmul
N_INPUT_FEATURES=10
N_OUTPUT_FEATURES=20
x = torch.randn((3, N_INPUT_FEATURES), dtype=torch.bfloat16)
w_int = torch.randint(-128, 127, (N_OUTPUT_FEATURES, N_INPUT_FEATURES), dtype=torch.int8)
scaler = torch.randn((N_OUTPUT_FEATURES,), dtype=torch.bfloat16)
# Call with torch CPU tensor (For debugging purpose)
matmul_output = torch.ops.xla.quantized_matmul(x, w_int, scaler)
device = xm.xla_device()
x_xla = x.to(device)
w_int_xla = w_int.to(device)
scaler_xla = scaler.to(device)
# Call with XLA Tensor to run on XLA device
matmul_output_xla = torch.ops.xla.quantized_matmul(x_xla, w_int_xla, scaler_xla)
# Use with torch.compile(backend='openxla')
def f(x, w, s):
return torch.ops.xla.quantized_matmul(x, w, s)
f_dynamo = torch.compile(f, backend="openxla")
dynamo_out_xla = f_dynamo(x_xla, w_int_xla, scaler_xla)
通常会在模型开发者的模型代码中将量化操作包装成一个自定义的nn.Module:
class MyQLinearForXLABackend(torch.nn.Module):
def __init__(self):
self.weight = ...
self.scaler = ...
def load_weight(self, w, scaler):
# Load quantized Linear weights
# Customized way to preprocess the weights
...
self.weight = processed_w
self.scaler = processed_scaler
def forward(self, x):
# Do some random stuff with x
...
matmul_output = torch.ops.xla.quantized_matmul(x, self.weight, self.scaler)
# Do some random stuff with matmul_output
...
模块交换¶
Alternatively, 用户也可以使用包裹XLA量化操作的nn.Module并在模型代码中进行模块替换:
orig_model = MyModel()
# Quantize the model and get quantized weights
q_weights = quantize(orig_model)
# Process the quantized weight to the format that XLA quantized op expects.
q_weights_for_xla = process_for_xla(q_weights)
# Do module swap
q_linear = XlaQuantizedLinear(self.linear.in_features,
self.linear.out_features)
q_linear.load_quantized_weight(q_weights_for_xla)
orig_model.linear = q_linear
支持的量化操作:¶
矩阵乘法¶
权重量化类型 |
激活量化类型 |
数据类型(dtype) |
支持的 |
|---|---|---|---|
逐通道(对称/非对称) |
N/A |
W8A16 |
是的 |
逐通道(对称/非对称) |
N/A |
W4A16 |
是的 |
per-channel |
per-token |
W8A8 |
No |
per-channel |
per-token |
W4A8 |
No |
块状(对称/不对称) |
N/A |
W8A16 |
是的 |
块状(对称/不对称) |
N/A |
W4A16 |
是的 |
块级 |
per-token |
W8A8 |
No |
块级 |
per-token |
W4A8 |
No |
Note W[X]A[Y] 指的是 X-位权重,Y-位激活函数。如果 X/Y 是 4 或 8,它指的是 int4/8。16 对于 bfloat16 格式。
嵌入¶
待添加