AI

On-Policy Distillation is Reinforcement Learning in Disguise

Jun 08, 2026 · – views
On-Policy Distillation is Reinforcement Learning in Disguise

A mathematical deep dive into why knowledge distillation, done right, is secretly policy gradient training


There is a common assumption that knowledge distillation and reinforcement learning are fundamentally different training regimes. Distillation is "supervised" - you have a teacher, you minimize a divergence, you move on . RL is "exploratory" - you sample trajectories, collect rewards, update your policy. Clean separation. Different toolboxes.

This assumption is wrong, or at least much blurrier than it looks.

When you do knowledge distillation on-policy - meaning you sample outputs from the student's own current distribution rather than evaluating over a fixed dataset - you are, with one small architectural choice, running a policy gradient algorithm. The KL divergence between teacher and student stops being a supervised loss and starts being a dense, token-level reward signal. The student is not being taught. It is exploring, getting evaluated, and updating.

Recently I was reading an excellent blog post by Zhoutong , and was doing the derivation and later I thought it's good if I document all the steps and doubts that I had. So this post walks through the math carefully, building intuition at each step, so that by the end you can see the REINFORCE algorithm hiding inside the distillation objective.

1.

The two flavours of distillation

Start with the simplest possible distillation setup. You have a large teacher model pTp_T and a smaller student model pSθp_S^\theta that you want to train. The standard approach is to minimize the KL divergence between teacher and student over some fixed dataset of prompts XX:

Loffline(θ)=ExX[DKL(pT(x)    pSθ(x))] L_{\text{offline}}(\theta) = \mathbb{E}_{x \sim X}\Big[\mathcal{D}_{KL}\big(p_T(\cdot|x) \;\|\; p_S^\theta(\cdot|x)\big)\Big]

You collect prompts, run the teacher once to get soft targets, and train the student to match them. Static dataset. The student's output doesn't influence what gets trained on. This is off-policy distillation and it is entirely supervised - structurally no different from cross-entropy training with soft labels .

The on-policy version makes one change that looks minor but is anything but:

LOD(θ)=ExX[EypS(x)[DKL(pTpSθ)(yx)]] L_{OD}(\theta) = \mathbb{E}_{x \sim X}\Big[\mathbb{E}_{y \sim p_S(\cdot|x)}\big[\mathcal{D}_{KL}(p_T \| p_S^\theta)(y|x)\big]\Big]

Instead of evaluating the KL over a fixed set of outputs, you sample outputs from the student's current policy and evaluate the KL along those sampled trajectories . The student generates text from itself, and those self-generated sequences form the training signal.

This is a small change in notation. It is a large change in what is actually happening during training.

2.

Unpacking what the KL actually is here

Before going further, the notation DKL(pTpSθ)(yx)\mathcal{D}_{KL}(p_T \| p_S^\theta)(y|x) needs unpacking because it looks strange - KL divergence is usually between two distributions, not evaluated at a point yy.

Language models are autoregressive: they generate a sequence one token at a time. At each position tt in the sequence, both the teacher and student produce a full probability distribution over the vocabulary V\mathcal{V} (all ~50,000 possible next tokens), conditioned on everything that came before. So the natural per-token KL at position tt is:

Dt(θ)=vVpT(vy<t,x)logpT(vy<t,x)pSθ(vy<t,x) D_t(\theta) = \sum_{v \in \mathcal{V}} p_T(v \mid y_{<t}, x) \log \frac{p_T(v \mid y_{<t}, x)}{p_S^\theta(v \mid y_{<t}, x)}

This is a proper KL divergence - a number measuring how different the teacher and student distributions are at this specific point in the sequence.

The full sequence-level "KL" is just the sum of these per-token KLs along the trajectory:

DKL(pTpSθ)(yx)=t=1TDt(θ) \mathcal{D}_{KL}(p_T \| p_S^\theta)(y|x) = \sum_{t=1}^{T} D_t(\theta)

This decomposition is not an approximation. It falls directly out of the autoregressive factorization: since logp(y)=tlogp(yty<t)\log p(y) = \sum_t \log p(y_t | y_{<t}), the KL inherits additive structure along the token positions.

So the full on-policy objective is:

LOD(θ)=ExXEypS(x) ⁣[t=1TDt(θ)] L_{OD}(\theta) = \mathbb{E}_{x \sim X}\,\mathbb{E}_{y \sim p_S(\cdot|x)}\!\left[\sum_{t=1}^{T} D_t(\theta)\right]

So in short, we sample a trajectory from the student, and for each token position in that trajectory, we compute how far the student's next-token distribution is from the teacher's. We sum those distances and minimize. On-policy. Dense signal. Every position in every sampled trajectory contributes.

3.

The stop-gradient - the architectural choice that changes everything

