跳到內容

自定義 Logits Processors

重要

某些 logits processor 的設計更改仍在進行中,API 可能在不久的將來發生變化。我們希望儘快穩定此 API 部分。

“自定義”logits processor 由 vLLM 使用者編寫,並在初始化時載入到 vLLM 中,而無需修改或重新編譯 vLLM 原始碼。它與內建 logits processor 相反。

本文件演示瞭如何編寫、載入和使用自定義 logits processor。

Logits Processors 背景

logits processor 調整下一個 token 的機率分佈,通常旨在將模型引導到所需的行為型別。

在 vLLM 中,logits processor 以 batch 粒度執行。在給定的引擎步中,logits processor 消耗模型輸出的 (num_requests) x (vocab_size) 的原始 logits 張量。對於所有啟用了 logits processor 的請求,logits processor 會對 logits 張量的相應行應用轉換,而其他行保持不變。然後將轉換後的 logits 張量傳遞給 softmax。

建立自定義 Logits Processor

自定義 logits processor 必須繼承自 vllm.v1.sample.logits_processor.LogitsProcessor 並定義(至少)以下方法:

  • validate_params(cls, sampling_params: SamplingParams):

    • 如果 SamplingParams 包含 logits processor 使用的無效引數(尤其是自定義引數),則引發 ValueError
    • 當請求傳送到入口點時,validate_params() 將驗證 SamplingParams 並拒絕帶有無效引數的請求。
    • 注意: 實現 validate_params() 以防止自定義 logits processor 的引數無效非常重要。否則,帶有無效引數的請求可能導致自定義 logits processor 出現意外行為。
  • __init__(self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool)

    • vllm_config:引擎配置資料結構。
    • device:硬體加速器裝置資訊。
    • is_pin_memory:指示是否可用 pin memory 來支援 logits processor 實現的標誌。
  • apply(self, logits: torch.Tensor) -> torch.Tensor:

    • 消耗一個 (num_requests) x (vocab_size) 的 logits 張量 (logits)。
    • 以 batch 粒度應用 logits processor 轉換。
    • 返回一個轉換後的 (num_requests) x (vocab_size) logits 張量。
    • 您可以原地修改或非原地修改輸入的 logits processor;原地修改更節省記憶體。
  • is_argmax_invariant(self) -> bool:

    • 如果 logits processor 是 argmax 不變的(對於給定的請求,從不改變最高 logits 值 token ID),則返回 True;如果 logits processor 可能修改 argmax,則返回 False
    • is_argmax_invariant() 在啟動時評估一次;如果為 True,當所有請求都使用貪婪取樣時,vLLM 將跳過應用此 logits processor。
  • update_state(self, batch_update: Optional["BatchUpdate"]) -> None:

    • 消耗一個 BatchUpdate 資料結構,該結構表示當前引擎步開始時的持久 batch 狀態更改。
    • 使用 BatchUpdate 的成員來更新 logits processor 的內部狀態。
    • 注意: batch update 資料結構可能為 None,表示 batch 構成沒有變化。在這種情況下,LogitsProcessor 可能仍希望根據它在新增時保留的更新後的 output_token_ids 列表來更新其狀態。

vLLM 引擎如何構建 BatchUpdate 資料結構

重要

某些 logits processor 的設計更改仍在進行中。我們預計將來您在實現 logits processor 時不需要考慮 batch 狀態更改,本節的資訊將變得無關緊要。

