跳到內容

效能分析

目前有三種方法可以對您的工作負載進行效能分析:

使用 examples/tpu_profiling.py

vLLM TPU 效能分析指令碼

此指令碼是一個實用工具,用於分析 vLLM 引擎在 TPU VM 上的效能。它使用 JAX profiler 來捕獲詳細的效能跟蹤。

可以使用 TensorBoard(配合 tensorboard-plugin-profile 包)或 Perfetto UI 等工具來視覺化效能分析結果。

如何使用

先決條件

您必須安裝 TensorBoard profile 外掛才能視覺化結果。

pip install tensorboard-plugin-profile

基本命令

該指令碼從命令列執行,指定工作負載引數和任何必需的 vLLM 引擎引數。

python3 examples/tpu_profiling.py --model <your-model-name> [OPTIONS]

關鍵引數

  • --model: (必需) 要進行效能分析的模型名稱或路徑。
  • --input-len: 每個請求的輸入提示 token 長度。
  • --output-len: 每個請求要生成的 token 數量。
  • --batch-size: 請求的數量。
  • --profile-result-dir: JAX profiler 輸出將儲存的目錄。
  • 該指令碼還接受所有標準的 vLLM EngineArgs(例如,--tensor-parallel-size, --dtype)。

示例

1. 分析 Prefill 操作: 要分析具有長輸入提示(例如,1024 個 token)的單個請求,請將 --input-len 設定得高一些,並將 --batch-size 設定為 1。

python3 examples/tpu_profiling.py \
  --model google/gemma-2b \
  --input-len 1024 \
  --output-len 1 \
  --batch-size 1

2. 分析 Decoding 操作: 要分析大量單 token 解碼步驟的批處理,請將 --input-len--output-len 設定為 1,並使用較大的 --batch-size

python3 examples/tpu_profiling.py \
  --model google/gemma-2b \
  --input-len 1 \
  --output-len 1 \
  --batch-size 256

使用 PHASED_PROFILING_DIR

如果您設定了以下環境變數

PHASED_PROFILING_DIR=<DESIRED PROFILING OUTPUT DIR>

我們將自動在工作負載的三個階段捕獲效能分析(假設它們會出現):1. Prefill 密集型(給定批次的 prefill / 總計劃 token 的商值 => 0.9)2. Decode 密集型(給定批次的 prefill / 總計劃 token 的商值 <= 0.2)3. 混合型(給定批次的 prefill / 總計劃 token 的商值在 0.4 到 0.6 之間)。

為了便於您的分析,我們還將記錄被分析批次的批次構成。

使用 USE_JAX_PROFILER_SERVER

如果您設定了以下環境變數

USE_JAX_PROFILER_SERVER=True

您可以改為手動決定何時捕獲效能分析以及捕獲多長時間,這對於您的工作負載(例如,E2E 基準測試)可能很有幫助,因為它很大,並且對整個工作負載進行效能分析(即使用上述方法)會生成一個巨大的跟蹤檔案。

您還可以設定所需的效能分析埠(預設為 9999)。

JAX_PROFILER_SERVER_PORT=XXXX

要使用此方法,您可以執行以下操作:

  1. 執行您典型的 vllm serveoffline_inference 命令(確保設定 USE_JAX_PROFILER_SERVER=True)。
  2. 執行您的基準測試命令(python benchmark_serving.py...)。
  3. 預熱完成後且基準測試正在執行時,啟動一個新的 tensorboard 例項,並將您的 logdir 設定為您所需的效能分析輸出位置(例如,tensorboard --logdir=profiles/llama3-mmlu/)。
  4. 開啟 tensorboard 例項並導航到 profile 頁面(例如,https://:6006/#profile)。
  5. 點選 Capture Profile,然後在 Profile Service URL(s) or TPU name 框中,輸入 localhost:XXXX,其中 XXXX 是您的 JAX_PROFILER_SERVER_PORT(預設為 9999)。

  6. 輸入所需的時間(以毫秒為單位)。