Skip to content

KV Cache 与批处理 — 练习

练习 1:计算 KV Cache 的内存占用

对于一个 LLaMA-7B 模型(32 层, 32 heads, 4096 hidden, F16)和上下文长度 4096,计算 KV Cache 需要多少内存。

参考答案

每层 KV Cache 大小:

  • K: 4096 positions × (4096 hidden / 32 heads × 32 heads) × 2 bytes (F16) = 4096 × 4096 × 2 = 33.5 MB
  • V: 同样 33.5 MB
  • 每层合计: 67 MB
  • 32 层合计: 32 × 67 = 2,144 MB ≈ 2.1 GB

使用 GQA (如 8 KV heads):33.5 × 2 × (8/32) × 32 = 536 MB

练习 2:理解 Batch 的序列管理

阅读 llama-batch.cpp,理解如何构建一个包含多个序列的 batch(parallel decoding 场景)。

参考答案

多序列 batch 的构建:

c
// 序列 0: 生成 token A
// 序列 1: 生成 token B
struct llama_batch batch = llama_batch_init(2);
batch.token[0] = token_A;
batch.pos[0] = seq0_pos;
batch.seq_id[0] = {0};  // 属于序列 0
batch.token[1] = token_B;
batch.pos[1] = seq1_pos;
batch.seq_id[1] = {1};  // 属于序列 1
batch.n_tokens = 2;

两个 token 会在一次 llama_decode() 中并行计算,各自更新对应序列的 KV Cache。

练习 3:Cache 前缀复用

分析 llama.cpp 如何在连续对话中复用 KV Cache。当一个已有 100 个 token cache 的对话继续生成时,cache 如何更新?

参考答案

连续对话的 cache 复用:

  1. 第一轮对话后,KV Cache 包含 100 个位置的数据
  2. 第二轮输入新的 prompt + 生成的回复(假设 50 个 token)
  3. 首先检查前缀匹配:前 100 个 token 可能已经 cache
  4. 新 token 从 position 100 开始追加
  5. 只需要对新 token 执行 decode,已有的 K/V 从 cache 读取
  6. 这就是 "prefix caching" 优化

相关 API:

  • llama_kv_cache_seq_keep() — 保留指定序列的 cache
  • llama_kv_cache_seq_shift() — 移动 cache 位置

拓展挑战

  • 实现一个简单的多轮对话程序,验证 KV Cache 复用
  • 对比不同上下文长度下的内存占用
  • 阅读 Flash Attention 在 CUDA 后端中的实现