logits processor 的 update_state() 實現應假定模型執行程式更新持久 batch 狀態的模型(在此處以 BatchUpdate 抽象的形式表示):

  1. 識別當前引擎步中完成的請求的索引。

  2. 識別當前步中引入的新請求。

  3. 使用 Add 操作,按照被替換請求的索引從小到大的順序,用新請求替換儘可能多的已完成請求。

  4. 基於新請求和已完成請求的相對數量。

    1. 如果新請求的數量與已完成請求的數量相同,則繼續下一步。

    2. 如果新請求多於已完成請求: 應用 Add 操作,用剩餘未替換已完成請求的新請求擴充套件 batch。為這些新請求分配連續的索引,從 current_max_batch_index + 1 開始。

    3. 如果新請求少於已完成請求。

      • 對未被新請求替換的已完成請求應用 Remove 操作。這些移除的請求索引必然大於上一步中被替換的已完成請求的最大索引。Remove 操作可能會使 batch 處於非連續狀態。

      • “壓縮”batch 使其連續: 從最低索引的空槽(由 Remove 操作引起)開始,應用一個從當前 batch 中最高的非空槽到該空槽的單向移動。按照空槽目標索引遞增和非空槽源索引遞減的順序繼續進行其他單向移動操作,直到 batch 連續。

      • 收縮 batch: 壓縮 batch 的一個副作用是將 Remove 操作產生的空槽分組在一個連續塊中,位於 batch 陣列的末尾。因此,壓縮後,更新 BatchUpdate.batch_size 以反映非空槽的數量。

  5. 重新排序 batch 以提高效率。根據 attention 後端實現和 batch 的當前特性,可以應用零個或多個 Swap Move 操作來重新排序 batch。

注意事項

  • logits processor 的 update_state() 方法必須按以下順序處理 batch 更新操作:移除 (removes)、新增 (adds)、移動 (moves)。

  • Add 操作的索引引數指的是 Add 操作發生時的索引,即在任何 Move 操作之前。

    • 例如:如果一個請求在索引 5 處被新增,然後與索引 3 交換,那麼 BatchUpdate.added 中的 Add 操作將與索引 5 相關聯,而不是 3。
    • 換句話說,可以假定 Move 操作是在 Adds 和 Removes 之後應用的。
  • 可以假定 Move 操作是按照它們在 BatchUpdate.moved 中出現的順序應用的。

  • 如果沒有新/已完成請求,也沒有 batch 重排,那麼 logits processor 的 batch 更新將為 None

將自定義引數傳遞給自定義 Logits Processor

與內建 logits processor 不同,自定義 logits processor 可能需要配置引數,這些引數並未硬編碼到 SamplingParams 或 vLLM 伺服器 REST API 中。為了解決這個問題,自定義 logits processor 可以利用 vLLM 的 自定義引數 支援,從使用者那裡接收配置設定(儘管您也可以自由設計一個利用 SamplingParams 中現有欄位的自定義 logits processor)。

示例自定義 Logits Processor 實現

下面的示例實現了一個自定義 logits processor,它消耗一個 (num_requests) x (vocab_size) 的 logits 張量,並將除了一個 token (target_token) 之外的所有 token 都掩蔽為 float(-inf)。對於任何未指定 target_token 的請求,該 logits processor 將被停用。為了確定 logits processor 是否啟用以及哪個 token 將保持不被掩蔽,logits processor 會檢查 SamplingParams.extra_args 中與每個請求關聯的 target_token 自定義引數。

示例自定義 logits processor 定義
import torch
from vllm.config import VllmConfig
from vllm.sampling_params import SamplingParams
from vllm.v1.sample.logits_processor import (BatchUpdate,
                                            LogitsProcessor,
                                            MoveDirectionality)

