vLLM Paged Attention¶
目前,vLLM 使用其自己實現的多頭查詢注意力核心(`csrc/attention/attention_kernels.cu`)。此核心設計用於相容 vLLM 的分頁 KV 快取,其中鍵(key)和值(value)快取儲存在單獨的塊中(請注意,此“塊”概念不同於 GPU 執行緒塊。因此,在後續文件中,我將 vLLM 分頁注意力塊稱為“塊”,而將 GPU 執行緒塊稱為“執行緒塊”)。
為了實現高效能,此核心依賴於專門設計的記憶體佈局和訪問方法,尤其是線上程將資料從全域性記憶體讀取到共享記憶體時。本文件旨在逐步提供對核心實現的高階解釋,以幫助那些希望瞭解 vLLM 多頭查詢注意力核心的人。閱讀本文件後,使用者可能會更好地理解並更容易地跟隨實際的實現。
請注意,本文件可能不會涵蓋所有細節,例如如何計算相應資料的正確索引或點乘實現。但是,在閱讀本文件並熟悉高階邏輯流程後,您應該更容易閱讀實際程式碼並理解其細節。
輸入¶
核心函式接受當前執行緒執行其分配工作的引數列表。其中最重要的三個引數是輸入指標 `q`、`k_cache` 和 `v_cache`,它們指向需要從全域性記憶體讀取和處理的查詢(query)、鍵(key)和值(value)資料。輸出指標 `out` 指向應寫入結果的全域性記憶體。這四個指標實際上引用多維陣列,但每個執行緒只訪問分配給它的那部分資料。為了簡化,我在此省略了所有其他執行時引數。
template<typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, int PARTITION_SIZE = 0>
__device__ void paged_attention_kernel(
... // Other side args.
const scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
... // Other side args.
)
函式簽名上方還有一系列模板引數,它們在編譯時確定。`scalar_t` 表示查詢、鍵和值資料元素的資料型別,例如 FP16。`HEAD_SIZE` 表示每個頭中的元素數量。`BLOCK_SIZE` 指每個塊中的 token 數量。`NUM_THREADS` 表示每個執行緒塊中的執行緒數量。`PARTITION_SIZE` 表示張量並行 GPU 的數量(為簡化起見,我們假設此值為 0 且張量並行已停用)。
有了這些引數,我們需要執行一系列準備工作。這包括計算當前頭索引、塊索引和其他必要的變數。但是,目前我們可以忽略這些準備工作,直接進行實際計算。一旦我們掌握了整個流程,再理解它們會更容易。
概念¶
在我們深入計算流程之前,我想先描述一些後續部分需要的概念。但是,如果您遇到任何令人困惑的術語,可以跳過此部分並稍後返回。
- 序列(Sequence):序列代表一個客戶端請求。例如,`q` 指向的資料形狀為 `[num_seqs, num_heads, head_size]`。這表示 `q` 共指向 `num_seqs` 個查詢序列資料。由於此核心是單查詢注意力核心,每個序列只有一個查詢 token。因此,`num_seqs` 等於批處理中處理的 token 總數。
- 上下文(Context):上下文由序列中已生成的 token 組成。例如,`["What", "is", "your"]` 是上下文 token,而輸入的查詢 token 是 `"name"`。模型可能會生成 token `"?"`。
- 向量(Vec):向量是一組同時獲取和計算的元素列表。對於查詢和鍵資料,向量大小(`VEC_SIZE`)的確定是為了使每個執行緒組可以一次獲取和計算 16 位元組的資料。對於值資料,向量大小(`V_VEC_SIZE`)的確定是為了使每個執行緒可以一次獲取和計算 16 位元組的資料。例如,如果 `scalar_t` 是 FP16(2 位元組),`THREAD_GROUP_SIZE` 為 2,則 `VEC_SIZE` 將為 4,而 `V_VEC_SIZE` 將為 8。
- 執行緒組(Thread group):執行緒組是一小群執行緒(`THREAD_GROUP_SIZE`),它們一次獲取並計算一個查詢 token 和一個鍵 token。每個執行緒只處理 token 資料的一部分。一個執行緒組處理的元素總數稱為 `x`。例如,如果執行緒組包含 2 個執行緒,頭大小為 8,那麼執行緒 0 處理索引為 0, 2, 4, 6 的查詢和鍵元素,而執行緒 1 處理索引為 1, 3, 5, 7 的元素。
- 塊(Block):vLLM 中的鍵和值快取資料被分成塊。每個塊儲存一個頭中固定數量(`BLOCK_SIZE`)token 的資料。每個塊可能只包含整個上下文 token 的一部分。例如,如果塊大小為 16 且頭大小為 128,那麼對於一個頭,一個塊可以儲存 16 * 128 = 2048 個元素。
- Warp:Warp 是 32 個執行緒(`WARP_SIZE`)的組,它們在流多處理器(SM)上同時執行。在此核心中,每個 warp 一次處理一個查詢 token 與一個完整塊的鍵 token 之間的計算(它可能在多次迭代中處理多個塊)。例如,如果一個上下文有 4 個 warp 和 6 個塊,則分配方式將是:warp 0 處理第 0、第 4 個塊,warp 1 處理第 1、第 5 個塊,warp 2 處理第 2 個塊,warp 3 處理第 3 個塊。
- 執行緒塊(Thread block):執行緒塊是一組執行緒(`NUM_THREADS`),它們可以訪問相同的共享記憶體。每個執行緒塊包含多個 warp(`NUM_WARPS`),在此核心中,每個執行緒塊處理一個查詢 token 與整個上下文的鍵 token 之間的計算。
- 網格(Grid):網格是執行緒塊的集合,並定義了該集合的形狀。在此核心中,形狀為 `(num_heads, num_seqs, max_num_partitions)`。因此,每個執行緒塊只處理一個頭、一個序列和一個分割槽的計算。
查詢(Query)¶
本節將介紹查詢資料在記憶體中如何儲存以及如何由每個執行緒獲取。如上所述,每個執行緒組獲取一個查詢 token 資料,而每個執行緒本身只處理一個查詢 token 資料的一部分。在每個 warp 內,每個執行緒組將獲取相同的查詢 token 資料,但會將其與不同的鍵 token 資料相乘。

