These two phases have pretty different performance characteristics - prefill can really maximize GPU memory. For long contexts, its can be nigh impossible to do it all in a single pass - frameworks like vLLM use a technique called "chunked prefill".
The decode phase is compute intensive, but tends not to maximize GPU memory.
If you are serving these models, you really want to be able to have larger batch sizes during inference, which can only really come with scale - for a smaller app, you won't want to make the user wait that long.
So, long contexts only have to be processed _once_ per inference, which is basically a scheduling problem.
But the number of decode passes scales linearly with the output length. If it was unlimited, you could get some requests just _always_ present in an inference batch, reducing throughput for everyone.
Decode speed is generally memory bandwidth bound. Prefill is typically arithmetic bound. This is the reason for mixed batches (both decode and prefill) - it let's you saturate both memory and arithmetic.
Chunked prefill is for minimizing latency for decode entries in the same batch. It's not needed if you have only one request - in that case it's the fastest to just prefill in one chunk.
I'm pretty sure the sibling comment is right about different length limits - it's because of training and model talking nonsense if you let too long.
Chunked prefill or some similar technique is also necessary for serving long context requests where there is not enough GPU memory available, regardless of concerns about latency.
For example, consider a prompt sent to Llama 3.1 405B that uses 128k input tokens.
The KV cache will be 123GB. No matter how many GPUs you shard the model across, you are not fitting that KV cache in GPU memory (a H100 has 80GB)
You can do tensor parallelism 8 ways (8 KV heads). You can also do pipeline parallelism (there is 126 layers). Either way would work. A million tokens is possible just very slow.
Also, 405b has 8 KV heads of 128 size (hidden_size/num_attention_heads) times 126 layers [0] times 2 (K and V) times 2 bytes (bf16) is 504k per token. At FP8 it's 252k.
It is also a training issue. The model has to be trained to reinforce longer outputs, which has a quadratic train-time cost and requires suitable long-context response training data.
They definitely have to be trained to reinforce longer outputs, but I do not believe this adequately explains the low-ish generation limits.
We are starting to see models with longer and longer generation limits (gpt-4o-mini having 16k, the o1 models going up to 64k), as well as longer and longer context limits (often 128k, google offering a million).
I find it very unlikely they are actually training with inputs or outputs near these maximums.
If you want to convince yourself, do the attention calculation math for these sequence lengths.
You can also see how openai restricts the sequence length for fine tuning to 64k - almost certainly bound by available GPU sizes
I suspect the 4096 limits have been set as a "reasonable" limit for a myriad of reasons.
The first phase is referred to as "prefill", where the input is processed to create the KV Cache.
After that phase, the "decode" phase is called auto-regressively. Each decode yields one new token.
This post on [Inference Memory Requirements](https://huggingface.co/blog/llama31#inference-memory-require...) is quite good.
These two phases have pretty different performance characteristics - prefill can really maximize GPU memory. For long contexts, its can be nigh impossible to do it all in a single pass - frameworks like vLLM use a technique called "chunked prefill".
The decode phase is compute intensive, but tends not to maximize GPU memory.
If you are serving these models, you really want to be able to have larger batch sizes during inference, which can only really come with scale - for a smaller app, you won't want to make the user wait that long.
So, long contexts only have to be processed _once_ per inference, which is basically a scheduling problem.
But the number of decode passes scales linearly with the output length. If it was unlimited, you could get some requests just _always_ present in an inference batch, reducing throughput for everyone.