跳到內容

JAX 模型開發指南

tpu-inference 提供了一個靈活的框架,用於在 Flax NNX 中實現基於 Transformer 的架構。

整合新模型型別所需的元件包括:- 定義模型架構並實現任何自定義層 - 實現權重載入邏輯 - (可選) 新增量化支援 - 將新模型註冊到 tpu-inference 中

程式碼組織

在開始模型開發之前,熟悉程式碼組織會很有幫助。

tpu_inference
├── layers
   ├── jax # Provide pre-implemented building blocks for tpu-inference models.
       ├── attention_interface.py # Core interfaces used for applying attention.
       ├── base.py
       ├── layers.py
       ├── transformer_block.py
       ├── sharding.py
       ├── rope.py
       ├── glossary.md
       ├── attention
           ├── attention.py # Pre-implemented attention layer.
           └── deepseek_v3_attention.py
       └── moe
            ├── moe.py
            └── deepseek_v3_moe.py
   └── common # Functionalities shared between torchax and jax implementations.
└── models
   ├── common
      └── model_loader.py
   └── jax  # Contains model files for each type of model family.
       ├── deepseek_v3.py
       ├── llama3.py
       ├── qwen3.py
       └── utils

模型實現

實現新模型需要建立一個專用模型檔案(例如,deepseek_v3.py),其中包含以下元件:- 定義架構的模型類。- 前向傳播實現和 logits 計算。- 權重載入邏輯,用於將 HuggingFace 權重匯入模型定義。

定義模型架構

模型檔案旨在包含定義基於 Transformer 的架構所需的所有資訊。每個模型檔案都包含一個具有以下建構函式介面的模型類:

class NewModel(nnx.Module):
  def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs,
               mesh: jax.sharding.Mesh)

建構函式應設定架構配置(例如,num_layers、hidden_size)並初始化模型層。可以使用 flax NNX 從頭開始定義層(例如,Llama3),或者可以利用 tpu-inference 來匯入或擴充套件常用的層型別(例如,EmbedderRMSNormMoEAttentionDenseFFWTransformerBlock)。

實現前向傳播

前向傳播包含將模型建構函式中定義的層進行拼接的邏輯,並應使用以下介面:

def __call__(
   self,
   kv_caches: List[jax.Array],
   input_ids: jax.Array,
   attention_metadata: AttentionMetadata,
   *args,
) -> Tuple[List[KVCacheType], jax.Array, List[jax.Array]]

此介面的關鍵假設是上下文由模型外部管理(模型負責在自注意力之後更新 KV 快取張量是例外),這與 vLLM 中的情況一致。(有關 AttentionMetadata 如何準備的更多詳細資訊,請參閱 vLLM 的 Block 排程和管理設計tpu_jax_runner.py)。預期返回的輸出包含更新的 KV 快取、最終層隱藏狀態以及(可選)輔助的最終隱藏狀態殘差(用於投機解碼)。

除了前向傳播邏輯之外,每個模型都需要實現一個使用以下介面生成 logits 的方法:def compute_logits(self, hidden_states: jax.Array) -> jax.Array:

權重載入

開源模型的權重在命名和引數形狀方面並不 universally standard。因此,有必要實現按模型載入權重的邏輯,以正確地將開源權重匯入到相應的模型引數中。為此,每個模型都必須實現一個具有以下介面的 load_weights 方法:def load_weights(self, rng: PRNGKey)

權重載入邏輯通常由幾個類別的步驟組成:- 將 HuggingFace 權重載入到迭代器中(參見 model_weights_generator)- 定義載入的權重名稱與實現權重名稱之間的對映。- 定義要應用於載入引數的張量變換的對映。(這些變換可以包括轉置或重塑載入的張量)。- 執行特定於模型的載入邏輯(例如,拆分載入的權重張量並載入到多個引數中)。- (可選)支援載入預量化模型。

有關如何實現權重載入的示例,請參考 deepseek_v3.pyllama4.py

量化支援

許多大型 LLM,如 DeepSeek-V3,使用量化來減少硬體需求和提高效能。tpu-inference 程式碼庫使用 Qwix 來載入預量化模型和/或對載入的模型權重應用額外的量化設定。在 tpu-inference 中,對於預量化檢查點如何生成沒有假設(因此您可以自由選擇流行的工具),只要結果以 HuggingFace Safetensor 格式儲存並遵循以下指南。有關如何在 tpu-inference 上使用 Qwix 進行推理執行的更多詳細資訊,請參閱 通用 readme

請注意,您可能需要在此處更新 TPU 上支援的量化型別列表:這裡。如果 HuggingFace 量化配置中的 quant_method 不是支援的型別之一,vLLM 將觸發驗證錯誤。HuggingFace 量化配置

為了演示,在本節中,我們將引用 deepseek_v3.py 來獲取實現細節。

載入預量化檢查點並應用量化規則

要正確載入預量化檢查點,需要執行以下步驟:- 使用 Qwix 配置定義量化設定,該配置可以作為 yaml 檔案(例如,int8_default.yaml)公開,或者在程式碼中設定。開源模型的量化設定通常在其各自的 HuggingFace 量化配置中釋出(例如,DeepSeek-R1)。(有關支援的 Qwix 量化選項的更多資訊,請參閱 Qwix 文件)。- 在 Qwix 配置中將 use_abstract_model 設定為 True,以便在載入權重之前對 NNX 模型圖進行量化。- 如果預量化模型包含反量化標度,請更新權重載入邏輯以儲存它們。如果載入模型的權重需要應用變換,請確保反量化標度也進行相應變換。標度維度可以透過 HuggingFace 配置中的 weight_block_size 來確定,並在 權重載入邏輯中設定。標度維度也可以與 Safetensor 檔案 進行交叉引用。

反之,如果檢查點未預量化,則不需要自定義模型載入程式碼,應在 Qwix 配置中將 use_abstract_model 設定為 False

請注意,Qwix 量化設定是事實上的標準,將覆蓋載入權位使用的資料型別(即使提供了預量化權重)。

模型註冊

一旦實現了新的模型型別,就必須將其新增到 model_loader.py 的模型登錄檔中。

警告

根據 vLLM 的驗證流程,模型必須註冊為一個受支援的 HuggingFace 模型名稱(有關更多詳細資訊,請參閱 此處)。

要將外部 Jax NNX 模型實現整合到 tpu-inference 中,請參閱 專用文件