class DummyLogitsProcessor(LogitsProcessor):
    """Fake logit processor to support unit testing and examples"""

    @classmethod
    def validate_params(cls, params: SamplingParams):
        target_token: int | None = params.extra_args and params.extra_args.get(
            "target_token"
        )
        if target_token is not None and not isinstance(target_token, int):
            raise ValueError(f"target_token value {target_token} is not int")

    def __init__(self, vllm_config: "VllmConfig", device: torch.device,
                is_pin_memory: bool):
        self.req_info: dict[int, int] = {}

    def is_argmax_invariant(self) -> bool:
        """Never impacts greedy sampling"""
        return False

    def update_state(self, batch_update: BatchUpdate | None):
        if not batch_update:
            return

        # Process added requests.
        for index, params, _, _ in batch_update.added:
            assert params is not None
            self.validate_params(params)
            if params.extra_args and (target_token :=
                                    params.extra_args.get("target_token")):
                self.req_info[index] = target_token
            else: 
                self.req_info.pop(index, None)

        if self.req_info:
            # Process removed requests.
            for index in batch_update.removed:
                self.req_info.pop(index, None)

            # Process moved requests, unidirectional move (a->b) and swap
            # (a<->b)
            for adx, bdx, direct in batch_update.moved:
                a_val = self.req_info.pop(adx, None)
                b_val = self.req_info.pop(bdx, None)
                if a_val is not None:
                    self.req_info[bdx] = a_val
                if direct == MoveDirectionality.SWAP and b_val is not None:
                    self.req_info[adx] = b_val

    def apply(self, logits: torch.Tensor) -> torch.Tensor:
        if not self.req_info:
            return logits

        # Save target values before modification
        cols = torch.tensor(
            list(self.req_info.values()), dtype=torch.long, device=logits.device
        )
        rows = torch.tensor(
            list(self.req_info.keys()), dtype=torch.long, device=logits.device
        )
        values_to_keep = logits[rows, cols].clone()

        # Mask all but target tokens
        logits[rows] = float('-inf')
        logits[rows, cols] = values_to_keep

        return logits

在本文件的其餘部分,我們將使用 DummyLogitsProcessor 作為自定義 logits processor 的示例。

DummyLogitsProcessor.update_state() 實現使用 `self.req_info` 字典來維護 batch 請求的“稀疏”表示:只有指定了 target_token 值的請求才會在字典中有一個鍵。update_state() 根據 Add、Remove 和 Move 操作對持久 batch 的響應,調整儲存的請求索引和 target_token 值(分別是 `self.req_info` 中的鍵和值)。

封裝現有的請求級別 Logits Processor

儘管 vLLM 引擎以 batch 粒度應用 logits processor,但有些使用者可能希望使用“請求級別”的 logits processor 實現與 vLLM 結合使用——即,一個作用於單個請求的實現。如果您的 logits processor 是為 vLLM 版本 0 開發的,這尤其會如此,當時它需要是 Callable(如 此處 所述),並符合以下型別註解:

RequestLogitsProcessor = Union[

    # (output token ids, logits tensor) -> logits tensor
    Callable[[list[int], Tensor], Tensor],

    # (prompt token ids, output token ids, logits tensor) -> logits tensor
    Callable[[list[int], list[int], Tensor], Tensor],
]

雖然請求級別的 logits processor 在 vLLM 引擎中明確**不支援**,但 vLLM **提供**了一個便捷的方法來封裝現有的 Callable 請求級別 logits processor,並建立一個與 vLLM 相容的 batch 級別 logits processor。Callable 必須符合上述型別註解;如果您的請求級別 logits processor 具有不同的介面,那麼為了封裝它,您可能需要修改它或實現一個額外的封裝層以符合上述介面規範。

您可以透過繼承 AdapterLogitsProcessor 來封裝請求級別的 logits processor,如下面的示例所示(在此示例中,DummyPerReqLogitsProcessor 是您需要封裝的請求級別 logits processor 的一個佔位符)。

  • 重寫 AdapterLogitsProcessor.validate_params(cls,params) 來驗證請求的取樣引數。

  • 重寫 AdapterLogitsProcessor.is_argmax_invariant(self) 來準確反映您的請求級別 logits processor 是否會影響哪個 token 具有最高值的 logit。

  • 重寫 AdapterLogitsProcessor.new_req_logits_processor(self,params) 來從 SamplingParams 例項建立一個新的請求級別 logits processor 例項。

封裝請求級別 Logits Processor 的示例
...

from vllm.v1.sample.logits_processor import (
    AdapterLogitsProcessor, # Wrapper base-class
    RequestLogitsProcessor, # Request-level logitsproc type annotation
)

...

