The chain rule behind autoregressive models

← Back to Blog

You’ve heard “autoregressive models factorize the joint distribution” and want a compact, practical explanation of what that means, why it works, and how it connects to training with cross-entropy.

Autoregressive (AR) models look mysterious until you notice they are built on a single, very old identity: the probability chain rule.

The probability chain rule (the whole trick)

For any sequence of random variables x1:n=(x1,x2,,xn)x_{1:n} = (x_1, x_2, \dots, x_n), the joint distribution can always be written as:

p(x1:n)=t=1np(xtx1:t1) p(x_{1:n}) = \prod_{t=1}^{n} p(x_t \mid x_{1:t-1})

This is not an approximation. It is a re-expression of the joint probability using conditional probabilities.

Two immediate consequences:

  • If you can model the conditionals p(xtx<t)p(x_t \mid x_{<t}), you can model the full joint p(x1:n)p(x_{1:n}).
  • You get a natural generative procedure: sample x1x_1, then sample x2x_2 conditioned on x1x_1, and so on.

That’s the definition of “autoregressive” in this context: the model predicts the next element conditioned on the previous ones.

Why the factorization matters for language models

For text, we typically define xtx_t as a token (word piece / subword) and train a model to output:

pθ(xtx<t) p_\theta(x_t \mid x_{<t})

A transformer language model is essentially a big conditional probability estimator that maps a prefix to a distribution over the next token.

The chain rule turns “model a complicated joint distribution over strings” into “repeat a simpler prediction task many times.”

A tiny concrete example

Consider a three-token sequence: (x1,x2,x3)(x_1, x_2, x_3). The chain rule gives:

p(x1,x2,x3)=p(x1)p(x2x1)p(x3x1,x2) p(x_1, x_2, x_3) = p(x_1)\,p(x_2 \mid x_1)\,p(x_3 \mid x_1, x_2)

A slightly more formal derivation

The chain rule follows by repeatedly applying the definition of conditional probability:

p(ab)=p(a,b)p(b)p(a,b)=p(ab)p(b) p(a \mid b) = \frac{p(a, b)}{p(b)}\quad\Rightarrow\quad p(a, b) = p(a \mid b)\,p(b)

For three variables:

p(x1,x2,x3)=p(x3x1,x2)p(x1,x2)=p(x3x1,x2)p(x2x1)p(x1) \begin{aligned} p(x_1, x_2, x_3) &= p(x_3 \mid x_1, x_2)\,p(x_1, x_2) \\ &= p(x_3 \mid x_1, x_2)\,p(x_2 \mid x_1)\,p(x_1) \end{aligned}

Generalizing gives:

p(x1:n)=p(x1)t=2np(xtx1:t1) p(x_{1:n}) = p(x_1)\,\prod_{t=2}^{n} p(x_t \mid x_{1:t-1})

This is the identity autoregressive models exploit.

The model never has to output p(x1,x2,x3)p(x_1, x_2, x_3) directly. It only needs to output three smaller distributions.

Training: maximum likelihood becomes “sum of next-token losses”

If the model defines the joint via the chain rule, then the log-likelihood of a sequence decomposes nicely:

logpθ(x1:n)=t=1nlogpθ(xtx<t) \log p_\theta(x_{1:n}) = \sum_{t=1}^{n} \log p_\theta(x_t \mid x_{<t})

So maximum likelihood training turns into maximizing the sum of the conditional log-probabilities across positions.

In practice we minimize the negative log-likelihood (NLL), which is exactly cross-entropy for a one-hot next-token target.

This is why a “language modeling loss” is typically implemented as “shift inputs right, predict the next token, compute cross entropy, average.”

Teacher forcing: why it’s so efficient

During training we usually feed the model the true prefix x<tx_{<t} (from the dataset) when predicting xtx_t. This is known as teacher forcing.

Benefits:

  • You can compute losses for all time steps in parallel (important for transformers).
  • The gradient signal is stable: you’re always conditioning on real context, not the model’s own mistakes.

