← jmlemercier.github.io

NeurIPS 2026 · preprint

DiLaDiff: Distilled Latent-Augmented Diffusion for Language Modeling

Jean-Marie Lemercier Tomas Geffner Morteza Mardani Karsten Kreis Arash Vahdat Ante Jukić

NVIDIA

tl;dr

Diffusion language models can't decode many tokens in parallel because they assume token independence. DiLaDiff fixes this with a continuous latent that captures token correlations, learned by a latent diffusion model and then self-distilled into a 5-step generative model. The result: 7× wall-time speed-up over masked diffusion baselines at higher generation quality, with the latent computed in negligible time compared to discrete decoding.

DiLaDiff architecture: an encoder produces a latent z, a latent denoiser models its prior with continuous diffusion (Ncont steps), and a discrete decoder samples tokens conditioned on z over Ndisc masked-diffusion steps. A self-distillation branch trains a MeanFlow student that replaces the teacher's instantaneous velocity in few-step sampling.
Figure 1. Hybrid continuous–discrete diffusion with a self-distilled latent. The latent space is crafted with encoder Eϕ and decoder xθ, then learned a posteriori with a diffusion process whose denoiser is zψ. The latent trajectories are further self-distilled with a MeanFlow student uη(zτ, τ, r).

Abstract

Diffusion language models intrinsically fail to capture correlations between decoded tokens, which leads to a harsh trade-off between sampling quality and throughput. To solve this issue, we propose DiLaDiff, a variant of masked diffusion language models with three components: (1) a continuous latent space with semantic capabilities, learned by an auto-encoder fine-tuned from an existing masked diffusion language model; (2) a latent diffusion model learning the prior over the encoder distribution; (3) a consistency model distilling the learned prior into a few-step latent generative model. We show that, even without distillation, our latent-guided diffusion model outperforms the masked diffusion baseline while significantly accelerating inference. Consistency distillation further lowers the computational overhead of continuous diffusion, such that the latent is generated in negligible time compared to discrete decoding.

Why latents fix diffusion LMs

The problem · diffusion LMs ignore token correlations at decoding time

Auto-regressive language models are trained with teacher forcing and recover an exact factorization of the joint likelihood — token by token, left to right. Diffusion language models, in contrast, are trained with an ELBO that gives a factorized denoising posterior: each token is predicted independently from the rest of the noisy context. That factorization is exact only when you sample one token at a time. The moment the model is asked to decode several tokens in parallel, those tokens are treated as conditionally independent — but in natural language they are not.

A denoising Transformer fills two masked positions in a sentence: 'NVIDIA sells several mostly linen items'. Two equations on the right show that the true joint distribution p_true(x0, x4 | xt) is not equal to the product of marginals, which is what the dLM factorization uses.
The factorization gap. Sampling two masked tokens in parallel from independent marginals produces nonsense like "NVIDIA … linen items" because the joint distribution is not separable. This mismatch is exactly the gap between the true joint ptrue(x0, x4 | xt) and the dLM factorization ptrue(x0 | xt) · ptrue(x4 | xt).

The practical consequence is a brutal speed–quality trade-off: pushing more tokens per forward pass tanks generation quality.

The fix · condition on a latent that absorbs the correlations

If we hand the denoiser an extra channel \(\mathbf{z}\) that already captures the joint structure of the clean sequence, then the factorized posterior becomes exact in expectation over \(\mathbf{z}\).

  1. The conditional denoiser factorizes. Conditional on the latent \(\mathbf{z}\), the clean-data denoiser truly factorizes across tokens — all token correlations have been absorbed into \(\mathbf{z}\): \[ q^{\theta}(\mathbf{x}_0 \mid \mathbf{z}) \;=\; \prod_{\ell} q^{\theta}(\mathbf{x}_0^{\ell} \mid \mathbf{z}). \]
  2. The forward kernel already factorizes. The masking process is independent across tokens (whether a token is masked is decided context-free), so factorization is automatic — no latent required: \[ q(\mathbf{x}_t \mid \mathbf{x}_s) \;=\; \prod_{\ell} q(\mathbf{x}_t^{\ell} \mid \mathbf{x}_s^{\ell}). \]
  3. By Bayes, the reverse posterior factorizes too. Combining the two and exploiting and exploiting the forward kernel's Markove properties yields a fully factorized reverse posterior: \[ q^{\theta}_{s\mid t}(\mathbf{x}_s \mid \mathbf{x}_t, \mathbf{z}) \;=\; \prod_{\ell} q^{\theta}_{s\mid t}(\mathbf{x}_s^{\ell} \mid \mathbf{x}_t^{\ell}, \mathbf{z}). \]