# Stand-in for your request-level logits processor:
class DummyPerReqLogitsProcessor:
    """The request-level logits processor masks out all logits except the
    token id identified by `target_token`"""

    def __init__(self, target_token: int) -> None:
        """Specify `target_token`"""
        self.target_token = target_token

    def __call__(
        self,
        output_ids: list[int],
        logits: torch.Tensor,
    ) -> torch.Tensor:
        val_to_keep = logits[self.target_token].item()
        logits[:] = float("-inf")
        logits[self.target_token] = val_to_keep
        return logits

...

# Example of wrapping the request-level logits processor:
class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
    """Example of wrapping a fake request-level logit processor to create a
    batch-level logits processor"""

    @classmethod
    def validate_params(cls, params: SamplingParams):
        target_token: Any | None = params.extra_args and params.extra_args.get(
            "target_token"
        )
        if target_token is not None and not isinstance(target_token, int):
            raise ValueError(
                f"target_token value {target_token} is not int"
            )

    def is_argmax_invariant(self) -> bool:
        return False

    def new_req_logits_processor(
        self,
        params: SamplingParams,
    ) -> Optional[RequestLogitsProcessor]:
        """This method returns a new request-level logits processor, customized
        to the `target_token` value associated with a particular request.

        Returns None if the logits processor should not be applied to the
        particular request. To use the logits processor the request must have
        a "target_token" custom argument with an integer value.

        Args:
        params: per-request sampling params

        Returns:
        `Callable` request logits processor, or None
        """
        target_token: Any | None = params.extra_args and params.extra_args.get(
            "target_token"
        )
        if target_token is None:
            return None
        return DummyPerReqLogitsProcessor(target_token)

注意

您的 new_req_logits_processor() 重寫可以返回 None,以指示不應將封裝的 logits processor 應用於當前請求。

一旦您建立了一個自定義子類(例如 WrappedPerReqLogitsProcessor)來封裝您的請求級別 logits processor,您就可以透過下一節中描述的任何方法將其傳遞給 vLLM。

在 vLLM 中載入自定義 Logits Processor 的方法

Logits processor 在初始化時載入。重要的是,載入的 logits processor 集合在 vLLM 引擎完成載入後不能被修改,也不能為單個請求按需載入新的 logits processor。

本節詳細介紹了讓您的 logits processor 對 vLLM 可見並觸發 vLLM 載入您的 logits processor 的各種方法。

方法 1:在初始化時將自定義 Logits Processor 的完全限定類名 (FQCN) 傳遞給 vLLM

此方法在離線和線上 vLLM 使用場景中都受支援。自定義 logits processor 的 FQCN(形式為 dotted.path.to.module:ClassName)可以作為引數傳遞給 LLMAsyncLLM Python 建構函式,或者作為 CLI 引數傳遞給 vllm serve,語法如下:

vllm serve ... --logits_processors <logits processor 1> <logits processor 2> ...

FQCN 的唯一要求是:

  1. Python 的 importlib.import_module() 必須能夠解析 FQCN 的點狀路徑並將其作為模組載入。

  2. FQCN 的類名部分必須能夠從載入的模組中匯入。

  3. FQCN 指向的物件必須是 LogitsProcessor 的子類。

請參閱下面的示例。

在 Python 中將自定義 logits processor FQCN 傳遞給 LLM
# Pass in FQCN
llm = LLM(
    model="facebook/opt-125m",
    logits_processors=["your.module.path:DummyLogitsProcessor"],
)
在 Python 中將自定義 logits processor FQCN 傳遞給 AsyncLLM
# Pass in FQCN
engine_args = AsyncEngineArgs(model="facebook/opt-125m",
                              logits_processors=["your.module.path:DummyLogitsProcessor"])
async_llm = AsyncLLM.from_engine_args(engine_args)
透過 CLI 將自定義 logits processor FQCN 傳遞給 vLLM 伺服器
vllm serve facebook/opt-125m --logits_processors your.module.path:DummyLogitsProcessor

