Skip to content

KV Cache 与批处理 — 代码走读

src/llama-memory.cpp — KV Cache 管理

核心数据结构

cpp
struct llama_kv_cache {
    // Cache 存储
    struct ggml_tensor * k_l;  // [n_kv_max, n_embd_k_gqa, n_layer]
    struct ggml_tensor * v_l;  // [n_kv_max, n_embd_v_gqa, n_layer]

    // 单元管理
    std::vector<struct llama_kv_cell> cells;
    uint32_t head;    // 最旧未使用位置
    uint32_t size;    // 已使用位置数
    uint32_t used;    // 活跃位置数

    // 序列追踪
    // 每个 cell 记录属于哪个序列、哪个位置
};

Cache 写入

cpp
// 在 decode 过程中,将新的 K/V 存入 cache
void llama_kv_cache_update(llama_context & lctx) {
    auto & kv = lctx.kv_self;
    for (int il = 0; il < n_layer; il++) {
        // 将新计算的 K 复制到 cache 对应位置
        ggml_backend_tensor_set(kv.k_l[il], k_data,
            offset, k_size);
        // 将新计算的 V 复制到 cache 对应位置
        ggml_backend_tensor_set(kv.v_l[il], v_data,
            offset, v_size);
    }
}

Cache 查找与复用

cpp
// 查找可以复用的 cache 前缀
// 例如 prompt "Hello world" 的 cache 可以在 "Hello world!" 中复用
int32_t llama_kv_cache_find_prefix(
    const llama_kv_cache & kv,
    const llama_pos * pos,
    int32_t n_tokens);

src/llama-batch.cpp — 批处理编码

cpp
struct llama_batch llama_batch_get_one(llama_token * tokens, int32_t n_tokens) {
    struct llama_batch batch = {
        .token    = tokens,
        .pos      = positions,   // 连续位置
        .n_seq_id = seq_count,   // 每个token的序列数
        .seq_id   = seq_ids,     // 序列ID
        .n_tokens = n_tokens,
    };
    return batch;
}

// 批量构建:同时处理多个序列
struct llama_batch llama_batch_get_token(llama_token token) {
    // 单 token batch,用于 decode
}

src/llama-context.cpp — 推理上下文

cpp
// 核心推理函数
int32_t llama_decode(struct llama_context * ctx, struct llama_batch batch) {
    // 1. 构建计算图
    auto * graph = llm_build_graph(*ctx, batch);

    // 2. 分配后端资源
    ggml_backend_alloc_graph(backend, graph);

    // 3. 执行计算
    ggml_backend_graph_compute(backend, graph);

    // 4. 获取输出 logits
    // logits 现在可用于采样
}

关键函数索引

函数文件说明
llama_decodellama-context.cpp执行一次 batch decode
llama_kv_cache_updatellama-memory.cpp更新 KV Cache
llama_kv_cache_clearllama-memory.cpp清空 cache
llama_batch_get_onellama-batch.cpp构建单序列 batch
llama_kv_cache_seq_rmllama-memory.cpp删除指定序列的 cache