每個執行緒定義自己的 `q_ptr`,它指向全域性記憶體上分配的查詢 token 資料。例如,如果 `VEC_SIZE` 為 4 且 `HEAD_SIZE` 為 128,則 `q_ptr` 指向總共包含 128 個元素(分為 128 / 4 = 32 個向量)的資料。

接下來,我們需要將 `q_ptr` 指向的全域性記憶體資料讀取到共享記憶體中,作為 `q_vecs`。需要注意的是,每個向量都被分配到不同的行。例如,如果 `THREAD_GROUP_SIZE` 為 2,則執行緒 0 將處理第 0 行的向量,而執行緒 1 將處理第 1 行的向量。透過這種方式讀取查詢資料,執行緒 0 和執行緒 1 等相鄰執行緒可以讀取相鄰記憶體,從而實現記憶體合併以提高效能。
鍵(Key)¶
與“查詢”部分類似,本節介紹鍵的記憶體佈局和分配。雖然每個執行緒組在一次核心執行中只處理一個查詢 token,但它可能在多次迭代中處理多個鍵 token。同時,每個 warp 將在多次迭代中處理多個鍵 token 塊,確保在核心執行後,所有上下文 token 都由整個執行緒組處理。在此上下文中,“處理”指的是執行查詢資料和鍵資料之間的點乘。
const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
+ kv_head_idx * kv_head_stride
+ physical_block_offset * x;
與 `q_ptr` 不同,每個執行緒中的 `k_ptr` 將在不同的迭代中指向不同的鍵 token。如上所示,`k_ptr` 根據 `k_cache` 在分配的塊、分配的頭和分配的 token 處指向鍵 token 資料。

上圖說明了鍵資料的記憶體佈局。它假設 `BLOCK_SIZE` 為 16,`HEAD_SIZE` 為 128,`x` 為 8,`THREAD_GROUP_SIZE` 為 2,並且總共有 4 個 warp。每個矩形表示一個頭中一個鍵 token 的所有元素,這些元素將由一個執行緒組處理。左半部分顯示了 warp 0 的總共 16 個鍵 token 資料塊,而右半部分表示其他 warp 或迭代的剩餘鍵 token 資料。每個矩形內部總共有 32 個向量(一個 token 包含 128 個元素),將由 2 個執行緒(一個執行緒組)單獨處理。

