Logits Processors¶
重要
一些 Logits Processors 的設計更改仍在進行中,API 在不久的將來可能會發生變化。我們希望儘快穩定 API 的這部分。
本文件描述了 vLLM 引擎如何與 logits processors 互動,以及 vLLM 支援的用於實現 logits processors 的程式設計模型。
Logits Processors 背景¶
Logits processor 會調整下一個 token 的機率分佈,通常是為了引導模型產生期望的行為。
在 vLLM 中,logits processors 以批次粒度執行。在給定的引擎步中,logits processor 接收模型輸出的 (num_requests) x (vocab_size) 大小的原始 logits 張量。對於所有啟用該 logits processor 的請求,logits processor 會對 logits 張量的相應行應用變換,而其他行則保持不變。然後將變換後的 logits 張量傳遞給 softmax。
vLLM 引擎中的 Logits Processors¶
vLLM 引擎的持久批次資料結構維護著一個已載入 logits processors 的列表。
為了同時處理整個批次,每個 logits processor 可能會維護關於批次中請求的元資料(即每個請求的特定於 logits processor 的配置設定)。因此,logits processors 是有狀態的。
在每個引擎步中,vLLM 引擎將(1)更新每個 logits processor 的內部狀態,以及(2)將 logits processors 應用於模型輸出的 logits。
更新 Logits Processor 內部狀態¶
在每個引擎步的開始,持久批次可能會根據排程器的輸出新增、丟棄和/或重新排序請求。在持久批次重新組織後,vLLM 引擎會呼叫每個 logits processor 的 update_state() 方法。這是為了確保 logits processor 的內部狀態能夠與引擎步開始時的新的持久批次狀態匹配。
以下虛擬碼展示了 vLLM 持久批次通知每個 logits processor 批次狀態變化的程序。
模型執行器更新 Logits Processor 狀態
# gpu_model_runner.py
class GPUModelRunner(...):
...
def execute_model(self, scheduler_output, ...):
self._update_states(scheduler_output)
...
def _update_states(...):
...
# ...update persistent batch to reflect new/finished requests & reordering
# of requests within batch...
...
self.input_batch.refresh_metadata()
# gpu_input_batch.py
class InputBatch:
...
def refresh_metadata(self):
...
# Update each logits processor's state to reflect persistent batch state
batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
for logit_proc in self.logitsprocs.all:
logit_proc.update_state(batch_update)
...
# vllm/v1/sample/logits_processor/interface.py
@dataclass(frozen=True)
class BatchUpdate:
# Batch state-change data structure which is passed to logits processors'
# update_state() methods
batch_size: int
removed: Sequence[RemovedRequest]
added: Sequence[AddedRequest]
moved: Sequence[MovedRequest]
將 Logits Processors 應用於模型輸出 Logits¶
在更新持久批次狀態後,vLLM 模型執行器會執行模型推理以獲得 logits。然後,模型執行器會對 logits 呼叫 sampler。sampler 的一部分操作是針對模型輸出 logits 呼叫 logits processors 的 apply() 方法,得到變換後的 logits(apply() 方法可以就地或非就地修改 logits,但就地修改更節省記憶體)。此過程如下圖所示。
請注意,sampler 將透過 SamplingMetadata.logitsprocs 訪問 logits processors。當 vLLM 引擎構建 SamplingMetadata 時(下圖未顯示),到 logits processors 列表的引用將從持久批次資料結構傳遞到 SamplingMetadata。
將 Logits Processors 應用於模型輸出 Logits
# gpu_model_runner.py
class GPUModelRunner(...):
...
def execute_model(self, scheduler_output, ...):
# (discussed in previous section)
self._update_states(scheduler_output)
...
# ...run model inference to obtain logits...
...
# Invoke sampler, which applies logits processors
sampler_output = self.sampler(logits=logits,
sampling_metadata=sampling_metadata)
...
# sampler.py
class Sampler(nn.Module):
...
def forward(self, logits, sampling_metadata):
...
# Apply non-argmax-invariant logits processors to model output logits
for processor in (sampling_metadata.logitsprocs.non_argmax_invariant):
logits = processor.apply(logits)
sampled = self.sample(logits, sampling_metadata)
...
# ...return sampler output data structure...
def sample(self, logits, sampling_metadta)
...
# ...exit early if all requests are greedy-sampling...
...
# Apply argmax-invariant logits processors
for processor in sampling_metadata.logitsprocs.argmax_invariant:
logits = processor.apply(logits)
...
# ...perform sampling and return sampling result...
在取樣時,sampler 會檢查持久批次中的所有請求是否都使用了貪婪取樣。如果是這種情況,sampler 會透過跳過“argmax 不變”的 logits processors 來節省計算。此處,“argmax”是給定 logits 張量行中具有最高 logit 值的 token ID 的簡稱(即模型為給定請求加權最高的 token)。
-
argmax 不變 Logits Processor 是一個不會改變 argmax 的 logits processor(例如 Min-P)。例如,一個遮蔽低機率 token 的 logits processor 不會改變具有最高 logit 的 token ID。貪婪取樣總是選擇具有最高 logit 值的 token ID,因此概念上,對於貪婪取樣請求,可以跳過 argmax 不變的 logits processor。
-
非 argmax 不變 Logits Processor 是一個可能改變 argmax 的 logits processor。例如,一個在一定步數後遮蔽除 EOS 之外所有 token 以強制解碼終止的 logits processor,可能會遮蔽最高 logit 值的 token,從而改變 argmax。概念上,這些 logits processors 不能為貪婪取樣請求而跳過。
vLLM logits processor 抽象要求引擎以批次粒度應用 logits processors;因此,實際上,只有當整個批次都使用貪婪取樣時,才能跳過 argmax 不變的 logits processors。
Logits Processor 程式設計模型¶
前面的章節暗示了 vLLM logits processors 必須支援的介面。本節將全面介紹用於實現與 vLLM 引擎相容的 logits processors 的程式設計模型,包括 LogitsProcessor 基類及其介面方法,以及用於表示持久批次狀態變化的 BatchUpdate 資料結構,這兩者都顯示在下面的程式碼中。
LogitsProcessor 基類和 BatchUpdate 資料結構
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum, auto
from typing import TYPE_CHECKING
import torch
from vllm import SamplingParams
if TYPE_CHECKING:
from vllm.config import VllmConfig
class MoveDirectionality(Enum):
# One-way i1->i2 req move within batch
UNIDIRECTIONAL = auto()
# Two-way i1<->i2 req swap within batch
SWAP = auto()
# (index, params, prompt_tok_ids, output_tok_ids) tuples for new
# requests added to the batch.
AddedRequest = tuple[int, SamplingParams, list[int], list[int]]
# (index 1, index 2, directionality) tuples representing
# one-way moves or two-way swaps of requests in batch
MovedRequest = tuple[int, int, MoveDirectionality]
# Batch indices of any removed requests.
RemovedRequest = int
@dataclass(frozen=True)
class BatchUpdate:
"""Persistent batch state change info for logitsprocs"""
batch_size: int # Current num reqs in batch
# Metadata for requests added to, removed from, and moved
# within the persistent batch.
#
# Key assumption: the `output_tok_ids` list (which is an element of each
# tuple in `added`) is a reference to the request's running output tokens
# list; via this reference, the logits processors always see the latest
# list of generated output tokens
removed: Sequence[RemovedRequest]
moved: Sequence[MovedRequest]
added: Sequence[AddedRequest]
class LogitsProcessor(ABC):
@abstractmethod
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
is_pin_memory: bool) -> None:
raise NotImplementedError
@abstractmethod
def apply(self, logits: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
@abstractmethod
def is_argmax_invariant(self) -> bool:
"""True if logits processor has no impact on the
argmax computation in greedy sampling.
NOTE: may or may not have the same value for all
instances of a given LogitsProcessor subclass,
depending on subclass implementation.
"""
raise NotImplementedError
@abstractmethod
def update_state(
self,
batch_update: "BatchUpdate" | None,
) -> None:
"""Called when there are new output tokens, prior
to each forward pass.
Args:
batch_update is non-None iff there have been
changes to the batch makeup.
"""
raise NotImplementedError
@classmethod
def validate_params(cls, sampling_params: SamplingParams):
"""Validate sampling params for this logits processor.
Raise ValueError for invalid ones.
"""
return None
vLLM logits processor 必須繼承 LogitsProcessor 並定義(至少)以下方法:
-
__init__(self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool)vllm_config:引擎配置資料結構。device:硬體加速器裝置資訊。is_pin_memory:指示是否可用 pinned memory 來支援 logits processor 實現的標誌。
-
apply(self, logits: torch.Tensor) -> torch.Tensor:- 接收一個
(num_requests) x (vocab_size)大小的 logits 張量(logits)。 - 以批次粒度應用 logits processor 變換。
- 返回一個變換後的
(num_requests) x (vocab_size)大小的 logits 張量。 - 您可以就地或非就地修改輸入的 logits processor;就地修改更節省記憶體。
- 接收一個
-
is_argmax_invariant(self) -> bool:- 如果 logits processor 是 argmax 不變的(從不改變給定請求的最高 logit 值的 token ID),則返回
True,如果 logits processor 可能會修改 argmax,則返回False。 is_argmax_invariant()在啟動時評估一次;如果為True,vLLM 將在所有請求使用貪婪取樣時跳過應用此 logits processor。
- 如果 logits processor 是 argmax 不變的(從不改變給定請求的最高 logit 值的 token ID),則返回
-
update_state(self, batch_update: "BatchUpdate" | None) -> None:- 接收一個
BatchUpdate資料結構,表示當前引擎步開始時的持久批次狀態變化。 - 使用
BatchUpdate成員來更新 logits processor 內部狀態。 - 注意: batch update 資料結構可能為
None,表示批次構成沒有變化。在這種情況下,LogitsProcessor 可能仍然希望根據它在新增時可能保留的更新的output_token_ids列表來更新其狀態。
- 接收一個
-
validate_params(cls, sampling_params: SamplingParams):- 如果
SamplingParams包含 logits processor 使用的無效引數(尤其是自定義引數),則引發ValueError。 - 當請求傳送到入口點時,
validate_params()將驗證SamplingParams並拒絕帶有無效引數的請求。
- 如果
BatchUpdate 資料結構¶
BatchUpdate 抽象模型將持久批次表示為請求列表,支援以下操作來改變批次狀態(請注意,下面操作的順序反映了它們在 update_state() 中應處理的順序)。
-
移除:在索引
i處移除請求(不替換)。-
Batchupdate.removed中的移除操作由一個int表示(代表i)。 -
remove-at-index 對批次的影響。
-
-
新增:在索引
i處新增(或替換現有請求為)一個新請求。如果替換了請求,其關聯的狀態應被丟棄。-
Batchupdate.added中的新增操作表示為包含以下元素的元組: -
prompt token ids和output token ids分別是對請求的 prompt token ids 和 output token ids 列表的引用。請注意,output token ids 列表會隨著每個引擎步的進行而增長,並且 Logits Processor 可以看到這種增長,因為 output token ids 是按引用傳遞的。這對於那些考慮了迄今為止生成 token 的 LogitsProcessors 很重要。 -
特定 logits processor 子類的實現決定了如何或是否將新增的請求元組中的欄位解析為其內部表示。例如,一個不使用 prompt 或 output token ids 的 logits processor 可能只需要使用
index和SamplingParams,而丟棄其他元組欄位。 -
如果索引
i當前包含一個請求,則會發生替換。 -
如果索引
i當前不包含請求(因為i超出了當前批次大小的範圍)。
-
-
移動:將索引
s的請求移動到索引d,或者交換索引s和d的請求。-
Batchupdate.moved中的移動操作表示為包含以下元素的元組: -
如果移動指定了
UNIDRECTIONAL。-
索引
s的請求被移動到索引d;索引s變成一個空槽。 -
如果索引
d已經存在一個請求,它將被替換並丟棄。
-
-
如果移動指定了
SWAP,則索引s和d的請求交換索引。
-
此外,BatchUpdate 資料結構還包括引擎步開始時持久批次大小的表示(batch_size)。
vLLM 引擎如何構建 BatchUpdate 資料結構¶
Logits processor update_state() 的實現應假定模型執行器更新持久批次狀態的模型如下(此處以 BatchUpdate 抽象來表示):
-
識別當前引擎步中完成的請求的索引。
-
識別當前步引入的新請求。
-
使用 Add 操作,按被替換請求的升序索引(從最小索引開始)將盡可能多的已完成請求替換為新請求。
-
基於新請求和已完成請求的數量。
-
如果新請求和已完成請求的數量相同,則繼續下一步。
-
如果新請求多於已完成請求:應用 Add 操作,用剩餘未替換已完成請求的新請求擴充套件批次。為這些新請求分配連續索引,從
current_max_batch_index + 1開始。 -
如果新請求少於已完成請求。
-
對未被新請求替換的已完成請求應用 Remove 操作。這些移除請求的索引必然大於上一步被替換的已完成請求的最大索引。移除操作可能會使批次處於非連續狀態。
-
“壓縮”批次使其連續:從最低索引的空槽(由 Remove 操作引起)開始,應用一個單向移動(Unidirectional Move),從當前批次中最高非空槽填充空槽。按空槽目標索引的遞增順序和非空槽源索引的遞減順序進行其他單向移動操作,直到批次連續。
-
縮小批次:壓縮批次的一個副作用是,由 Remove 操作產生的空槽會聚集在批次陣列的末尾形成一個連續塊。因此,壓縮後,更新
BatchUpdate.batch_size以反映非空槽的數量。
-
-
-
重新排序批次以提高效率。根據注意力後端實現和當前批次的特性,可能會應用零個或多個 Swap Move 操作來重新排序批次。
注意事項
-
Logits processor
update_state()方法必須按以下順序處理批次更新操作:移除、新增、移動。 -
Add 操作的索引引數指的是 Add 操作發生時的索引,即在任何 Move 操作之前。
- 示例:如果一個請求在索引 5 處被新增,然後與索引 3 交換,那麼
BatchUpdate.added中的 Add 操作將與索引 5 相關聯,而不是 3。 - 換句話說,可以假定 Move 操作是在 Add 和 Remove 操作之後應用的。
- 示例:如果一個請求在索引 5 處被新增,然後與索引 3 交換,那麼
-
可以假定 Move 操作是按照它們在
BatchUpdate.moved中出現的順序應用的。 -
如果沒有新/已完成請求,也沒有批次重新排序,那麼 logits processors 的批次更新將是
None。
示例:新請求少於完成請求的批次更新¶
以下示例模擬了一個引擎步,其中引入了 1 個新請求,並移除了 2 個已完成請求,此外,注意力後端執行了交換以最佳化批次排序。
Batch state (beginning of engine step): [A,B,C,D]
Batch size: 4
New requests: E
Finished requests: A, C
Processing steps (using BatchUpdate abstraction):
1. Add E at index 0
[E,B,C,D] # Discard A
Batch size: 4
2. Remove at index 2
[E,B,x,D] # Discard C, empty slot at index 2
Batch size: 4
3. Condense batch with a Unidirectional Move 3 -> 2 operation and shrink batch
[E,B,D] x # Empty slot is now outside batch
Batch size: 3
4. Attention backend optimization: reorder batch with Swap 0 <-> 1
[B,E,D]
Batch size: 3
生成的 BatchUpdate 資料結構將如下所示:
BatchUpdate instance
* added: [(0,E's SamplingParams,E's prompt tokens ref,E's output tokens ref)]
* removed: [2] # request C was removed without replacement
* moved: [(3,2,UNIDIRECTIONAL),(0,1,SWAP)]
示例:新請求多於完成請求的批次更新¶
以下示例模擬了一個引擎步,其中引入了 2 個新請求,並移除了 1 個已完成請求,此外,注意力後端執行了交換以最佳化批次排序。
Batch state (beginning of engine step): [A,B,C,D]
Batch size: 4
New requests: E,F
Finished requests: C
Processing steps (using BatchUpdate abstraction):
1. Add E at index 2
[A,B,E,D] # Discard C
Batch size: 4
2. Add F at index 4 (current max batch index + 1)
[A,B,E,D,F] # Extend batch by 1
Batch size: 5
4. Attention backend optimization: reorder batch with Swap 0 <-> 1
[B,A,E,D,F]
Batch size: 5
請注意,由於 Remove 操作沒有留下空槽,因此跳過了批次壓縮。
生成的 BatchUpdate 資料結構將如下所示:
BatchUpdate instance
* added: [(2,E's SamplingParams,E's prompt tokens ref,E's output tokens ref),(4,F's SamplingParams,F's prompt tokens ref,F's output tokens ref)]
* removed: [] # no requests were removed without replacement
* moved: [(0,1,SWAP)]
如何向 vLLM 引入新的 Logits Processor¶
編寫內建 Logits Processors 的最佳實踐¶
-
考慮到 logits processors 以批次粒度執行,請編寫高效的
apply()和update_state()實現。- 例如,您可能可以使用高效的向量化操作來實現
apply()或在update_state()中更新內部狀態向量。 - 但是,如果您認為某個 logits processor 可能不常使用,則可以使用“稀疏”的請求狀態表示,即該類可以使用字典來表示請求配置,該字典僅儲存啟用 logits processor 的請求的元資料。
- 例如,您可能可以使用高效的向量化操作來實現
-
這取決於 logits processor 的作者來決定:
-
配置 logits processor 對該請求行為的每個請求屬性。 例如,如果您正在為 vLLM 編寫一個新的內建 logits processor,您可能需要向
SamplingParams和 vLLM REST API 新增額外的欄位,也可能不需要。 -
Logits processor 在每個請求的基礎上啟用或不啟用的條件。 除非您的目的是讓內建 logits processor 始終對所有請求起作用,否則您應該編寫您的 logits processor,使其能夠為特定請求停用 logits processor,例如透過將引數預設設定為
None或傳遞特定的無操作引數值,即0.0。請嘗試為停用 logits processor 的請求節省計算和記憶體。 -
Logits processor 在批次級別短路的條件。 即使您已定義了在請求級別停用內建 logits processor 的方法,也很難將其轉化為計算節省,例如,如果您的
update_state()和apply()實現使用了在整個持久批次上執行一次的向量化實現。例如,即使一個請求停用了 logits processor,您也不能僅憑此跳過apply()中的整個向量化操作。為了在沒有執行請求使用內建 logits processor 的邊緣情況下節省計算,我們建議將apply()設計為在所有請求都停用 logits processor 時返回未修改的輸入張量。同樣,考慮在沒有請求啟用 logits processor 的情況下是否可以跳過update_state()中的步驟。- 此外,在
update_state()中節省計算的一種簡單方法是在batch_update為None時提前退出。
- 此外,在
-
-
確保 logits processor
update_state方法丟棄已完成請求(即被 Add 操作替換或受 Remove 操作影響的請求)的資訊。 -
如果 logits processor 具有一致的行為,
is_argmax_invariant()可以硬編碼為True或False。但是,argmax 的不變性也可以透過程式設計方式確定(例如,如果您的 logits processor 是使用者可自定義的,並且以某種方式影響了 logits processor 是否是 argmax 不變的)。因此,is_argmax_invariant()不是一個類方法。
內建 Logits Processors¶
內建 logits processors 在 vLLM 引擎啟動時始終會載入。請參閱 vllm/v1/sample/logits_processor/builtin.py 中的現有 vLLM 內建 logits processors,以獲取關於如何編寫新的內建 vLLM logits processor 的示例。如果某個 logits processor 可能對廣大使用者有用,那麼將其作為內建處理器引入是合理的。vLLM 目前根據上述程式設計模型使用以下內建 logits processors:
-
Min-P
-
Logit 偏差
-
Min-tokens
請參考這些 logits processor 的實現,以獲得編寫內建 logits processors 的指導。
此外,以下類似 logits processor 的功能已被硬編碼到 sampler 中,尚未利用上述程式設計模型。其中大部分將被重構為使用上述 logits processor 程式設計模型。
-
允許的 token IDs
-
不良詞彙
-
重複懲罰
-
頻率懲罰
-
存在懲罰
-
溫度
-
Top-K
-
Top-P
自定義 Logits Processors¶
vLLM 可以透過使用者提供的自定義 logits processors 進行擴充套件。