方法 2:自動檢測安裝在您的 Python 環境中作為入口點的自定義 Logits Processors

setuptools 可以使已安裝的包成為其他 Python 程式的外掛,透過稱為“入口點”的元資料片段。

在初始化期間,vLLM 會自動掃描 vllm.logits_processors 入口點組,並載入它找到的所有已安裝的 logits processor。

假設您開發了一個包含自定義 logits processor 的 Python 包。您可以透過為每個 logits processor 在您的 logits processor Python 包中新增一個唯一的入口點來將其暴露給 vLLM。下面的示例展示瞭如何向專案的 pyproject.toml 檔案新增一個入口點:

將自定義 logits processor 作為 Python 入口點公開
[project.entry-points."vllm.logits_processors"]
dummy_logits_processor = "your.module.path:DummyLogitsProcessor"

一旦您的包被安裝,每當 vLLM 初始化時,您的自定義 logits processor 都將被自動載入。如果您的 logits processor 是透過入口點公開的,您**不需要**在初始化時顯式地將自定義 logits processor 傳遞給 LLMAsyncLLM 建構函式或 vLLM 伺服器。

注意

vLLM 將**始終**載入**所有**透過 vllm.logits_processors 分組公開的 logits processor。

方法 3 (僅限離線):將 Python 類物件傳遞給 vLLM 建構函式

您可以將一個或多個自定義 logits processor 類物件傳遞給 LLMAsyncLLM 建構函式。此選項非常靈活,因為 logits processor 類可以是 (1) 在與 LLMAsyncLLM 例項化的相同 Python 原始檔中本地定義的,或者 (2) 從 Python 包匯入的。

在 Python 中將自定義 logits processor 類物件傳遞給 LLMAsyncLLM
# Import custom logits processor
from some.module import DummyLogitsProcessor

# ...or...

# Define custom logits processor locally
from vllm.v1.sample.logits_processor import LogitsProcessor

class DummyLogitsProcessor(LogitsProcessor):
    # See DummyLogitsProcessor implementation above
    ...

# Pass class object to LLM constructor
llm = LLM(
    model="facebook/opt-125m",
    logits_processors=[DummyLogitsProcessor],
)

# Pass class object to AsyncLLM constructor
engine_args = AsyncEngineArgs(model="facebook/opt-125m",
                              logits_processors=[DummyLogitsProcessor])
async_llm = AsyncLLM.from_engine_args(engine_args)

針對請求呼叫自定義 Logits Processor

自定義 logits processor 的設計決定了是否必須為給定請求啟用/停用 logits processor,以及必須提供哪些引數來配置 logits processor。

下面的示例展示了使用者如何向 DummyLogitsProcessor 傳遞自定義引數 (target_token) 以 (1) 為特定請求啟用 logits processor 並 (2) 控制 logits processor 的行為。

vLLM REST API:配置請求的自定義 logits processor
curl https://:8000/v1/completions \
    -H "Content-Type: application/json" \
    -d '{
        "model": "Qwen/Qwen2.5-1.5B-Instruct",
        ...
        "vllm_xargs": {"target_token": 67}
    }'
OpenAI SDK:配置請求的自定義 logits processor
batch = await client.completions.create(
    model="Qwen/Qwen2.5-1.5B-Instruct",
    ...,
    extra_body={
        "vllm_xargs": {
            "target_token": 67
        }
    }
)
離線:為 LLM 請求配置自定義 logits processor
outputs_logitproc = llm.generate("your prompt", 
                                 SamplingParams(...,
                                    extra_args={"target_token": 67}))
離線:為 AsyncLLM 請求配置自定義 logits processor
async for out in engine.generate(request_id="your request id",
                                 prompt="your prompt",
                                 sampling_params=SamplingParams(...,
                                    extra_args={"target_token": 67})):

    # Process async request outputs
    ...

編寫自定義 Logits Processors 的最佳實踐