接下來,我們需要從 `k_ptr` 讀取鍵 token 資料並將其儲存在暫存器記憶體中作為 `k_vecs`。我們使用暫存器記憶體儲存 `k_vecs`,因為它只會被一個執行緒訪問一次,而 `q_vecs` 會被多個執行緒訪問多次。每個 `k_vecs` 將包含多個向量,用於後續計算。每個向量將在每次內部迭代中設定。向量的分配允許 warp 中的相鄰執行緒一起讀取相鄰記憶體,這再次促進了記憶體合併。例如,執行緒 0 將讀取向量 0,而執行緒 1 將讀取向量 1。在下一個內迴圈中,執行緒 0 將讀取向量 2,而執行緒 1 將讀取向量 3,依此類推。
您可能仍然對整個流程感到有些困惑。不用擔心,請繼續閱讀下一節“QK”。它將以更清晰、更高階的方式說明查詢和鍵的計算流程。
QK¶
如下圖虛擬碼所示,在整個 for 迴圈塊之前,我們獲取一個 token 的查詢資料並將其儲存在 `q_vecs` 中。然後,在外層 for 迴圈中,我們迭代指向不同 token 的不同 `k_ptrs`,並在內層 for 迴圈中準備 `k_vecs`。最後,我們對 `q_vecs` 和每個 `k_vecs` 執行點乘。
q_vecs = ...
for ... {
k_ptr = ...
for ... {
k_vecs[i] = ...
}
...
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
}
如前所述,對於每個執行緒,它一次只獲取部分查詢和鍵 token 資料。然而,在 `Qk_dot<>::dot` 中會發生跨執行緒組的規約(reduction)。因此,這裡返回的 `qk` 不僅僅是部分查詢和鍵 token 的點乘結果,而實際上是整個查詢和鍵 token 資料之間的完整結果。
例如,如果 `HEAD_SIZE` 的值為 128 且 `THREAD_GROUP_SIZE` 為 2,則每個執行緒的 `k_vecs` 將總共包含 64 個元素。然而,返回的 `qk` 實際上是 128 個查詢元素和 128 個鍵元素之間點乘的結果。如果您想了解更多關於點乘和規約的細節,可以參考 `Qk_dot<>::dot` 的實現。但是,為簡化起見,本文件中將不予介紹。
Softmax¶
接下來,我們需要計算所有 `qk` 的歸一化 Softmax,如上所示,其中每個代表一個 `qk`。為此,我們必須獲得 `qk_max` 的規約值()和 `exp_sum`()的所有 `qk` 值。規約應在整個執行緒塊中執行,涵蓋查詢 token 和所有上下文鍵 token 之間的結果。
`qk_max` 和 `logits`¶
在我們獲得 `qk` 結果後,我們可以立即用 `qk` 設定臨時的 `logits` 結果(最終,`logits` 應該儲存歸一化 Softmax 結果)。同時,我們還可以比較並收集當前執行緒組計算出的所有 `qk` 中的 `qk_max`。
if (thread_group_offset == 0) {
const bool mask = token_idx >= context_len;
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
}
請注意,這裡的 `logits` 位於共享記憶體中,因此每個執行緒組將為其自己分配的上下文 token 設定欄位。總的來說,logits 的大小應為上下文 token 的數量。
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
if (lane == 0) {
red_smem[warp_idx] = qk_max;
}
然後我們需要獲取每個 warp 的規約 `qk_max`。主要思想是讓 warp 中的執行緒相互通訊並獲得最終的 `qk` 最大值。
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
qk_max = VLLM_SHFL_SYNC(qk_max, 0);
最後,透過比較此執行緒塊中所有 warp 的 `qk_max`,我們可以從整個執行緒塊中獲得規約後的 `qk_max`。然後我們需要將最終結果廣播給每個執行緒。
`exp_sum`¶
與 `qk_max` 類似,我們還需要從整個執行緒塊中獲取規約後的求和值。
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float val = __expf(logits[i] - qk_max);
logits[i] = val;
exp_sum += val;
}
...
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
首先,對每個執行緒組的所有 exp 值求和,同時將 `logits` 的每個條目從 `qk` 轉換為 `exp(qk - qk_max)`。請注意,這裡的 `qk_max` 已經是整個執行緒塊中的最大 `qk`。然後,我們可以像處理 `qk_max` 一樣,對整個執行緒塊進行 `exp_sum` 的規約。
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
logits[i] *= inv_sum;
}
最後,透過規約後的 `qk_max` 和 `exp_sum`,我們可以得到最終的歸一化 Softmax 結果作為 `logits`。這個 `logits` 變數將在後續步驟中用於與值資料進行點乘。現在,它應該儲存所有分配的上下文 token 的 `qk` 的歸一化 Softmax 結果。
值(Value)¶



