JAX 模型開發指南¶
tpu-inference 提供了一個靈活的框架,用於在 Flax NNX 中實現基於 Transformer 的架構。
整合新模型型別所需的元件包括:- 定義模型架構並實現任何自定義層 - 實現權重載入邏輯 - (可選) 新增量化支援 - 將新模型註冊到 tpu-inference 中
程式碼組織¶
在開始模型開發之前,熟悉程式碼組織會很有幫助。
tpu_inference
├── layers
│ ├── jax # Provide pre-implemented building blocks for tpu-inference models.
│ │ ├── attention_interface.py # Core interfaces used for applying attention.
│ │ ├── base.py
│ │ ├── layers.py
│ │ ├── transformer_block.py
│ │ ├── sharding.py
│ │ ├── rope.py
│ │ ├── glossary.md
│ │ ├── attention
│ │ │ ├── attention.py # Pre-implemented attention layer.
│ │ │ └── deepseek_v3_attention.py
│ │ └── moe
│ │ ├── moe.py
│ │ └── deepseek_v3_moe.py
│ └── common # Functionalities shared between torchax and jax implementations.
└── models
├── common
│ └── model_loader.py
└── jax # Contains model files for each type of model family.
├── deepseek_v3.py
├── llama3.py
├── qwen3.py
└── utils
- 新 Jax 模型型別的註冊應在
tpu_inference/models/common/model_loader.py中執行。 - 新的 Jax 模型定義應新增到
tpu_inference/models/jax。 - 常用的層(例如,嵌入層、前饋層)可以從
tpu_inference/layers/jax匯入。 - 特定於模型的層實現應新增到
tpu_inference/layers/<layer_type>/<model_type>_<layer_type>.py(例如,attention/deepseek_v3_attention.py,moe/deepseek_v3_moe.py)。 - 自定義 (Qwix) 量化配置(yaml 檔案)應儲存在
tpu_inference/models/jax/utils/quantization/configs。
模型實現¶
實現新模型需要建立一個專用模型檔案(例如,deepseek_v3.py),其中包含以下元件:- 定義架構的模型類。- 前向傳播實現和 logits 計算。- 權重載入邏輯,用於將 HuggingFace 權重匯入模型定義。
定義模型架構¶
模型檔案旨在包含定義基於 Transformer 的架構所需的所有資訊。每個模型檔案都包含一個具有以下建構函式介面的模型類:
class NewModel(nnx.Module):
def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs,
mesh: jax.sharding.Mesh)
建構函式應設定架構配置(例如,num_layers、hidden_size)並初始化模型層。可以使用 flax NNX 從頭開始定義層(例如,Llama3),或者可以利用 tpu-inference 來匯入或擴充套件常用的層型別(例如,Embedder、RMSNorm、MoE、Attention、DenseFFW、TransformerBlock)。
實現前向傳播¶
前向傳播包含將模型建構函式中定義的層進行拼接的邏輯,並應使用以下介面:
def __call__(
self,
kv_caches: List[jax.Array],
input_ids: jax.Array,
attention_metadata: AttentionMetadata,
*args,
) -> Tuple[List[KVCacheType], jax.Array, List[jax.Array]]
此介面的關鍵假設是上下文由模型外部管理(模型負責在自注意力之後更新 KV 快取張量是例外),這與 vLLM 中的情況一致。(有關 AttentionMetadata 如何準備的更多詳細資訊,請參閱 vLLM 的 Block 排程和管理設計 和 tpu_jax_runner.py)。預期返回的輸出包含更新的 KV 快取、最終層隱藏狀態以及(可選)輔助的最終隱藏狀態殘差(用於投機解碼)。
除了前向傳播邏輯之外,每個模型都需要實現一個使用以下介面生成 logits 的方法:def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
權重載入¶
開源模型的權重在命名和引數形狀方面並不 universally standard。因此,有必要實現按模型載入權重的邏輯,以正確地將開源權重匯入到相應的模型引數中。為此,每個模型都必須實現一個具有以下介面的 load_weights 方法:def load_weights(self, rng: PRNGKey)
權重載入邏輯通常由幾個類別的步驟組成:- 將 HuggingFace 權重載入到迭代器中(參見 model_weights_generator)- 定義載入的權重名稱與實現權重名稱之間的對映。- 定義要應用於載入引數的張量變換的對映。(這些變換可以包括轉置或重塑載入的張量)。- 執行特定於模型的載入邏輯(例如,拆分載入的權重張量並載入到多個引數中)。- (可選)支援載入預量化模型。
有關如何實現權重載入的示例,請參考 deepseek_v3.py 或 llama4.py。
量化支援¶
許多大型 LLM,如 DeepSeek-V3,使用量化來減少硬體需求和提高效能。tpu-inference 程式碼庫使用 Qwix 來載入預量化模型和/或對載入的模型權重應用額外的量化設定。在 tpu-inference 中,對於預量化檢查點如何生成沒有假設(因此您可以自由選擇流行的工具),只要結果以 HuggingFace Safetensor 格式儲存並遵循以下指南。有關如何在 tpu-inference 上使用 Qwix 進行推理執行的更多詳細資訊,請參閱 通用 readme。
請注意,您可能需要在此處更新 TPU 上支援的量化型別列表:這裡。如果 HuggingFace 量化配置中的 quant_method 不是支援的型別之一,vLLM 將觸發驗證錯誤。HuggingFace 量化配置。
為了演示,在本節中,我們將引用 deepseek_v3.py 來獲取實現細節。
載入預量化檢查點並應用量化規則¶
要正確載入預量化檢查點,需要執行以下步驟:- 使用 Qwix 配置定義量化設定,該配置可以作為 yaml 檔案(例如,int8_default.yaml)公開,或者在程式碼中設定。開源模型的量化設定通常在其各自的 HuggingFace 量化配置中釋出(例如,DeepSeek-R1)。(有關支援的 Qwix 量化選項的更多資訊,請參閱 Qwix 文件)。- 在 Qwix 配置中將 use_abstract_model 設定為 True,以便在載入權重之前對 NNX 模型圖進行量化。- 如果預量化模型包含反量化標度,請更新權重載入邏輯以儲存它們。如果載入模型的權重需要應用變換,請確保反量化標度也進行相應變換。標度維度可以透過 HuggingFace 配置中的 weight_block_size 來確定,並在 權重載入邏輯中設定。標度維度也可以與 Safetensor 檔案 進行交叉引用。
反之,如果檢查點未預量化,則不需要自定義模型載入程式碼,應在 Qwix 配置中將 use_abstract_model 設定為 False。
請注意,Qwix 量化設定是事實上的標準,將覆蓋載入權位使用的資料型別(即使提供了預量化權重)。
模型註冊¶
一旦實現了新的模型型別,就必須將其新增到 model_loader.py 的模型登錄檔中。
警告
根據 vLLM 的驗證流程,模型必須註冊為一個受支援的 HuggingFace 模型名稱(有關更多詳細資訊,請參閱 此處)。
要將外部 Jax NNX 模型實現整合到 tpu-inference 中,請參閱 專用文件。