Here is the critical decision: do not backpropagate through the sampling distribution pS(x)p_S(\cdot|x).

The student model appears in two places in the objective:

  • As the sampler - pS(x)p_S(\cdot|x) generates the trajectory yy
  • As the trainable distribution - pSθp_S^\theta appears inside the KL

The stop-gradient says: treat the first role as a black box. Samples are drawn from the student, but no gradient flows back through the act of sampling. Only the second role - the pSθp_S^\theta inside the KL - receives gradients.

In PyTorch terms: y = student.sample(x).detach(), then compute the KL using θ\theta as normal.

This looks like a minor implementation detail. It is actually what makes the method mathematically equivalent to a policy gradient algorithm. To see why, we need to look at what happens with and without this choice .

4.

The gradient, computed explicitly

Write the objective in its expanded form, treating pS(x)p_S(\cdot|x) as the sampling distribution and D(y;θ)=tDt(θ)D(y;\theta) = \sum_t D_t(\theta) as the loss evaluated at trajectory yy:

LOD(θ)=xp(x)ypS(yx)D(y;θ) L_{OD}(\theta) = \sum_x p(x) \sum_y p_S(y|x)\, D(y;\theta)

Now take the gradient. If we did not apply stop-gradient - treating pS(x)p_S(\cdot|x) as pSθ(x)p_S^\theta(\cdot|x) throughout - the product rule gives two terms:

θL=xp(x)y[θpSθ(yx)D(y;θ)  +  pSθ(yx)θD(y;θ)] \nabla_\theta L = \sum_x p(x)\sum_y \Big[\nabla_\theta p_S^\theta(y|x) \cdot D(y;\theta) \;+\; p_S^\theta(y|x) \cdot \nabla_\theta D(y;\theta)\Big]

Apply the log-derivative trick θpSθ=pSθθlogpSθ\nabla_\theta p_S^\theta = p_S^\theta \nabla_\theta \log p_S^\theta to the first term and fold both back into expectations:

θL=EypSθ ⁣[D(y;θ)θlogpSθ(yx)](A) score-function / REINFORCE term+EypSθ ⁣[θD(y;θ)](B) pathwise / direct term \nabla_\theta L = \underbrace{\mathbb{E}_{y \sim p_S^\theta}\!\Big[D(y;\theta)\,\nabla_\theta \log p_S^\theta(y|x)\Big]}_{\text{(A) score-function / REINFORCE term}} + \underbrace{\mathbb{E}_{y \sim p_S^\theta}\!\Big[\nabla_\theta D(y;\theta)\Big]}_{\text{(B) pathwise / direct term}}

Look at term (A) very carefully. It has exactly the form of the REINFORCE policy gradient :

Eyπθ ⁣[R(y)θlogπθ(y)] \mathbb{E}_{y \sim \pi_\theta}\!\big[R(y) \cdot \nabla_\theta \log \pi_\theta(y)\big]

with the identification R(y)=D(y;θ)R(y) = D(y;\theta). The KL cost of the trajectory is acting as the reward. Term (A) says: make trajectories with high KL cost less likely, and trajectories with low KL cost more likely.

Term (B) is the "supervised" piece - direct gradient through the KL formula with respect to θ\theta.

This decomposition illustrates the dual role of the policy during training .

Now apply stop-gradient. Treating pS(x)p_S(\cdot|x) as having no θ\theta-dependence kills term (A). Only term (B) survives:

θLOD(θ)=ExEypS(x) ⁣[θD(y;θ)] \nabla_\theta L_{OD}(\theta) = \mathbb{E}_{x}\,\mathbb{E}_{y \sim p_S(\cdot|x)}\!\big[\nabla_\theta D(y;\theta)\big]

Let's expand this by computing θDt(θ)\nabla_\theta D_t(\theta) per token. The teacher entropy vpTlogpT\sum_v p_T \log p_T has no θ\theta, so:

θDt(θ)=vVpT(vy<t,x)θlogpSθ(vy<t,x) \nabla_\theta D_t(\theta) = -\sum_{v \in \mathcal{V}} p_T(v \mid y_{<t}, x)\,\nabla_\theta \log p_S^\theta(v \mid y_{<t}, x)

Substituting back:

  θLOD(θ)=ExEypS ⁣[t=1TvVpT(vy<t,x)θlogpSθ(vy<t,x)]   \boxed{\;\nabla_\theta L_{OD}(\theta) = -\,\mathbb{E}_{x}\,\mathbb{E}_{y \sim p_S}\!\left[\sum_{t=1}^{T} \sum_{v \in \mathcal{V}} p_T(v \mid y_{<t}, x)\,\nabla_\theta \log p_S^\theta(v \mid y_{<t}, x)\right]\;}