現在我們需要檢索值資料並與 `logits` 進行點乘。與查詢和鍵不同,值資料沒有執行緒組的概念。如圖所示,與鍵 token 的記憶體佈局不同,來自同一列的元素對應於相同的值 token。對於一個值資料塊,有 `HEAD_SIZE` 行和 `BLOCK_SIZE` 列,它們被分成多個 `v_vecs`。
每個執行緒總是從相同數量的 `V_VEC_SIZE` 個 token 中一次獲取 `V_VEC_SIZE` 個元素。因此,單個執行緒透過多次內部迭代從不同行和相同列中檢索多個 `v_vec`。對於每個 `v_vec`,它需要與相應的 `logits_vec` 進行點乘,後者也是 `logits` 中的 `V_VEC_SIZE` 個元素。總的來說,透過多次內部迭代,每個 warp 將處理一個值 token 塊。透過多次外部迭代,將處理整個上下文的值 token。
float accs[NUM_ROWS_PER_THREAD];
for ... { // Iteration over different blocks.
logits_vec = ...
for ... { // Iteration over different rows.
v_vec = ...
...
accs[i] += dot(logits_vec, v_vec);
}
}
如以上虛擬碼所示,在外層迴圈中,與 `k_ptr` 類似,`logits_vec` 遍歷不同的塊並從 `logits` 中讀取 `V_VEC_SIZE` 個元素。在內層迴圈中,每個執行緒從相同的 token 中讀取 `V_VEC_SIZE` 個元素作為 `v_vec` 並執行點乘。重要的是要注意,在每次內部迭代中,執行緒為相同的 token 獲取不同的頭位置元素。點乘結果隨後累積在 `accs` 中。因此,`accs` 的每個條目都對映到分配給當前執行緒的頭位置。
例如,如果 `BLOCK_SIZE` 為 16 且 `V_VEC_SIZE` 為 8,則每個執行緒一次為 8 個 token 獲取 8 個值元素。每個元素都來自不同 token 的相同頭位置。如果 `HEAD_SIZE` 為 128 且 `WARP_SIZE` 為 32,則在每個內迴圈中,一個 warp 需要獲取 `WARP_SIZE * V_VEC_SIZE = 256` 個元素。這意味著一個 warp 總共需要進行 128 * 16 / 256 = 8 次內部迭代來處理一個完整的值 token 塊。每個執行緒中的 `accs` 包含 8 個元素,這些元素累積在 8 個不同的頭位置。對於執行緒 0,`accs` 變數將包含 8 個元素,它們是值頭中的第 0、第 32 … 第 224 個元素,這些元素從所有分配的 8 個 token 中累積而來。
LV¶
現在,我們需要在每個 warp 內對 `accs` 執行規約。此過程允許每個執行緒累積一個塊中所有 token 的分配頭位置的 `accs`。
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
float acc = accs[i];
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
}
accs[i] = acc;
}
接下來,我們對所有 warp 的 `accs` 執行規約,從而使每個執行緒都能累積所有上下文 token 的分配頭位置的 `accs`。請注意,每個執行緒中的每個 `accs` 只儲存整個頭在所有上下文 token 中的一部分元素的累積。但是,總的來說,所有輸出結果都已計算出來,只是儲存在不同的執行緒暫存器記憶體中。
程式碼
float* out_smem = reinterpret_cast<float*>(shared_mem);
for (int i = NUM_WARPS; i > 1; i /= 2) {
// Upper warps write to shared memory.
...
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
...
dst[row_idx] = accs[i];
}
// Lower warps update the output.
const float* src = &out_smem[warp_idx * HEAD_SIZE];
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
...
accs[i] += src[row_idx];
}
// Write out the accs.
}
輸出¶
現在我們可以將所有計算結果從區域性暫存器記憶體寫入最終的輸出全域性記憶體。
scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
+ head_idx * max_num_partitions * HEAD_SIZE
+ partition_idx * HEAD_SIZE;
首先,我們需要定義 `out_ptr` 變數,它指向分配的序列和分配的頭的起始地址。
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
from_float(*(out_ptr + row_idx), accs[i]);
}
}
最後,我們需要遍歷不同的分配頭位置,並根據 `out_ptr` 寫入相應的累積結果。