The caveat in step 3 is the whole game: we now need an auto-encoder that learns a useful \(\mathbf{z}\), a generative model for sampling \(\mathbf{z}\) from scratch at inference, and a way to make that sampling fast enough not to eat the parallel-decoding gains. The next three stages each tackle one of those problems.

Method walkthrough

Stage 1 — Auto-encoding text

We craft the latent space by fine-tuning an auto-encoder on top of pre-trained contextual embeddings, with the decoder warm-started from an existing masked diffusion LM. The auto-encoder is built incrementally — scrub through to see each component drop in.

Frozen contextual embedding (BERT) feeds a compressor that outputs latents z. Data augmentation injected on the embeddings and on the latents to regularize the latent space. A stack of self-attention layers added as the decoder, initialized from a pretrained masked diffusion language model. Cross-attention layers wrap the decoder so it is guided by the latent z, completing the auto-encoder.
1 · compress contextual embeddings

Pre-trained contextual embeddings (frozen BERT) are passed through a Perceiver-style compressor that produces a sequence of latents z shorter than the token sequence..

2 · regularize with data augmentation

Without regularization the latent space is too sparse for a generative model to learn. We inject Gaussian noise on the embeddings, and random masking on both embeddings and latents, to push the auto-encoder toward a smoother, diffusion-friendly representation.

3 · reuse a pretrained dLM as decoder

Rather than training a decoder from scratch, we initialize the self-attention stack from a pretrained masked diffusion language model. This gives us robust token-space decoding for free and keeps the latent's job narrow: just capture global context.

4 · guide decoding with cross-attention

Two zero-initialized cross-attention layers wrap the decoder so its hidden state can attend to the latent z. With latent dropout in training, the model degrades gracefully to the original dLM behaviour when no latent is provided.

Stage 2 — Modelling the latent prior (LaDiff)

With the auto-encoder frozen, we learn the prior p(z) with a standard continuous diffusion model trained on encoder samples. At inference, we run the probability-flow ODE in the latent space, then hand the resulting z to the discrete decoder. We call this hybrid model LaDiff.

A frozen encoder produces a latent z from the clean sequence x. The latent z is corrupted by a forward diffusion kernel into a noisy latent z_tau. A learnable latent denoiser z_psi reconstructs z from z_tau under an LDM loss.
1 · sample a latent from the frozen encoder

The clean sequence x is encoded into a latent z by the frozen Stage-1 encoder Eϕ.

2 · diffuse with the forward kernel

A noise level τ is sampled, and z is perturbed to zτ with a standard Gaussian forward kernel — exactly the standard continuous diffusion setup, just in the latent space.

3 · train the latent denoiser

The latent denoiser zψ(zτ, τ) is trained to reconstruct the clean latent with an LDM L2 loss. At inference we solve the reverse probability-flow ODE to draw a fresh z, then hand it off to the discrete decoder for masked-diffusion sampling.

Stage 3 — Self-distilling latent trajectories (DiLaDiff)

At 200 NFEs the latent diffusion alone eats most of the wall-time budget — defeating the speed-up. We distill the LaDiff teacher into a few-step student via MeanFlow. The construction has three sub-steps: identify what's wrong with one-step velocity sampling, fix it with an average-velocity correction, and train the correction by self-distillation.

3.1 · limitations of one-step ODE sampling

The latent denoiser computes an instantaneous velocity — the tangent to the ODE trajectory. Using it for one big Euler step is fine when the step is small or the trajectory is straight, but it falls apart when both fail.

A straight blue arrow representing the teacher's instantaneous velocity v_psi(z_t, t). The same velocity arrow with a small step shown to land on the ODE curve correctly. A long step along a low-curvature ODE trajectory still lands close to the target. A long step along a highly-curved ODE trajectory misses the target (red X).
1 · the teacher's instantaneous velocity

The trained denoiser gives us vψ(zt, t), the tangent to the ODE path at the current point. One Euler step follows that tangent.

2 · small timesteps work fine

For a small enough step Δt, the tangent approximates the chord well and Euler sampling lands close to the true trajectory. This is the many-step regime.

3 · low ODE curvature also works

If the ODE path is nearly straight, even a large step along the tangent stays close to the trajectory. This is what teachers with straight probability-flow ODEs exploit.

4 · neither holds in few-step regimes

In the general few-step ODE sampling case, the steps are big and the trajectory is curved. The tangent misses the target by a wide margin — the red ✗. The single-step Euler velocity is no longer the right object to use.

3.2 · the MeanFlow correction

Instead of the instantaneous velocity, the student parameterizes the average velocity \(\mathbf{u}_\eta(\mathbf{z}_t, t, r)\) over an interval \([r, t]\) — that's the right object for a big Euler step. MeanFlow corrects the teacher's tangent with a tangent term, and the tangent itself is obtained by self-distillation: a stop-gradient through the Jacobian of the student. The training target reads

\[ \mathbf{u}_{\text{tgt}} \;=\; \mathbf{v}_\psi(\mathbf{z}_t, t) \;-\; (t - r)\,\big(\mathbf{v}_\psi(\mathbf{z}_t, t)\,\partial_{\mathbf{z}}\mathbf{u}_\eta \,+\, \partial_t \mathbf{u}_\eta\big), \]

and the student is fit with a simple regression loss

\[ \mathcal{L}_{\text{MeanFlow}}(\eta) \;=\; \mathbb{E}\;\big\|\,\mathbf{u}_\eta(\mathbf{z}_t, t, r) \;-\; \operatorname{stopgrad}(\mathbf{u}_{\text{tgt}})\,\big\|_2^2. \]

The result is a single network that can be plugged in place of the teacher's velocity for big ODE solver steps.

3.3 · training the student by self-distillation

Putting it together, DiLaDiff is trained on the same encoder + forward-diffusion setup as LaDiff, with the MeanFlow loss matching the student's average velocity against the teacher-derived target.

A frozen encoder produces a latent z. Forward diffusion to noise level t produces z_t from z. The frozen teacher v_psi computes the instantaneous velocity at (z_t, t). The trainable student u_eta produces an average velocity; the MeanFlow loss compares it to the teacher target.
1 · sample a latent from the frozen encoder

Exactly as in Stage 2: a clean sequence x goes through the frozen encoder Eϕ to give z.

2 · pick two noise levels and diffuse to t

We pick a target noise level t and a smaller one r, then perturb z to zt with the same forward kernel as LaDiff.

3 · query the frozen teacher

The frozen LaDiff teacher vψ(zt, t) provides the instantaneous velocity needed to assemble the MeanFlow target.

4 · train the student against the target

The student uη(zt, t, r) is updated with the MeanFlow loss against the teacher-derived target. After 25k steps the student matches the teacher with Ncont = 5 instead of 200.

Results

Latent space analysis

Before measuring downstream generation quality, we check that the auto-encoder's latent actually does what we asked of it: capture sentence-level semantics. We perturb a latent with Gaussian noise and decode it 8 steps, then measure semantic distance against the source.

BERTScore-F1 between sentences decoded from the same latent drops monotonically as we add noise — both against the clean source at \(t = 0\) and against the ground-truth sentence. The latent carries semantics inherited from the contextual embedding.

Semantic distance vs noise level: monotonic curves showing BERTScore-F1 drift as the latent is corrupted.
Semantic distance. BERTScore-F1 between decoded sentences and either the clean source (\(t=0\)) or the ground truth, as a function of the latent's noise level.

Decomposing further: intra-pool sentences (decoded from the same latent with different seeds) stay close in BERTScore, while inter-pool sentences (different latent, same seed) drop sharply. The latent captures sentence-level meaning while the decoder contributes word-level diversity.

Two plots: the semantic-distance curve and a bar chart showing intra-pool BERTScore stays high while inter-pool drops sharply.
Intra- vs inter-pool similarity. Same-latent samples remain semantically close; different-latent samples diverge sharply.

LaDiff (no distillation)

Even before distillation, conditioning the masked-diffusion decoder on a learned latent pays off. The undistilled LaDiff sits on a strictly better speed–quality Pareto frontier than MDLM, preserves diversity at low temperatures, and lets us exploit a confidence-based token selection that MDLM cannot use reliably.

LaDiff dominates the speed–quality Pareto frontier. At equal MAUVE and GenPPL, LaDiff is roughly 7× faster than MDLM. At equal throughput (16 tokens per forward pass) it is also strictly better in quality.

Speed-quality Pareto: LaDiff dominates MDLM in both MAUVE (higher is better) and GenPPL (lower is better) across throughput regimes.
Speed–quality Pareto frontier. LaDiff (green) vs MDLM (orange). Higher MAUVE and lower GenPPL at the same throughput.

Diversity is captured in the latent. Lowering the decoder temperature \(\zeta\) normally collapses MDLM samples — entropy and MAUVE fall off a cliff. LaDiff stays put because the latent has already absorbed sentence-level diversity, freeing the decoder temperature to do token-level sharpening.

Temperature sampling: LaDiff preserves MAUVE and entropy as temperature drops, while MDLM collapses.
Temperature sampling. MAUVE, GenPPL and Entropy as a function of the decoder temperature \(\zeta\), with \(N_{\text{cont}} = 200\).

Confidence-based token selection is unlocked. Choosing which masked positions to fill in by the decoder's own confidence is unreliable for MDLM at small scales (the dashed confidence curves collapse). With LaDiff, confidence-based selection matches or beats random selection across every \(N_{\text{disc}}\).

Confidence-based token selection: LaDiff with confidence sampling matches its argmax curve while MDLM with confidence collapses to a low MAUVE plateau.
Confidence-based vs random token selection, \(N_{\text{cont}} = 200\). LaDiff (blue) survives confidence sampling; MDLM (green, dashed) does not.

DiLaDiff (with self-distillation)

Self-distilling the latent denoiser removes the continuous-diffusion bottleneck. With only 5 latent steps, DiLaDiff stays on the same Pareto frontier as LaDiff at 200 steps — the latent is essentially free.

DiLaDiff matches the teacher with 40× fewer latent steps. With \(N_{\text{cont}} = 5\) latent diffusion steps, DiLaDiff hits the same Pareto frontier as the LaDiff teacher with \(N_{\text{cont}} = 200\). Latent diffusion now accounts for ~5% of the wall-time — effectively free.

Speed-quality Pareto with DiLaDiff curves overlaid: DiLaDiff matches LaDiff teacher quality at much higher throughput.
Pareto with distillation. DiLaDiff (purple, \(N_{\text{cont}}=5\)) sits on the LaDiff teacher's frontier (green, \(N_{\text{cont}}=200\)) and far above MDLM.

MeanFlow and Terminal Velocity Matching agree. The MeanFlow correction and Terminal Velocity Matching give nearly identical curves across MAUVE, GenPPL, and Entropy as \(N_{\text{cont}}\) sweeps from 5 to 200. Both clearly outperform the dashed teacher curve at the small-\(N_{\text{cont}}\) end.

Three plots (MAUVE, GenPPL, Entropy) comparing MeanFlow and Terminal Velocity Matching distillation across Ncont; both match the teacher with only 5 steps.
MeanFlow vs Terminal Velocity Matching. Two noise schedules each (LogitNormal(-1,1) and (-1,2)) compared against the teacher across the latent NFE sweep.

Citation

@article{lemercier2026diladiff,
  title   = {DiLaDiff: Distilled Latent-Augmented Diffusion for Language Modeling},
  author  = {Lemercier, Jean-Marie and Geffner, Tomas and Mardani, Morteza
             and Kreis, Karsten and Vahdat, Arash and Juki{\'c}, Ante},
  journal = {arXiv preprint},
  year    = {2026}
}