將自定義JAX模型作為外掛整合¶
本指南將引導您完成為TPU推理實現基本JAX模型的步驟。
1. 引入您的模型程式碼¶
本指南假定您的模型是為JAX編寫的。
2. 使您的程式碼與vLLM相容¶
為確保與vLLM相容,您的模型必須滿足以下要求:
初始化程式碼
模型中的所有vLLM模組在其建構函式中都必須包含一個vllm_config引數。這將包含所有vllm相關的配置以及模型配置。
初始化程式碼應如下所示:
class LlamaForCausalLM(nnx.Module):
def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array,
mesh: Mesh) -> None:
self.vllm_config = vllm_config
self.rng = nnx.Rngs(rng_key)
self.mesh = mesh
self.model = LlamaModel(
vllm_config=vllm_config,
rng=self.rng,
mesh=mesh,
)
計算程式碼
模型的正向傳播應在__call__方法中,該方法至少必須包含以下引數:
def __call__(
self,
kv_caches: List[jax.Array],
input_ids: jax.Array,
attention_metadata: AttentionMetadata,
) -> Tuple[List[jax.Array], jax.Array]:
…
有關參考,請檢視我們的Llama實現。
3. 實現權重載入邏輯¶
您現在需要在您的*ForCausalLM類中實現load_weights方法。此方法應從HuggingFace的檢查點檔案(或相容的本地檢查點)載入權重,並將其分配給您模型中相應的層。
4. 註冊您的模型¶
TPU推理依賴於模型登錄檔來確定如何執行每個模型。預註冊架構的列表可以在這裡找到。
如果您的模型不在該列表中,您必須將其註冊到TPU推理。您可以使用外掛(類似於vLLM的外掛)來載入外部模型,而無需修改TPU推理的程式碼庫。
您的外掛結構應如下所示:
setup.py構建指令碼應遵循與vLLM外掛相同的指導。
要註冊模型,請在your_code/__init__.py中使用以下程式碼:
from tpu_inference.logger import init_logger
from tpu_inference.models.common.model_loader import register_model
logger = init_logger(__name__)
def register():
from .your_code import YourModelForCausalLM
register_model("YourModelForCausalLM", YourModelForCausalLM)
5. 安裝並執行您的模型¶
確保您在與vllm/tpu推理相同的Python環境中pip install .您的模型。然後執行您的模型: