Tag: Transformers

  • KV Caching in LLMs Explained: Faster Inference, Lower Cost, and How It Actually Works

    KV Caching in LLMs Explained: Faster Inference, Lower Cost, and How It Actually Works

    KV caching in LLMs is one of the most important (and most misunderstood) reasons chatbots can stream tokens quickly. If you’ve ever wondered why the first response takes longer than the next tokens, or why long chats get slower and more expensive over time, KV cache is a big part of the answer.

    In this guide, I’ll explain KV caching from first principles, connect it to real serving behavior (prefill vs decode, batching, tail latency), show the memory math so you can estimate cost, and then give practical tactics to reduce latency and GPU spend in production.

    KV caching in LLMs

    TL;DR

    • KV cache stores attention keys and values for previously processed tokens so you don’t recompute them on every generation step.
    • It dramatically speeds up decode (token-by-token generation). It does not remove the cost of prefill (prompt processing).
    • The tradeoff: KV cache uses a lot of GPU memory and grows linearly with context length (prompt + generated tokens).
    • At scale, long-context workloads are often memory-bandwidth bound, not compute-bound.
    • Serving stacks use continuous batching, prefix caching, and paged KV cache to keep throughput high without memory fragmentation.

    Table of contents

    What is KV cache in LLMs?

    Most modern LLMs are Transformer decoder models. They generate text autoregressively: one token at a time. At generation step t, the model needs to condition on all tokens 1..t-1 that came before. That conditioning happens through self-attention.

    In self-attention, each layer projects the hidden states into queries (Q), keys (K), and values (V). The current token’s query compares against the past keys to compute attention weights, and those weights are used to combine the past values. That gives the current token a context-aware representation, which eventually produces the next-token logits.

    KV caching is the idea of storing the keys and values for past tokens so you can reuse them in later steps, instead of recomputing them from scratch at every generation step.

    This is why it’s called “KV” cache (not “QKV” cache): queries depend on the current token, so they must be recomputed every step; keys and values for prior tokens can be reused.

    Quick refresher: Q, K, V in attention (no fluff)

    If you haven’t looked at attention math in a while, here’s the short version. Given hidden states X, attention computes:

    Q = X Wq
    K = X Wk
    V = X Wv
    Attention(X) = softmax(Q K^T / sqrt(d)) V

    During decoding, you only need the new token’s query (Q_new), but you need the keys/values for all previous tokens. Without caching, you would repeatedly compute K and V for the whole history, which is pure redundancy.

    Prefill vs decode: where KV cache helps

    In serving, you’ll often hear two phases:

    • Prefill (prompt processing): run the model over the entire prompt once. This creates the initial KV cache for all prompt tokens.
    • Decode (generation): generate output tokens one-by-one, reusing KV cache and appending new K/V each step.

    KV caching helps decode massively, because decode would otherwise redo the same work at every token step. But prefill is still expensive: you must process every prompt token at least once. That’s why “long prompt” apps often feel slow even on strong GPUs.

    In real systems, prefill latency often correlates strongly with prompt length, while decode latency correlates strongly with (a) output tokens and (b) memory bandwidth pressure from reading the growing KV cache.

    How KV caching works (step-by-step)

    Let’s walk through a single request. Assume the user prompt has N tokens and you will generate M tokens.

    Step 1: Prefill builds the initial cache

    The model processes all N prompt tokens. For each layer ℓ, it computes K_ℓ and V_ℓ for tokens 1..N and stores them in GPU memory. After prefill, you have a KV cache that represents the entire prompt, at every layer.

    Step 2: Decode uses the cache and appends to it

    To generate token N+1:

    • Compute hidden state for the new token.
    • Compute Q_new, K_new, V_new for the new token at each layer.
    • Compute attention using Q_new over all cached K (prompt tokens) and produce a weighted sum over cached V.
    • Append K_new and V_new to the cache.

    Then repeat for token N+2, N+3, … until you generate M tokens (or hit a stop condition). The cache grows from N tokens to N+M tokens.

    Why KV caching is faster (intuition + complexity)

    KV caching saves you from recomputing K/V projections for old tokens at each decode step. That doesn’t just reduce FLOPs—it also improves practical throughput because projection layers (and the memory traffic associated with them) would be repeated unnecessarily.

    However, KV caching does not make attention free. At step t, the model must still read the cached K/V for tokens 1..t-1. So decode time grows with context length: longer context = more KV reads per token. That’s the core reason long chats slow down.

    This is why you’ll sometimes hear: “prefill is compute-heavy; decode becomes memory-bound.” It’s not universally true, but it’s a good rule of thumb for long-context workloads.

    KV cache memory math (with a worked example)

    If you’re trying to understand cost, you need a back-of-the-envelope estimate. A simplified KV cache size per token is:

    KV bytes per token ≈ 2 * layers * kv_heads * head_dim * bytes_per_element

    Total KV bytes for a sequence of T tokens is that number multiplied by T (and then multiplied by batch size if you have multiple concurrent sequences).

    Worked example (order-of-magnitude, not exact): Suppose a model has:

    • 32 layers
    • 8 KV heads (e.g., with GQA)
    • head_dim = 128
    • dtype = FP16 or BF16 (2 bytes per element)

    Then KV bytes per token ≈ 2 * 32 * 8 * 128 * 2 = 131,072 bytes ≈ 128 KB per token.

    If your context length is 4,096 tokens, that’s ~512 MB of KV cache for one sequence. If you serve 10 such sequences concurrently on the same GPU, you’re already at ~5 GB of KV cache just for attention memory, before model weights, activations, fragmentation overhead, and runtime buffers.

    Again, this is simplified (actual implementations pack tensors, use different layouts, sometimes store additional buffers). But the point stands: long context is expensive primarily because KV cache is expensive.

    Why GQA/MQA reduces KV cache size

    Classic multi-head attention has many heads and each head has its own K/V. Newer architectures often use GQA (grouped-query attention) or MQA (multi-query attention), where you have many query heads but fewer KV heads shared across them.

    KV cache size scales with the number of KV heads, not the number of query heads. So moving from (say) 32 KV heads to 8 KV heads can reduce KV cache memory by 4× for the same sequence length. That’s a massive win for long-context serving.

    KV cache in production serving: batching, throughput, tail latency

    In production, you rarely serve a single request at a time. You’re doing scheduling across many users, each with different prompt lengths, output lengths, and response-time expectations.

    Continuous batching: why it exists

    Traditional batching assumes all sequences are the same length (or padded to the same length). But generation is dynamic: some users stop early, others generate long outputs, and new requests arrive continuously. Continuous batching lets the server merge compatible decode steps across requests, improving GPU utilization and throughput.

    The challenge: KV cache allocation for many sequences is messy. That’s where paged KV cache becomes critical, because it avoids allocating huge contiguous buffers per request and reduces fragmentation.

    Why tail latency spikes

    When the server approaches KV memory limits, it must reduce batch size, evict caches, or reject/queue requests. This is often visible as tail latency spikes: p95/p99 get worse before average latency looks terrible. Long-context users can also create “noisy neighbor” effects where they consume disproportionate KV capacity.

    That’s why many production stacks implement context-length tiering, separate pools for long/short requests, and per-tenant limits.

    Modern KV cache techniques (prefix, paged, sliding window, quantized)

    1) Prefix caching (a.k.a. prompt caching)

    If many requests share the same prefix (system prompt, policy text, tool schemas, few-shot examples), you can cache the KV for that prefix and reuse it. This turns repeated prefill into a one-time cost, and it can significantly reduce latency and GPU time for agent-style applications.

    The biggest gotcha is that small differences in the prefix (dynamic timestamps, per-user IDs, slightly different tool schemas) can break reuse. The practical solution is to version your system prompt and keep it stable for long periods.

    2) Paged KV cache (block-based allocation)

    Paged KV cache allocates KV memory in fixed-size blocks/pages. When a sequence grows, you allocate more blocks. When a sequence ends, you return blocks to a free list. This is a big deal for high-throughput serving because it reduces fragmentation and makes continuous batching stable under mixed workloads.

    3) Sliding window attention / context truncation

    Some models support sliding window attention where the model attends only to the last N tokens. Serving systems can also implement truncation policies (keep the last N tokens of chat history). This caps KV cache growth. The tradeoff is losing direct attention to earlier context, so you may need summarization or retrieval to preserve important information.

    4) Quantized KV cache

    KV cache can be quantized to reduce memory and bandwidth. This is especially attractive for long-context workloads where KV dominates runtime. The tradeoff is potential quality loss (especially for tasks sensitive to long-range dependencies). Many stacks treat KV quantization as an opt-in knob you enable when you hit memory ceilings.

    Common anti-patterns that explode KV cost

    If your LLM bill feels unreasonable, one of these is usually happening:

    • Stuffing the entire chat transcript every turn (unbounded history) instead of summarizing or retrieving.
    • Huge system prompts that change slightly every request (breaking prefix caching).
    • Overusing few-shot examples in production paths where they don’t materially change quality.
    • Returning extremely long outputs by default (high max_tokens with no guardrails).
    • One-size-fits-all context limits (everyone gets the maximum context) instead of tiering.

    The blunt truth: the cheapest optimization is usually to reduce tokens, not to optimize kernels. Kernel optimizations matter, but product-level token discipline often dominates.

    How to monitor KV pressure in production

    To manage KV caching, you need the right metrics. At minimum, you want to track:

    • Prompt tokens (prefill tokens) per request
    • Generated tokens (decode tokens) per request
    • Prefill latency vs decode latency (time-to-first-token vs tokens/sec)
    • Active sequences and average context length on each GPU
    • KV cache utilization (if your serving stack exposes it)
    • p95/p99 latency to catch memory pressure early

    If you already instrument your agent pipeline, make sure you attribute tokens to their sources: system prompt, chat history, retrieval chunks, tool outputs. Otherwise, you can’t fix the real cause.

    Practical checklist to cut latency and cost

    Here’s a high-ROI checklist that works for most LLM products:

    • Make time-to-first-token (TTFT) a first-class SLO. TTFT is dominated by prefill. If TTFT is high, your prompt is probably too big.
    • Summarize aggressively. Replace older chat turns with a rolling summary + a small window of recent turns.
    • Use retrieval, not transcript stuffing. Bring only relevant documents into context.
    • Stabilize your system prompt. Version it. Don’t inject dynamic data into it if you want prefix caching benefits.
    • Cap outputs with intent-based limits. Don’t let every request generate thousands of tokens unless it’s a “long-form” mode.
    • Tier context length. Default to smaller contexts; allow larger contexts for premium workflows.
    • Pick architectures that are KV-efficient. Prefer models with GQA/MQA when long context is a core feature.
    • Separate long-context traffic. If you can, route long-context requests to dedicated GPUs/pools so they don’t degrade the experience for short requests.

    FAQ

    Does KV cache help training?

    KV caching is primarily an inference optimization. Training typically processes many tokens in parallel and uses different memory strategies (activations, gradients), so KV cache is mainly discussed for serving and generation.

    Why does the first token take longer?

    Because the model must run prefill over the full prompt to build the initial KV cache. That initial pass dominates “time-to-first-token”. After that, decode can reuse KV cache and generate tokens faster.

    Why do long chats get slower over time?

    Because the KV cache grows with every token. Each decode step must read more cached keys/values, increasing memory traffic. At scale, this reduces throughput and increases tail latency.

    Is KV cache the same as “prompt caching”?

    Prompt/prefix caching is a strategy built on top of KV cache: you persist the KV cache for a common prefix and reuse it across requests. KV cache itself exists within a single request as generation proceeds.

    Tools & platforms (official + GitHub links)

    If you’re implementing or relying on KV caching in real serving systems, these projects are worth knowing:

    • vLLM (popular high-throughput serving; paged attention/KV ideas): GitHub
    • Hugging Face TGI (Text Generation Inference): GitHub
    • NVIDIA TensorRT-LLM (optimized inference stack): GitHub
    • SGLang (serving/runtime for LLM apps): GitHub

    Extra depth: what “memory-bound” really means for KV cache

    People often say “decode is memory-bound.” Concretely, that means the GPU spends more time waiting on memory reads/writes than performing arithmetic. With KV caching, each generated token requires reading a growing amount of cached K/V. As the sequence length increases, the ratio of memory traffic to compute increases. Eventually, the bottleneck is not how fast your GPU can multiply matrices—it’s how fast it can move KV data.

    This is also why improvements like FlashAttention (and related attention kernels) matter: they reduce memory traffic by fusing operations and avoiding writing large intermediate matrices. Even with a KV cache, attention still involves substantial memory movement; kernel-level optimizations can help, but they can’t change the fundamental scaling: longer context means more KV reads.

    Designing product features around KV cache economics

    KV cache is one of those infrastructure details that should influence product decisions. A few examples of “infra-aware product design”:

    • “Long-form mode”: only allow very high max_tokens when the user explicitly opts in, so your default mode stays efficient.
    • “Memory” as structured state: store stable user preferences or facts in a database and inject only the relevant pieces, rather than replaying the full conversation forever.
    • Conversation summarization cadence: summarize every K turns and replace older turns with the summary, keeping the active context bounded.
    • Context tiering by plan: if you’re commercial, selling bigger context windows is effectively selling more GPU memory time—price it accordingly.

    These decisions reduce average context length, which reduces KV cache footprint, which increases throughput, which lowers cost. It’s a straight line from product UX to GPU economics.

    Final takeaway

    KV caching in LLMs is simple in concept—store K and V once, reuse them—but its implications ripple through everything: latency profiles, throughput, memory fragmentation, and even business pricing. If you’re serious about serving LLMs at scale, understanding KV cache is non-negotiable.

    Debugging real-world performance: a simple prefill/decode checklist

    When users complain “the model is slow,” it’s helpful to separate the complaint into two measurable symptoms:

    • Slow time-to-first-token (TTFT) → usually a prefill problem (too many prompt tokens, cold start, too much retrieval, too big system prompt).
    • Slow tokens-per-second (TPS) during streaming → often a decode problem (KV cache is large, server is overloaded, memory bandwidth limits, batch scheduling).

    Once you measure TTFT and TPS separately, you can make the correct fix instead of guessing. For example, reducing prompt size will improve TTFT a lot, but it might not change TPS much if the decode bottleneck is memory bandwidth. Conversely, paged KV cache and better batching can improve TPS under load but won’t make a huge difference if your prompt is 12k tokens long for every request.

    Another worked example: why “just increase context” has a hidden cost

    Imagine your product currently uses a 4k context and feels fine. A stakeholder asks: “Can we ship 16k context? Competitors have it.” From a user perspective, more context sounds like a pure win. From a KV cache perspective, it’s a 4× increase in the worst-case KV footprint per sequence.

    Even if most users don’t hit 16k, the long-tail users who do will:

    • increase TTFT (more prefill tokens)
    • reduce throughput for everyone sharing the GPU (more KV reads per decode step)
    • increase OOM risk and fragmentation pressure (bigger KV allocations)

    A pragmatic approach is to ship long context as a bounded, managed feature: enable it only on certain workflows, isolate it to specific GPU pools, and combine it with summarization/retrieval so you aren’t paying full KV cost every turn.

    KV cache and multi-turn chat: what to do instead of infinite history

    For chat products, the naive implementation is: every time the user sends a message, you send the entire conversation transcript back to the model. That works… until it doesn’t. KV cache grows, latency drifts upward, and your cost per conversation rises over time.

    Common alternatives that keep quality high while controlling KV size:

    • Rolling summary: maintain a summary of older turns (updated every few turns) and include only the last few raw turns.
    • Semantic memory: store extracted facts/preferences as structured data; inject only what’s relevant to the current query.
    • RAG for long history: retrieve only the most relevant past snippets rather than replaying everything.
    • Hard window: keep only the last N tokens of conversation (simple, but can lose important context unless paired with summarization).

    All of these approaches share a theme: they bound prompt size and therefore bound KV cache growth. In most production settings, that’s the difference between “works in a demo” and “works at scale.”