torchax 在 vLLM 中是如何使用的¶
作者:Siyuan Liu, Hongmin Fan, Han Qi
最後更新:2025 年 9 月 26 日
什麼是 torchax¶
torchax 是一個提供 JAX 和 PyTorch 互操作性的庫。這意味著,您現在可以在同一個程序中執行 PyTorch 和 JAX,並且可以在 JAX 支援的所有硬體上執行(包括 NVidia GPU 和 Google TPU)。
它可以被看作是:* JAX 的 PyTorch 前端,或者,* PyTorch 的 JAX 後端。
使用 torchax,您可以
- 透過少至 2 行程式碼的修改,在 TPU 上執行 PyTorch 程式碼。
- 從 PyTorch 函式中呼叫 JAX 函式,並傳遞 jax.Array。
- 從 JAX 函式中呼叫 PyTorch 函式,並傳遞 torch.Tensor。
- 使用 JAX 的功能,如 jax.grad、optax 和 GSPMD 來訓練 PyTorch 模型。
- 使用 PyTorch 模型作為特徵提取器,並將其與 JAX 模型一起使用。
其工作原理是擁有一個 torch.Tensor 的子類,該子類包含一個 jax.Array,並實現了該張量應該支援的所有 torch 運算子。
有關 torchax 工作原理的更多詳細資訊,請參閱此頁面。
TPU JAX worker¶
vllm Worker 透過 init_device、determine_available_memory、execute_model、compile_or_warm_up_model、profile 等通用方法與 vllm 的 LLM Engine 進行互動。
每個 worker 都有一個 runner - 主要用於後端特定實現。主要包括以下內容
- 模型初始化 & 權重載入
- 確定 KV 快取塊的數量(程式碼指標)
- 捕獲計算圖:使用不同的輸入形狀執行模型,以執行所有可能的計算圖,避免在服務期間進行編譯。
- 根據排程器輸出執行模型(預處理模型輸入,執行模型,為每個請求生成取樣 token)
Jax TPU worker 是在 tpu_common 中引入的一個新的 Worker 實現,它負責呼叫使用 Jax 或 Torch 實現的模型。
Jax worker 與 torch 模型之間的互動¶
當 Jax worker 例項化時,它使用 get_model 函式來獲取一個代表模型的 callable。該 callable 是一個純函式(沒有狀態),因此權重、KV 快取以及輸入都將作為輸入傳遞給該函式。
當 worker 執行模型時,它會呼叫模型函式。模型函式接受 Jax Arrays 作為輸入。
純函式¶
JAX 的轉換和編譯僅設計用於功能上純粹的 Python 函式:所有輸入資料都透過函式引數傳遞,所有結果都透過函式結果輸出。純函式在輸入相同時總是返回相同的結果。
如果在計算中使用了某些陣列,但它們不是函式輸入,它們將在計算圖中被內聯為常量。
PyTorch 的 forward 函式不將模型權重作為輸入引數 -> 需要使用 torch.func.functional_call。
正如我們在上面的官方文件中看到的,functional_call 允許將權重作為輸入傳遞(而不是從模型物件的屬性中讀取權重)。
import torch
import torch.nn as nn
from torch.func import functional_call, grad
x = torch.randn(4, 3)
t = torch.randn(4, 3)
model = nn.Linear(3, 3)
def compute_loss(params, x, t): # params is the weights as a dict of Tensos
y = functional_call(model, params, x)
return nn.functional.mse_loss(y, t)
grad_weights = grad(compute_loss)(dict(model.named_parameters()), x, t)
KV 快取¶
透過 functional call,我們可以將權重作為函式輸入傳遞,現在,我們仍然需要處理 KV 快取。KV 快取是我們的 model_fn 的輸入/輸出。然而,傳統上 vllm 上游將其作為模型屬性寫入。
我們可以在呼叫模型之前手動將 KV 快取放在那裡,並返回它們。
看看下面的說明性示例
def f(cache, x):
cache += x
cache = jnp.zeros(3)
x = jnp.ones(3)
jax.jit(f)(cache, x)
# prints 0
print(cache)
上面的程式碼模擬了對 cache 變數的就地更新,我們可以看到這種更改並未傳播到 jax.jit 區域之外。
def functional_f(cache, x):
cache += x
return cache
cache = jnp.zeros(3)
x = jnp.ones(3)
updated = jax.jit(functional_f)(cache, x)
cache = updated #<-- write back the update
# prints 1
print(cache)
訣竅在於讓 jax.jit 中的函式返回更新的 KV 快取,然後我們重新分配修改。
使用上述技術的程式碼位於:tpu_inference/models/torchax/torchax_wrapper.py#L55-L83 如下所示
心智模型¶
torchax 就是 JAX。
torchax 的工作原理是提供一個 PyTorch 前端;因此,每個 PyTorch 運算子最終都成為作用於 JAX 陣列的 JAX 函式。因此,我們這裡採取的方法是:1. 使用處理 JAX 模型相同的 JAX worker。2. 使用 torchax 使 torch.nn.Module 對 worker 看起來像一個 JAX 模型。3. 對 KV 快取和 Attention kernel(如 RaggedPagedAttention)使用基於 JAX 的方法。
模型執行虛擬碼
inputs : jax.Array = prepare_inputs() # shared by jax model and torchax model
inputs_torch : torch.Tensor = torch_view(inputs) # a torch.Tensor subclass that holds an jax.Array
outputs_torch : torch.Tensor = torch.func.functional_call(torch_model, weights, inputs_torch) # kv caches are handled in VllmModelWrapper
outputs = jax_view(outputs_torch)
# ...
# sampler logic implemented based on jax, shared by jax model and torchax model
Attention kernel 呼叫虛擬碼