一旦 vLLM 在初始化期間載入了 logits processor,vLLM 將在每個引擎步中對該 logits processor 呼叫 update_state()apply()。這兩個方法都作用於當前位於 vLLM 持久 batch 中的所有請求。因此,高效實現這些方法非常重要。

  • 考慮到 logits processor 以 batch 粒度執行,請編寫高效的 apply()update_state() 實現。

    • 例如,您可能能夠使用高效的向量化操作來實現 apply() 或在 update_state() 中更新內部狀態向量。
    • 但是,如果您認為某個 logits processor 可能不經常使用,那麼使用“稀疏”表示請求狀態可能是合適的,即該類可以使用一個字典來表示請求配置,該字典僅儲存啟用 logits processor 的請求的元資料。
    • 注意: 封裝的請求級別 logits processor 不需要實現 apply()update_state();預設的 AdapterLogitsProcessor.update_state() 實現維護請求狀態的稀疏表示,其中 new_req_logits_processor() 返回 None 的請求在基類狀態字典中不被表示。AdapterLogitsProcessor.apply() 的預設實現將請求級別 logits processor 順序應用於輸入 logits 的每一行,並組裝輸出 logits 張量。如果此 AdapterLogitsProcessor 預設實現的效能不足,則避免封裝您的請求級別 logits processor,而是將其重新實現為具有最佳化後的 apply()update_state() 實現的 LogitsProcessor 子類,這些實現以 batch 粒度執行。
  • 由 logits processor 作者決定:

    1. 配置 logits processor 對該請求行為的每個請求屬性。 您自定義 logits processor 的 update_state() 重寫決定了如何將 SamplingParams 欄位對映到 logits processor 狀態。

      • 注意: 對於封裝的請求級別 logits processor,new_req_logits_processor() 決定了如何使用 SamplingParams 欄位來初始化請求級別 logits processor 例項。
    2. logits processor 在每個請求基礎上啟用或不啟用的條件。 除非您的目的是讓自定義 logits processor 始終作用於所有請求,否則您應該以這樣一種方式編寫您的 logits processor,即有可能為特定請求停用 logits processor,例如透過將引數預設為 None 或傳入特定的無操作引數值(例如 0.0)。對於停用 logits processor 的請求,儘量節省計算和記憶體。

      • 注意: 對於封裝的請求級別 logits processor,預設的 AdapterLogitsProcessor.update_state() 實現確保在 new_req_logits_processor() 為該請求返回 None 時停用請求級別 logits processor。
    3. logits processor 在 batch 級別被短路的條件。 即使您已經定義了在請求級別停用自定義 logits processor 的方法,也很難將其轉化為計算節省,例如,如果您的 update_state()apply() 實現使用了在單個命令中作用於整個持久 batch 的高效向量化實現。例如,您不能僅僅因為一個請求停用了 logits processor 就跳過 apply() 中的整個向量化操作。為了在沒有執行請求使用自定義 logits processor 的邊緣情況下節省計算,我們建議將 apply() 設計為在所有請求都停用 logits processor 時返回未修改的輸入張量。同樣,請考慮在沒有請求啟用 logits processor 的情況下是否可以跳過 update_state() 中的步驟。

      • 此外,一個在 update_state() 中節省計算的簡單方法是,當 batch_updateNone 時提前退出。

      • 注意: 對於封裝的請求級別 logits processor,AdapterLogitsProcessor 基類預設實現了上述最佳化。

  • 確保 logits processor 的 update_state 方法丟棄已完成請求的資訊(例如,被 Add 操作替換或被 Remove 操作處理的請求)。

    • 注意: 對於封裝的請求級別 logits processor,AdapterLogitsProcessor 基類預設處理此問題。
  • 如果 logits processor 具有一致的行為,is_argmax_invariant() 可以硬編碼為 TrueFalse。但是,argmax 不變性也可以透過程式設計方式確定(例如,如果您的 logits processor 是使用者可定製的,以某種方式影響 logits processor 是否是 argmax 不變的)。因此,is_argmax_invariant() 不是一個類方法。