Large language models are rapidly gaining popularity, but their slow response times often frustrate users, driving them toward less capable alternatives.
In this post, we’ll explore the reasons behind these delays and will explore an innovative decoding technique—Assisted Generation— that can dramatically improve performance, cutting latency by up to 10 times on standard hardware!
Understanding Text Generation
Text generation relies on an input text sequence consisting of previously generated words and proceeds with a forward pass. This pass feeds the input through the model’s layers, predicting the unnormalized log probabilities (logits) for the next token. A token could be a word, sub-word, or even an individual character.
Once logits are produced, strategies like greedy decoding (selecting the highest probability) or multinomial sampling help choose the next token.
Simplified Explanation: When we use a language model, we input a sequence of words, and it predicts the next word. This is the basic forward pass of the model. But there’s an alternative approach: we can input a full sentence and ask, “Would the model have predicted this sentence?” The model checks whether each word in the sequence fits, verifying that the generated sequence aligns with its internal logic.
Greedy Decoding with Assisted Generation
Imagine an assistant model that can instantly generate text similar to the main model, but operates in the background. While you can’t directly use it for final output, you can leverage it to suggest possible next words, allowing the main model to confirm those guesses with a forward pass.
In this perfect world, you wouldn’t need to wait for your main model to predict each word one by one (which normally takes time and depends on how long the sentence is). Instead, with the help of the assistant, you’d only need a single check to confirm everything.
Normally, generating a sentence of n tokens takes O(n) time (you generate one word, then the next, and so on, which takes longer as the sentence grows). With this magical assistant, you could generate everything in O(1) time (instant), no matter how long the sentence is. For really long sentences, this would save a lot of time.
In reality, this perfect latency-free model doesn’t exist, but the idea illustrates how much faster text generation could become with the right kind of helper!
Let’s use a simple example to explain this:
You’re trying to generate the sentence: “The cat sat on the mat.
Without the Assistant (O(n)):
- Input: Start with the token “The.”
- Prediction: The model predicts the next token (“cat”) by generating logits.
- Chaining: The process repeats for each new token:
- For “The cat,” it predicts “sat.”
- For “The cat sat,” it predicts “on,” and so on.
With the Assistant (O(1)):
- Assistant Suggests: The assistant provides a full sequence: “The cat sat on the mat.”
- Main Model Input: The main model takes this complete sequence as input in one go.
- Single Forward Pass: During this single forward pass, the model generates logits for each token position in the entire sequence simultaneously.
- Logits Generated (for all positions in one pass):
- For “The”: The model generates a logit vector for the first position.
- For “cat”: The model generates a logit vector for the second position.
- For “sat”: The model generates a logit vector for the third position.
- For “on”: The model generates a logit vector for the fourth position.
- For “the”: The model generates a logit vector for the fifth position.
- For “mat”: The model generates a logit vector for the sixth position.
Here’s an example of what the logits might look like after this single forward pass: Logits Output:
Position 0 (for “The”): [0.1, 2.0, 0.5, 0.3]
Position 1 (for “cat”): [0.05, 1.5, 1.0, 0.2]
Position 2 (for “sat”): [0.1, 0.8, 1.2, 0.4]
Position 3 (for “on”): [0.2, 0.4, 1.1, 0.3]
Position 4 (for “the”): [0.3, 0.6, 1.0, 0.7]
Position 5 (for “mat”): [0.2, 0.3, 0.4, 1.5]
Making Predictions
Now, the model checks each of these logits against the assistant’s suggestions in a single go. Prediction Process:
For Position 0 (“The”): The model looks at logits [0.1, 2.0, 0.5, 0.3] and sees “cat” (index 1) has the highest probability.
For Position 1 (“cat”): The model looks at logits [0.05, 1.5, 1.0, 0.2] and sees “sat” (index 1) has the highest probability.
For Position 2 (“sat”): The model looks at logits [0.1, 0.8, 1.2, 0.4] and sees “on” (index 2) has the highest probability.
For Position 3 (“on”): The model looks at logits [0.2, 0.4, 1.1, 0.3] and sees “the” (index 1) has the highest probability.
For Position 4 (“the”): The model looks at logits [0.3, 0.6, 1.0, 0.7] and sees “mat” (index 3) has the highest probability.
So here we observe :
Single Forward Pass: The key takeaway here is that the main model runs one forward pass for the entire sequence instead of separate passes for each word.
Efficiency: Because it processes everything in one go, it dramatically reduces the time complexity from O(n) (one pass for each token) to O(1) (one single pass for the entire sequence).
When the Assistant Makes Mistakes
Autoregressive Nature:
Language models predict one word at a time, with each word depending on the previous ones. If the assistant suggests a wrong word, subsequent predictions may also be wrong since they rely on earlier mistakes.
Iterative Correction:
Despite potential errors, the main model can correct the assistant’s mistakes by running another forward pass. For example, if the assistant incorrectly predicts “cat” instead of “dog,” the main model corrects it and reruns the generation starting from “The dog.” This process still saves time overall since fewer corrections are needed compared to generating every word from scratch.
Summary
In conclusion, Assisted Generation offers a powerful method to speed up text generation, even if the assistant model isn’t perfect. By leveraging quick suggestions and correcting errors iteratively, this approach significantly reduces latency compared to traditional word-by-word generation methods. The result is faster, more efficient text generation without sacrificing the integrity of the output.
How do we implement it?
You can find code for assisted decoding in the `huggingface-llama-recipes` repository.
https://github.com/huggingface/huggingface-llama-recipes
Happy coding !!!
References:
LikeLike
LikeLike
LikeLike
LikeLike
LikeLike