This is the gradient you actually compute during on-policy distillation.

5.

Comparing to REINFORCE side by side

The sequential REINFORCE gradient for a language model is :

θLRL=ExEyπθ ⁣[t=1TRtθlogπθ(yty<t,x)] \nabla_\theta L_{RL} = -\,\mathbb{E}_{x}\,\mathbb{E}_{y \sim \pi_\theta}\!\left[\sum_{t=1}^{T} R_t \cdot \nabla_\theta \log \pi_\theta(y_t \mid y_{<t}, x)\right]

A quick note on the sum over tt: this is not added by hand. The log-probability of a trajectory factorizes as logπθ(yx)=tlogπθ(yty<t,x)\log \pi_\theta(y|x) = \sum_t \log \pi_\theta(y_t|y_{<t}, x), so differentiating the trajectory log-prob naturally produces a sum over positions. In single-step problems there's no sum; sequences make it unavoidable.

Now put both gradients side by side:

θLRL=Eyπθ ⁣[tRtscalar rewardθlogπθ(yt1 sampled tokeny<t)] \nabla_\theta L_{RL} = -\,\mathbb{E}_{y \sim \pi_\theta}\!\left[\sum_t \underbrace{R_t}_{\text{scalar reward}} \cdot \nabla_\theta \log \pi_\theta(\underbrace{y_t}_{\text{1 sampled token}} \mid y_{<t})\right] θLOD=EypS ⁣[tvpT(vy<t)soft weightθlogpSθ(vall vocab tokensy<t)] \nabla_\theta L_{OD} = -\,\mathbb{E}_{y \sim p_S}\!\left[\sum_t \sum_v \underbrace{p_T(v \mid y_{<t})}_{\text{soft weight}} \cdot \nabla_\theta \log p_S^\theta(\underline{v}_{\text{all vocab tokens}} \mid y_{<t})\right]

The inner v\sum_v is the closed-form expected gradient over the action space, weighted by teacher probabilities. REINFORCE estimates this expectation with a single Monte Carlo sample. OD computes it analytically, because the teacher distribution is a known function at any point. The inner sum is not a second temporal dimension - it is variance reduction at each step.

The structural skeleton is identical:

REINFORCE On-Policy Distillation
Trajectories sampled from current policy sampled from current policy
Signal at each step scalar reward RtR_t KL divergence DtD_t at that state
Per-step gradient Rtlogπ(ytst)R_t \cdot \nabla \log \pi(y_t \mid s_t) vpT(vst)logpSθ(vst)\sum_v p_T(v \mid s_t) \cdot \nabla \log p_S^\theta(v \mid s_t)
Action coverage one sampled token full vocabulary, analytically
Reward source environment teacher model
6.

The KL divergence as a reward signal

The claim "KL acts as a reward" requires a bit more care because, as you will have noticed, DtD_t doesn't appear as a multiplier in the final gradient formula. Where is it?

The key is that DtD_t is the objective being minimized. The gradient θDt\nabla_\theta D_t is the signal that reduces DtD_t at visited states. These are two descriptions of the same thing.

Think about it in RL terms. Define the per-token reward as:

rt  =  DKL ⁣(pT(y<t,x)    pSθ(y<t,x))  =  Dt r_t \;=\; -\,\mathcal{D}_{KL}\!\big(p_T(\cdot \mid y_{<t}, x) \;\|\; p_S^\theta(\cdot \mid y_{<t}, x)\big) \;=\; -D_t

The student earns a negative reward - a penalty - at every token position proportional to how far its distribution is from the teacher's. This reward has three properties that make it ideal:

It is dense. Every single token position in every sampled trajectory produces a real-valued signal. There is no waiting for end-of-sequence; there is no sparsity problem. Compare to human preference RL (like RLHF) where the reward is one number per full sequence.

It is grounded. The reward function is not arbitrary or approximated by a learned reward model. It is exactly the KL to the teacher - a principled, analytic measure of how well the student matches the expert.

It scales with mismatch. Where the student already closely matches the teacher, Dt0D_t \approx 0 and the gradient θDt0\nabla_\theta D_t \approx 0: no update needed, none applied. Where the student diverges significantly, DtD_t is large, the gradient is large, and the update is correspondingly strong. The method automatically focuses capacity on the states where the student needs the most work.

These properties - density, groundedness, adaptive weighting - are precisely what makes a reward signal good for RL training. The KL divergence, evaluated on-policy, provides all three.

An another thing worth noting is that in standard RL, the environment reward RR is a non-differentiable scalar, meaning θR=0.\nabla_{\theta} R = 0. This forces RL to rely entirely on the high-variance REINFORCE estimator (Term A from Section 4).

