跳到內容

將自定義JAX模型作為外掛整合

本指南將引導您完成為TPU推理實現基本JAX模型的步驟。

1. 引入您的模型程式碼

本指南假定您的模型是為JAX編寫的。

2. 使您的程式碼與vLLM相容

為確保與vLLM相容,您的模型必須滿足以下要求:

初始化程式碼

模型中的所有vLLM模組在其建構函式中都必須包含一個vllm_config引數。這將包含所有vllm相關的配置以及模型配置。

初始化程式碼應如下所示:

class LlamaForCausalLM(nnx.Module):

    def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array,
                 mesh: Mesh) -> None:
        self.vllm_config = vllm_config
        self.rng = nnx.Rngs(rng_key)
        self.mesh = mesh

        self.model = LlamaModel(
            vllm_config=vllm_config,
            rng=self.rng,
            mesh=mesh,
        )

計算程式碼

模型的正向傳播應在__call__方法中,該方法至少必須包含以下引數:

def __call__(
    self,
    kv_caches: List[jax.Array],
    input_ids: jax.Array,
    attention_metadata: AttentionMetadata,
) -> Tuple[List[jax.Array], jax.Array]:

有關參考,請檢視我們的Llama實現

3. 實現權重載入邏輯

您現在需要在您的*ForCausalLM類中實現load_weights方法。此方法應從HuggingFace的檢查點檔案(或相容的本地檢查點)載入權重,並將其分配給您模型中相應的層。

4. 註冊您的模型

TPU推理依賴於模型登錄檔來確定如何執行每個模型。預註冊架構的列表可以在這裡找到。

如果您的模型不在該列表中,您必須將其註冊到TPU推理。您可以使用外掛(類似於vLLM的外掛)來載入外部模型,而無需修改TPU推理的程式碼庫。

您的外掛結構應如下所示:

├── setup.py
├── your_code
│   ├── your_code.py
│   └── __init__.py

setup.py構建指令碼應遵循與vLLM外掛相同的指導

要註冊模型,請在your_code/__init__.py中使用以下程式碼:

from tpu_inference.logger import init_logger
from tpu_inference.models.common.model_loader import register_model

logger = init_logger(__name__)

def register():
    from .your_code import YourModelForCausalLM
    register_model("YourModelForCausalLM", YourModelForCausalLM)

5. 安裝並執行您的模型

確保您在與vllm/tpu推理相同的Python環境中pip install .您的模型。然後執行您的模型:

HF_TOKEN=token TPU_BACKEND_TYPE=jax \
  python -m vllm.entrypoints.cli.main serve \
  /path/to/hf_compatible/weights/ \
  --max-model-len=1024 \
  --tensor-parallel-size 8 \
  --max-num-batched-tokens 1024 \
  --max-num-seqs=1 \