AMD Quark¶
量化可以有效減少記憶體和頻寬使用,加速計算並提高吞吐量,同時將精度損失降到最低。vLLM 可以利用 Quark 這一靈活而強大的量化工具包,生成高效能的量化模型,以便在 AMD GPU 上執行。Quark 專門支援對大型語言模型進行權重、啟用和 KV 快取的量化,並支援 AWQ、GPTQ、Rotation 和 SmoothQuant 等前沿量化演算法。
Quark 安裝¶
在量化模型之前,您需要安裝 Quark。Quark 的最新版本可以透過 pip 安裝
您可以參考 Quark 安裝指南 獲取更多安裝詳情。
此外,安裝 vllm
和 lm-evaluation-harness
用於評估
量化過程¶
安裝 Quark 後,我們將透過一個示例來演示如何使用 Quark。Quark 量化過程可分為以下 5 個步驟
- 載入模型
- 準備校準資料載入器
- 設定量化配置
- 量化模型並匯出
- 在 vLLM 中進行評估
1. 載入模型¶
Quark 使用 Transformers 獲取模型和分詞器。
程式碼
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_ID = "meta-llama/Llama-2-70b-chat-hf"
MAX_SEQ_LEN = 512
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype="auto",
)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, model_max_length=MAX_SEQ_LEN)
tokenizer.pad_token = tokenizer.eos_token
2. 準備校準資料載入器¶
Quark 使用 PyTorch Dataloader 載入校準資料。有關如何高效使用校準資料集的更多詳細資訊,請參閱 新增校準資料集。
程式碼
from datasets import load_dataset
from torch.utils.data import DataLoader
BATCH_SIZE = 1
NUM_CALIBRATION_DATA = 512
# Load the dataset and get calibration data.
dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
text_data = dataset["text"][:NUM_CALIBRATION_DATA]
tokenized_outputs = tokenizer(text_data, return_tensors="pt",
padding=True, truncation=True, max_length=MAX_SEQ_LEN)
calib_dataloader = DataLoader(tokenized_outputs['input_ids'],
batch_size=BATCH_SIZE, drop_last=True)
3. 設定量化配置¶
我們需要設定量化配置,您可以檢視 Quark 配置指南 獲取更多詳細資訊。這裡我們使用權重、啟用和 KV 快取上的 FP8 每張量(per-tensor)量化,量化演算法為 AutoSmoothQuant。
注意
請注意,量化演算法需要一個 JSON 配置檔案,該檔案位於 Quark Pytorch 示例 的 examples/torch/language_modeling/llm_ptq/models
目錄下。例如,Llama 的 AutoSmoothQuant 配置檔案是 examples/torch/language_modeling/llm_ptq/models/llama/autosmoothquant_config.json
。
程式碼
from quark.torch.quantization import (Config, QuantizationConfig,
FP8E4M3PerTensorSpec,
load_quant_algo_config_from_file)
# Define fp8/per-tensor/static spec.
FP8_PER_TENSOR_SPEC = FP8E4M3PerTensorSpec(observer_method="min_max",
is_dynamic=False).to_quantization_spec()
# Define global quantization config, input tensors and weight apply FP8_PER_TENSOR_SPEC.
global_quant_config = QuantizationConfig(input_tensors=FP8_PER_TENSOR_SPEC,
weight=FP8_PER_TENSOR_SPEC)
# Define quantization config for kv-cache layers, output tensors apply FP8_PER_TENSOR_SPEC.
KV_CACHE_SPEC = FP8_PER_TENSOR_SPEC
kv_cache_layer_names_for_llama = ["*k_proj", "*v_proj"]
kv_cache_quant_config = {name :
QuantizationConfig(input_tensors=global_quant_config.input_tensors,
weight=global_quant_config.weight,
output_tensors=KV_CACHE_SPEC)
for name in kv_cache_layer_names_for_llama}
layer_quant_config = kv_cache_quant_config.copy()
# Define algorithm config by config file.
LLAMA_AUTOSMOOTHQUANT_CONFIG_FILE =
'examples/torch/language_modeling/llm_ptq/models/llama/autosmoothquant_config.json'
algo_config = load_quant_algo_config_from_file(LLAMA_AUTOSMOOTHQUANT_CONFIG_FILE)
EXCLUDE_LAYERS = ["lm_head"]
quant_config = Config(
global_quant_config=global_quant_config,
layer_quant_config=layer_quant_config,
kv_cache_quant_config=kv_cache_quant_config,
exclude=EXCLUDE_LAYERS,
algo_config=algo_config)
4. 量化模型並匯出¶
然後我們可以應用量化。量化後,在匯出之前,我們需要先凍結量化模型。請注意,我們需要以 HuggingFace safetensors
格式匯出模型,您可以參考 HuggingFace 格式匯出 瞭解更多匯出格式詳情。
程式碼
import torch
from quark.torch import ModelQuantizer, ModelExporter
from quark.torch.export import ExporterConfig, JsonExporterConfig
# Apply quantization.
quantizer = ModelQuantizer(quant_config)
quant_model = quantizer.quantize_model(model, calib_dataloader)
# Freeze quantized model to export.
freezed_model = quantizer.freeze(model)
# Define export config.
LLAMA_KV_CACHE_GROUP = ["*k_proj", "*v_proj"]
export_config = ExporterConfig(json_export_config=JsonExporterConfig())
export_config.json_export_config.kv_cache_group = LLAMA_KV_CACHE_GROUP
# Model: Llama-2-70b-chat-hf-w-fp8-a-fp8-kvcache-fp8-pertensor-autosmoothquant
EXPORT_DIR = MODEL_ID.split("/")[1] + "-w-fp8-a-fp8-kvcache-fp8-pertensor-autosmoothquant"
exporter = ModelExporter(config=export_config, export_dir=EXPORT_DIR)
with torch.no_grad():
exporter.export_safetensors_model(freezed_model,
quant_config=quant_config, tokenizer=tokenizer)
5. 在 vLLM 中進行評估¶
現在,您可以直接透過 LLM 入口點載入並執行 Quark 量化模型
程式碼
from vllm import LLM, SamplingParams
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM.
llm = LLM(model="Llama-2-70b-chat-hf-w-fp8-a-fp8-kvcache-fp8-pertensor-autosmoothquant",
kv_cache_dtype='fp8',quantization='quark')
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}")
print(f"Output: {generated_text!r}")
print("-" * 60)
或者,您可以使用 lm_eval
來評估精度
lm_eval --model vllm \
--model_args pretrained=Llama-2-70b-chat-hf-w-fp8-a-fp8-kvcache-fp8-pertensor-autosmoothquant,kv_cache_dtype='fp8',quantization='quark' \
--tasks gsm8k
Quark 量化指令碼¶
除了上面的 Python API 示例,Quark 還提供了一個 量化指令碼,以更方便地量化大型語言模型。它支援使用各種不同的量化方案和最佳化演算法來量化模型。它可以匯出量化模型並即時執行評估任務。有了這個指令碼,上面的示例可以變成
python3 quantize_quark.py --model_dir meta-llama/Llama-2-70b-chat-hf \
--output_dir /path/to/output \
--quant_scheme w_fp8_a_fp8 \
--kv_cache_dtype fp8 \
--quant_algo autosmoothquant \
--num_calib_data 512 \
--model_export hf_format \
--tasks gsm8k
使用 MXFP4 模型¶
vLLM 支援載入透過 AMD Quark 離線量化的 MXFP4 模型,該模型符合 Open Compute Project (OCP) 規範。
此方案目前僅支援啟用的動態量化。
安裝最新 AMD Quark 版本後的使用示例
MXFP4 中的矩陣乘法執行模擬可以在不原生支援 MXFP4 操作的裝置上執行(例如 AMD Instinct MI325、MI300 和 MI250),使用融合核心(fused kernel)即時將 MXFP4 權重反量化為半精度。這對於使用 vLLM 評估 MXFP4 模型,或者受益於約 4 倍的記憶體節省(與 float16 和 bfloat16 相比)非常有用。
要生成使用 MXFP4 資料型別量化的離線模型,最簡單的方法是使用 AMD Quark 的 量化指令碼,例如