The trade-off is a mismatch at generation time: at inference the model conditions on its own samples, which can compound errors. That mismatch is often discussed under names like exposure bias.

Sampling: the chain rule becomes an algorithm

Once you have pθ(xtx<t)p_\theta(x_t \mid x_{<t}), generation is just:

  1. Start with a prompt (maybe empty).
  2. Compute the next-token distribution.
  3. Sample (or take argmax).
  4. Append the token and repeat.

Different decoding methods (greedy, beam search, top-kk, nucleus/top-pp, temperature) are just different ways to turn that conditional distribution into an actual token choice.

A practical view: log-probs add, probabilities multiply

Because of the product, probabilities can get tiny fast. In code, you almost always work with log-probabilities:

import math

# Example: p(x1) = 0.2, p(x2|x1) = 0.5, p(x3|x1,x2) = 0.1
probs = [0.2, 0.5, 0.1]

logp = sum(math.log(p) for p in probs)
p_joint = math.exp(logp)

print("log p(x1:x3):", logp)
print("p(x1:x3):", p_joint)

This mirrors what frameworks compute: sum of token-level log-probs (or mean loss), not a direct joint probability.

“Chain rule” also shows up in backprop (but it’s a different one)

People sometimes conflate two “chain rules”:

  • Probability chain rule: factorizes a joint distribution into conditionals.
  • Calculus chain rule: propagates gradients through composed functions.

Autoregressive modeling relies on the probability chain rule. Autoregressive training (like most deep learning) relies on the calculus chain rule during backpropagation.

They are conceptually distinct, but both are the reason the whole pipeline is tractable:

  • the probability chain rule gives you a learnable, decomposable objective;
  • the calculus chain rule lets you optimize it with gradient descent.

The mental model I keep

An autoregressive model is:

  • a choice of ordering (left-to-right for text);
  • the probability chain rule;
  • a conditional model class (transformer, RNN, etc.);
  • maximum likelihood training (cross-entropy over next-token predictions).

Everything else—prompting, decoding tricks, RLHF-style fine-tuning—sits on top of that foundation.

Perplexity: the common metric for AR language models

Because log-likelihood decomposes into token-level terms, we can define the average negative log-likelihood per token:

NLL=1nt=1nlogpθ(xtx<t) \text{NLL} = -\frac{1}{n}\sum_{t=1}^{n} \log p_\theta(x_t \mid x_{<t})

Perplexity is just the exponentiated average NLL (with the same log base convention):

PPL=exp(NLL) \text{PPL} = \exp(\text{NLL})

Intuition:

  • Lower perplexity means the model assigns higher probability to the observed next tokens.
  • Perplexity is essentially “effective branching factor”: how many plausible next tokens the model is, on average, spreading probability mass over.

(When people report PPL, details matter: tokenization, log base, and whether the evaluation uses the same preprocessing as training.)

Tokens, not words: what is xtx_t in practice?

In modern LMs, xtx_t is almost never a whole word. It is typically a subword token from a vocabulary learned by BPE/Unigram.

That changes how you should read the chain rule:

  • The model factorizes probability over token sequences, not word sequences.
  • A single “word” may be 1 token or many tokens.
  • Reported metrics (loss/perplexity) are therefore tokenization-dependent.

Concretely, the same text can correspond to different nn (sequence length) under different tokenizers, which affects average loss and PPL.

A diagram of the AR factorization + generation loop

  graph TD;
  A[Training text: x1..xn] --> B[Shifted inputs: x1..x{n-1}]
  B --> C[Model outputs: p(x_t | x_<t)]
  C --> D[Cross-entropy vs target x_t]
  D --> E[Sum/mean over t => loss]

  F[Prompt: x1..xk] --> G[p(x_{k+1} | x_<=k)]
  G --> H[Decode: greedy / top-k / top-p / temp]
  H --> I[Sample token x_{k+1}]
  I --> F