The chain rule behind autoregressive models
Autoregressive (AR) models look mysterious until you notice they are built on a single, very old identity: the probability chain rule.
The probability chain rule (the whole trick)
For any sequence of random variables , the joint distribution can always be written as:
This is not an approximation. It is a re-expression of the joint probability using conditional probabilities.
Two immediate consequences:
- If you can model the conditionals , you can model the full joint .
- You get a natural generative procedure: sample , then sample conditioned on , and so on.
That’s the definition of “autoregressive” in this context: the model predicts the next element conditioned on the previous ones.
Why the factorization matters for language models
For text, we typically define as a token (word piece / subword) and train a model to output:
A transformer language model is essentially a big conditional probability estimator that maps a prefix to a distribution over the next token.
The chain rule turns “model a complicated joint distribution over strings” into “repeat a simpler prediction task many times.”
A tiny concrete example
Consider a three-token sequence: . The chain rule gives:
A slightly more formal derivation
The chain rule follows by repeatedly applying the definition of conditional probability:
For three variables:
Generalizing gives:
This is the identity autoregressive models exploit.
The model never has to output directly. It only needs to output three smaller distributions.
Training: maximum likelihood becomes “sum of next-token losses”
If the model defines the joint via the chain rule, then the log-likelihood of a sequence decomposes nicely:
So maximum likelihood training turns into maximizing the sum of the conditional log-probabilities across positions.
In practice we minimize the negative log-likelihood (NLL), which is exactly cross-entropy for a one-hot next-token target.
This is why a “language modeling loss” is typically implemented as “shift inputs right, predict the next token, compute cross entropy, average.”
Teacher forcing: why it’s so efficient
During training we usually feed the model the true prefix (from the dataset) when predicting . This is known as teacher forcing.
Benefits:
- You can compute losses for all time steps in parallel (important for transformers).
- The gradient signal is stable: you’re always conditioning on real context, not the model’s own mistakes.
The trade-off is a mismatch at generation time: at inference the model conditions on its own samples, which can compound errors. That mismatch is often discussed under names like exposure bias.
Sampling: the chain rule becomes an algorithm
Once you have , generation is just:
- Start with a prompt (maybe empty).
- Compute the next-token distribution.
- Sample (or take argmax).
- Append the token and repeat.
Different decoding methods (greedy, beam search, top-, nucleus/top-, temperature) are just different ways to turn that conditional distribution into an actual token choice.
A practical view: log-probs add, probabilities multiply
Because of the product, probabilities can get tiny fast. In code, you almost always work with log-probabilities:
import math
# Example: p(x1) = 0.2, p(x2|x1) = 0.5, p(x3|x1,x2) = 0.1
probs = [0.2, 0.5, 0.1]
logp = sum(math.log(p) for p in probs)
p_joint = math.exp(logp)
print("log p(x1:x3):", logp)
print("p(x1:x3):", p_joint)This mirrors what frameworks compute: sum of token-level log-probs (or mean loss), not a direct joint probability.
“Chain rule” also shows up in backprop (but it’s a different one)
People sometimes conflate two “chain rules”:
- Probability chain rule: factorizes a joint distribution into conditionals.
- Calculus chain rule: propagates gradients through composed functions.
Autoregressive modeling relies on the probability chain rule. Autoregressive training (like most deep learning) relies on the calculus chain rule during backpropagation.
They are conceptually distinct, but both are the reason the whole pipeline is tractable:
- the probability chain rule gives you a learnable, decomposable objective;
- the calculus chain rule lets you optimize it with gradient descent.
The mental model I keep
An autoregressive model is:
- a choice of ordering (left-to-right for text);
- the probability chain rule;
- a conditional model class (transformer, RNN, etc.);
- maximum likelihood training (cross-entropy over next-token predictions).
Everything else—prompting, decoding tricks, RLHF-style fine-tuning—sits on top of that foundation.
Perplexity: the common metric for AR language models
Because log-likelihood decomposes into token-level terms, we can define the average negative log-likelihood per token:
Perplexity is just the exponentiated average NLL (with the same log base convention):
Intuition:
- Lower perplexity means the model assigns higher probability to the observed next tokens.
- Perplexity is essentially “effective branching factor”: how many plausible next tokens the model is, on average, spreading probability mass over.
(When people report PPL, details matter: tokenization, log base, and whether the evaluation uses the same preprocessing as training.)
Tokens, not words: what is in practice?
In modern LMs, is almost never a whole word. It is typically a subword token from a vocabulary learned by BPE/Unigram.
That changes how you should read the chain rule:
- The model factorizes probability over token sequences, not word sequences.
- A single “word” may be 1 token or many tokens.
- Reported metrics (loss/perplexity) are therefore tokenization-dependent.
Concretely, the same text can correspond to different (sequence length) under different tokenizers, which affects average loss and PPL.
A diagram of the AR factorization + generation loop
graph TD;
A[Training text: x1..xn] --> B[Shifted inputs: x1..x{n-1}]
B --> C[Model outputs: p(x_t | x_<t)]
C --> D[Cross-entropy vs target x_t]
D --> E[Sum/mean over t => loss]
F[Prompt: x1..xk] --> G[p(x_{k+1} | x_<=k)]
G --> H[Decode: greedy / top-k / top-p / temp]
H --> I[Sample token x_{k+1}]
I --> F