跳到內容

torchax 在 vLLM 中是如何使用的

作者:Siyuan Liu, Hongmin Fan, Han Qi

最後更新:2025 年 9 月 26 日

什麼是 torchax

torchax 是一個提供 JAX 和 PyTorch 互操作性的庫。這意味著,您現在可以在同一個程序中執行 PyTorch 和 JAX,並且可以在 JAX 支援的所有硬體上執行(包括 NVidia GPU 和 Google TPU)。

它可以被看作是:* JAX 的 PyTorch 前端,或者,* PyTorch 的 JAX 後端。

使用 torchax,您可以

  • 透過少至 2 行程式碼的修改,在 TPU 上執行 PyTorch 程式碼。
  • 從 PyTorch 函式中呼叫 JAX 函式,並傳遞 jax.Array。
  • 從 JAX 函式中呼叫 PyTorch 函式,並傳遞 torch.Tensor。
  • 使用 JAX 的功能,如 jax.grad、optax 和 GSPMD 來訓練 PyTorch 模型。
  • 使用 PyTorch 模型作為特徵提取器,並將其與 JAX 模型一起使用。

其工作原理是擁有一個 torch.Tensor 的子類,該子類包含一個 jax.Array,並實現了該張量應該支援的所有 torch 運算子。

torchax

有關 torchax 工作原理的更多詳細資訊,請參閱此頁面

TPU JAX worker

vllm Worker 透過 init_device、determine_available_memory、execute_model、compile_or_warm_up_model、profile 等通用方法與 vllm 的 LLM Engine 進行互動。

每個 worker 都有一個 runner - 主要用於後端特定實現。主要包括以下內容

  • 模型初始化 & 權重載入
  • 確定 KV 快取塊的數量(程式碼指標)
  • 捕獲計算圖:使用不同的輸入形狀執行模型,以執行所有可能的計算圖,避免在服務期間進行編譯。
  • 根據排程器輸出執行模型(預處理模型輸入,執行模型,為每個請求生成取樣 token)

Jax TPU worker 是在 tpu_common 中引入的一個新的 Worker 實現,它負責呼叫使用 Jax 或 Torch 實現的模型。

Jax worker 與 torch 模型之間的互動

當 Jax worker 例項化時,它使用 get_model 函式來獲取一個代表模型的 callable。該 callable 是一個純函式(沒有狀態),因此權重、KV 快取以及輸入都將作為輸入傳遞給該函式。

Get Model

當 worker 執行模型時,它會呼叫模型函式。模型函式接受 Jax Arrays 作為輸入。

Model fn

純函式

JAX 的轉換和編譯僅設計用於功能上純粹的 Python 函式:所有輸入資料都透過函式引數傳遞,所有結果都透過函式結果輸出。純函式在輸入相同時總是返回相同的結果。

如果在計算中使用了某些陣列,但它們不是函式輸入,它們將在計算圖中被內聯為常量。

PyTorch 的 forward 函式不將模型權重作為輸入引數 -> 需要使用 torch.func.functional_call

正如我們在上面的官方文件中看到的,functional_call 允許將權重作為輸入傳遞(而不是從模型物件的屬性中讀取權重)。

import torch
import torch.nn as nn
from torch.func import functional_call, grad

x = torch.randn(4, 3)
t = torch.randn(4, 3)
model = nn.Linear(3, 3)

def compute_loss(params, x, t):  # params is the weights as a dict of Tensos
    y = functional_call(model, params, x)
    return nn.functional.mse_loss(y, t)

grad_weights = grad(compute_loss)(dict(model.named_parameters()), x, t)

KV 快取

透過 functional call,我們可以將權重作為函式輸入傳遞,現在,我們仍然需要處理 KV 快取。KV 快取是我們的 model_fn 的輸入/輸出。然而,傳統上 vllm 上游將其作為模型屬性寫入。

我們可以在呼叫模型之前手動將 KV 快取放在那裡,並返回它們。

看看下面的說明性示例

def f(cache, x):
  cache += x

cache = jnp.zeros(3)
x = jnp.ones(3)

jax.jit(f)(cache, x)

# prints 0
print(cache)

上面的程式碼模擬了對 cache 變數的就地更新,我們可以看到這種更改並未傳播到 jax.jit 區域之外。

def functional_f(cache, x):
  cache += x
  return cache

cache = jnp.zeros(3)
x = jnp.ones(3)

updated = jax.jit(functional_f)(cache, x)
cache = updated #<-- write back the update

# prints 1
print(cache)

訣竅在於讓 jax.jit 中的函式返回更新的 KV 快取,然後我們重新分配修改。

使用上述技術的程式碼位於:tpu_inference/models/torchax/torchax_wrapper.py#L55-L83 如下所示

alt text

心智模型

torchax 就是 JAX

torchax 的工作原理是提供一個 PyTorch 前端;因此,每個 PyTorch 運算子最終都成為作用於 JAX 陣列的 JAX 函式。因此,我們這裡採取的方法是:1. 使用處理 JAX 模型相同的 JAX worker。2. 使用 torchax 使 torch.nn.Module 對 worker 看起來像一個 JAX 模型。3. 對 KV 快取和 Attention kernel(如 RaggedPagedAttention)使用基於 JAX 的方法。

alt text

模型執行虛擬碼

inputs : jax.Array = prepare_inputs()  # shared by jax model and torchax model
inputs_torch : torch.Tensor = torch_view(inputs)  # a torch.Tensor subclass that holds an jax.Array
outputs_torch : torch.Tensor = torch.func.functional_call(torch_model, weights, inputs_torch) # kv caches are handled in VllmModelWrapper
outputs = jax_view(outputs_torch)
# ...
# sampler logic implemented based on jax, shared by jax model and torchax model

Attention kernel 呼叫虛擬碼

# q, new_k, new_v : torch.Tensor
# kv_cache : jax.Array
q = jax_view(q)
new_k = jax_view(new_k)
new_v = jax_view(new_v)
output : jax.Array = attention_kernel(q, new_k, new_v, kv_cache, ...)
output : torch.Tensor = torch_view(output)