Skip to main content

KV Cache

KV Cache は、LLM の autoregressive decoding で過去 token の Key / Value を保存し、次 token 生成時に再利用する推論最適化です。過去 token の Key / Value は同じ入力から何度計算しても同じなので、毎 step 再計算せず、GPU memory に置いておきます。

KV cache flow

自作概念図。Prefill で prompt 全体の Key / Value を cache し、decode では新しい token の Query / Key / Value だけを計算して、過去の Key / Value は cache から読みます。

なぜ必要か

Decoder-only LLM は、次 token を一つずつ生成します。Token t+1t+1 を出すとき、最新 token の Query は一つだけでよいですが、attention では過去すべての Key / Value が必要です。

Attention(qt+1,K1:t+1,V1:t+1)=softmax(qt+1K1:t+1dk)V1:t+1\mathrm{Attention}(q_{t+1}, K_{1:t+1}, V_{1:t+1}) = \mathrm{softmax}\left(\frac{q_{t+1}K_{1:t+1}^\top}{\sqrt{d_k}}\right)V_{1:t+1}

過去 token の K1:tK_{1:t}V1:tV_{1:t} はすでに計算済みです。KV cache は、それらを保存しておくことで、各 step で過去 token の projection をやり直す無駄をなくします。

Prefill と Decode

LLM inference は大きく二段階に分かれます。

Phase内容ボトルネック
Prefillprompt 全体を一度に処理し、KV cache を作るcompute-bound になりやすい
Decode生成 token を一つずつ追加し、cache を読みながら生成するmemory bandwidth / latency-bound になりやすい

最初の token が出るまでの時間は Time To First Token (TTFT) と呼ばれます。Prompt が長いほど prefill が重くなり、TTFT が伸びます。一方、cache ができた後は、各 step で新しい token だけを処理するため、token が stream される速度は速くなります。

何を cache するのか

各 transformer layer で、各 token の Key と Value を保存します。

Query は最新 token のものだけで足ります。Key と Value は過去 token のものをすべて使います。

Memory cost

KV cache の memory は、概念的には次の量に比例します。

KV cache2×layers×seq length×nkv heads×dhead×batch\mathrm{KV\ cache} \propto 2 \times \mathrm{layers} \times \mathrm{seq\ length} \times n_{kv\ heads} \times d_{head} \times \mathrm{batch}

先頭の 2 は Key と Value の二つを保存するためです。Context length を 2 倍にすると、cache memory もほぼ 2 倍になります。Concurrent request が増えると、model weight より KV cache が memory bottleneck になる場合があります。

GQA と MQA

KV cache の memory を減らす代表的な設計が Grouped-Query Attention (GQA) と Multi-Query Attention (MQA) です。

手法考え方効果
MHAQuery head と同数の Key / Value head を持つ品質は高いが cache が重い
GQA複数 Query head が同じ Key / Value head group を共有するcache を減らしつつ品質を保ちやすい
MQAすべての Query head が単一の Key / Value head を共有するcache が非常に軽いが品質 trade-off がある

PagedAttention との関係

KV cache は長さが request ごとに違い、生成中に伸び続けます。単純に連続 memory を確保すると fragmentation が起きます。PagedAttention は KV cache を固定サイズ block に分け、OS の virtual memory のように管理します。これにより、serving 時の memory utilization と batching 効率が改善します。

何が速くなり、何が残るのか

KV cache は過去 token の Key / Value projection の再計算をなくします。ただし、最新 Query が過去すべての Key に attention するため、attention score の計算自体は sequence length に比例して残ります。

つまり、KV cache は compute を memory に交換する 技術です。高速になりますが、長 context と高並列 serving では GPU memory が制約になります。

数式で見る KV cache のメモリ量

KV cache は、各 layer、各 token、各 KV head について key と value を保存します。Batch size を BB、sequence length を TT、layer 数を LL、KV head 数を HkvH_{kv}、head dimension を dhd_h、1 要素あたりの byte 数を bb とすると、必要メモリは概念的に次のように書けます。

MemoryKV=2BTLHkvdhb\mathrm{Memory}_{KV}=2\cdot B\cdot T\cdot L\cdot H_{kv}\cdot d_h\cdot b

各項の意味は次の通りです。

  • 先頭の 22 は key と value の二つを保存するためです。
  • TT が長くなると、KV cache は線形に増えます。
  • HkvH_{kv} は KV head 数です。GQA / MQA はこの値を小さくして memory を削減します。
  • bb は fp16 / bf16 なら 2 bytes、fp8 なら 1 byte です。

この式の気持ちは、「decode では計算よりも過去 token の K/V を読み出す memory bandwidth が効きやすい」ということです。Long context serving で PagedAttention や KV quantization が重要になるのは、この線形に伸びる cache が GPU memory を圧迫するためです。

数式で見る GQA / MQA の削減効果

通常の multi-head attention では Hkv=HqH_{kv}=H_q です。一方、GQA では複数の query head が同じ key / value head を共有します。

Hkv=HqgH_{kv}=\frac{H_q}{g}

ここで、gg は query heads per KV head です。KV cache の memory は HkvH_{kv} に比例するため、理想的には約 gg 倍削減できます。

MemoryGQAMemoryMHA=HkvHq=1g\frac{\mathrm{Memory}_{GQA}}{\mathrm{Memory}_{MHA}}=\frac{H_{kv}}{H_q}=\frac{1}{g}

この式の気持ちは、「query は細かく分けて表現力を保ちつつ、key / value は共有して cache を軽くする」というものです。

関連ページ

主なソース