In On-Policy Distillation, our “reward” (DtD_t) is fully differentiable. By applying stop-gradient to the sampling distribution, we intentionally discard the REINFORCE term and instead backpropagate directly through the loss (Term B). Therefore, DtD_t acts as a conceptual reward-providing dense, grounded credit assignment-but mathematically, we bypass the REINFORCE estimator entirely.

We compute the exact, zero-variance analytic gradient (which reduces to soft-label cross-entropy) evaluated strictly on the states the student actually visits. You get the exploratory benefits of RL with the optimization stability of supervised learning.

7.

The objectives are the same shape

There is one final confusion worth addressing directly. RL is often written compactly as "maximize E[R]\mathbb{E}[R]," which looks nothing like the nested double-expectation in the OD objective. This is a notational mismatch, not a conceptual one.

In sequential RL, the "reward" of a trajectory is a sum:

R(τ)=t=1Trt R(\tau) = \sum_{t=1}^{T} r_t

So the RL objective, fully expanded for a prompt-conditioned language model, is:

JRL(θ)=ExXEyπθ(x) ⁣[t=1Trt] J_{RL}(\theta) = \mathbb{E}_{x \sim X}\,\mathbb{E}_{y \sim \pi_\theta(\cdot|x)}\!\left[\sum_{t=1}^{T} r_t\right]

The OD objective:

LOD(θ)=ExXEypS(x) ⁣[t=1TDt] L_{OD}(\theta) = \mathbb{E}_{x \sim X}\,\mathbb{E}_{y \sim p_S(\cdot|x)}\!\left[\sum_{t=1}^{T} D_t\right]

These are the same structure. Outer expectation over prompts. Inner expectation over student trajectories. Sum of per-step signals along those trajectories. Setting rt=Dtr_t = -D_t and flipping the sign makes them identical. The compact "E[R]\mathbb{E}[R]" notation hides the sum-over-time and the outer prompt expectation - both of which are present in real sequential RL, and both of which are present in OD.

8.

Conclusion

The argument reduces to this:

  1. The on-policy distillation objective samples trajectories from the student and evaluates a per-token KL cost along each trajectory. This is structurally identical to the RL objective with rt=Dtr_t = -D_t as the per-token reward.
  2. Without stop-gradient, the full gradient decomposes into a REINFORCE term - where DtD_t explicitly multiplies logpSθ\nabla \log p_S^\theta - plus a direct pathwise term.
  3. With stop-gradient, only the pathwise term remains. This term, when expanded per-token, is the analytic expectation of the REINFORCE per-step update over the full vocabulary, weighted by teacher probabilities. It is REINFORCE with the Monte Carlo approximation replaced by an exact computation.
  4. The KL divergence plays the role of the reward: dense, grounded, and proportional to how much the student needs to change at each state it visits.

The stop-gradient is the architectural decision that makes this connection precise. Without it, you have a mix of supervised and RL gradients. With it, you have pure policy gradient training where the teacher provides the reward function.

This matters practically. On-policy distillation inherits the benefits of RL training - exploration of the student's own failure modes, credit assignment along self-generated trajectories, adaptive weighting toward difficult states - while retaining the stability of a well-defined, smooth reward function . It also inherits the intuitions of RL: the student is not being spoon-fed correct outputs; it is discovering, through its own exploration, where it falls short of the teacher and learning to close that gap.

When you run on-policy distillation, you are running reinforcement learning. The teacher is the environment. The KL is the reward. The student is the policy. The labels, for once, are optional.

Note: All ideas, math, and technical content are mine; though I have used Claude & GPT for drafting, paragraph polishing, and proofreading.

9.

References

  1. Geoffrey Hinton, Oriol Vinyals, Jeff Dean. "Distilling the Knowledge in a Neural Network." arXiv preprint arXiv:1503.02531, 2015.
  2. Yoon Kim, Alexander M. Rush. "Sequence-Level Knowledge Distillation." EMNLP 2016.
  3. Rishabh Agarwal et al. "On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes." arXiv:2306.13649, 2023.
  4. Ronald J. Williams. "Simple Statistical Gradient-Following Algorithms for Connectionist Reinforcement Learning." Machine Learning, 1992.
  5. Richard S. Sutton, David McAllester, Satinder Singh, Yishay Mansour. "Policy Gradient Methods for Reinforcement Learning with Function Approximation." NeurIPS 1999.
  6. Itai Shenfeld et al. "Self-Distillation Enables Continual Learning." arXiv:2601.19897, 2026.
  7. Siyan Zhao et al. "Self-Distilled Reasoner: On-Policy Self-Distillation for Large Language Models." arXiv:2601.18734, 2026.
  8. Zhoutong Zhang. "On-Policy Distillation: Learning from Self-Generated Mistakes." blog post.

Discussion

Add a Comment

Comments