KV Cache
KV Cache は、LLM の autoregressive decoding で過去 token の Key / Value を保存し、次 token 生成時に再利用する推論最適化です。過去 token の Key / Value は同じ入力から何度計算しても同じなので、毎 step 再計算せず、GPU memory に置いておきます。
自作概念図。Prefill で prompt 全体の Key / Value を cache し、decode では新しい token の Query / Key / Value だけを計算して、過去の Key / Value は cache から読みます。
なぜ必要か
Decoder-only LLM は、次 token を一つずつ生成します。Token を出すとき、最新 token の Query は一つだけでよいですが、attention では過去すべての Key / Value が必要です。
過去 token の と はすでに計算済みです。KV cache は、それらを保存しておくことで、各 step で過去 token の projection をやり直す無駄をなくします。
Prefill と Decode
LLM inference は大きく二段階に分かれます。
| Phase | 内容 | ボトルネック |
|---|---|---|
| Prefill | prompt 全体を一度に処理し、KV cache を作る | compute-bound になりやすい |
| Decode | 生成 token を一つずつ追加し、cache を読みながら生成する | memory bandwidth / latency-bound になりやすい |
最初の token が出るまでの時間は Time To First Token (TTFT) と呼ばれます。Prompt が長いほど prefill が重くなり、TTFT が伸びます。一方、cache ができた後は、各 step で新しい token だけを処理するため、token が stream される速度は速くなります。
何を cache するのか
各 transformer layer で、各 token の Key と Value を保存します。
Query は最新 token のものだけで足ります。Key と Value は過去 token のものをすべて使います。
Memory cost
KV cache の memory は、概念的には次の量に比例します。
先頭の 2 は Key と Value の二つを保存するためです。Context length を 2 倍にすると、cache memory もほぼ 2 倍になります。Concurrent request が増えると、model weight より KV cache が memory bottleneck になる場合があります。
GQA と MQA
KV cache の memory を減らす代表的な設計が Grouped-Query Attention (GQA) と Multi-Query Attention (MQA) です。
| 手法 | 考え方 | 効果 |
|---|---|---|
| MHA | Query head と同数の Key / Value head を持つ | 品質は高いが cache が重い |
| GQA | 複数 Query head が同じ Key / Value head group を共有する | cache を減らしつつ品質を保ちやすい |
| MQA | すべての Query head が単一の Key / Value head を共有する | cache が非常に軽いが品質 trade-off がある |
PagedAttention との関係
KV cache は長さが request ごとに違い、生成中に伸び続けます。単純に連続 memory を確保すると fragmentation が起きます。PagedAttention は KV cache を固定サイズ block に分け、OS の virtual memory のように管理します。これにより、serving 時の memory utilization と batching 効率が改善します。
何が速くなり、何が残るのか
KV cache は過去 token の Key / Value projection の再計算をなくします。ただし、最新 Query が過去すべての Key に attention するため、attention score の計算自体は sequence length に比例して残ります。
つまり、KV cache は compute を memory に交換する 技術です。高速になりますが、長 context と高並列 serving では GPU memory が制約になります。
数式で見る KV cache のメモリ量
KV cache は、各 layer、各 token、各 KV head について key と value を保存します。Batch size を 、sequence length を 、layer 数を 、KV head 数を 、head dimension を 、1 要素あたりの byte 数を とすると、必要メモリは概念的に次のように書けます。
各項の意味は次の通りです。
- 先頭の は key と value の二つを保存するためです。
- が長くなると、KV cache は線形に増えます。
- は KV head 数です。GQA / MQA はこの値を小さくして memory を削減します。
- は fp16 / bf16 なら 2 bytes、fp8 なら 1 byte です。
この式の気持ちは、「decode では計算よりも過去 token の K/V を読み出す memory bandwidth が効きやすい」ということです。Long context serving で PagedAttention や KV quantization が重要になるのは、この線形に伸びる cache が GPU memory を圧迫するためです。
数式で見る GQA / MQA の削減効果
通常の multi-head attention では です。一方、GQA では複数の query head が同じ key / value head を共有します。
ここで、 は query heads per KV head です。KV cache の memory は に比例するため、理想的には約 倍削減できます。
この式の気持ちは、「query は細かく分けて表現力を保ちつつ、key / value は共有して cache を軽くする」というものです。
関連ページ
- Self-Attention and QKV
- Transformer Architecture
- LLM Inference Optimization
- Long Context and Position Encoding
- Speculative Decoding
主なソース
- Avi Chawla, KV Caching in LLMs, Clearly Explained: https://r.jina.ai/http://r.jina.ai/http://https://x.com/_avichawla/status/2034902650534187503
- Attention Is All You Need: https://arxiv.org/abs/1706.03762
- PagedAttention / vLLM: https://arxiv.org/abs/2309.06180
- FlashAttention: https://arxiv.org/abs/2205.14135