Models & Algorithms

Consistency Models: A New Paradigm for 1-Step Generation

Single-step generation without iterative sampling. OpenAI's innovative approach using self-consistency property.

Consistency Models: A New Paradigm for 1-Step Generation

Consistency Models: A New Paradigm for 1-Step Generation

Single-step generation without iterative sampling. OpenAI's innovative approach.

TL;DR

  • Consistency Models: Map all points on the same trajectory to the same output
  • Self-Consistency: $f(x_t, t) = f(x_{t'}, t')$ for all $t, t'$ on same trajectory
  • Two Training Methods: Consistency Distillation (requires teacher) vs Consistency Training (no teacher)
  • Result: High-quality 1-step generation, with optional multi-step for better quality

1. Why Consistency Models?

The Fundamental Limitation of Diffusion

Diffusion models require iterative sampling:

python
z ~ N(0,I) → x_T → x_{T-1} → ... → x_1 → x_0

No matter how optimized:

  • DDPM: 1000 steps
  • DDIM: 50-100 steps
  • DPM-Solver: 10-20 steps

Is 1-step impossible?

Problems with Existing Approaches

MethodProblem
Progressive DistillationMultiple distillation stages needed
Rectified FlowMultiple reflow iterations needed
Direct 1-step trainingSevere quality degradation

The Consistency Models Idea

Key observation:

All points on an ODE trajectory converge to the **same data point**

Therefore:

Learn a function that outputs the **same result** regardless of starting point on trajectory!

2. Self-Consistency Property

Definition

A consistency function $f: (x_t, t) \to x_0$ satisfies:

f(xt,t)=f(xt,t)t,t[0,T]f(x_t, t) = f(x_{t'}, t') \quad \forall t, t' \in [0, T]

when $x_t$ and $x_{t'}$ are on the same ODE trajectory.

Intuitive Understanding

python
Noise                                    Data
  z ─────●─────●─────●─────●─────> x_0
         ↓     ↓     ↓     ↓
        f()   f()   f()   f()
         ↓     ↓     ↓     ↓
         └─────┴─────┴─────┘
              All same x_0

Following the ODE leads to the same $x_0$, so predicting $x_0$ directly from any intermediate point should be possible.

Boundary Condition

At $t = 0$, should be identity:

f(x0,0)=x0f(x_0, 0) = x_0

If already at data, return as-is.

3. Consistency Model Architecture

Basic Structure

Design to satisfy boundary condition:

fθ(x,t)=cskip(t)x+cout(t)Fθ(x,t)f_\theta(x, t) = c_{\text{skip}}(t) \cdot x + c_{\text{out}}(t) \cdot F_\theta(x, t)

Where:

  • $F_\theta$: Neural network (U-Net, DiT, etc.)
  • $c_{\text{skip}}(t)$, $c_{\text{out}}(t)$: Time-dependent weights

Skip Connection Design

To satisfy $f(x, 0) = x$:

cskip(0)=1,cout(0)=0c_{\text{skip}}(0) = 1, \quad c_{\text{out}}(0) = 0

Common choice:

cskip(t)=σdata2σdata2+t2c_{\text{skip}}(t) = \frac{\sigma_{\text{data}}^2}{\sigma_{\text{data}}^2 + t^2}

cout(t)=tσdataσdata2+t2c_{\text{out}}(t) = \frac{t \cdot \sigma_{\text{data}}}{\sqrt{\sigma_{\text{data}}^2 + t^2}}

Time Embedding

For stability near $t \to 0$, transform time:

t=14log(t+1)t' = \frac{1}{4} \log(t + 1)

4. Consistency Distillation (CD)

Concept

Use a pre-trained diffusion model as teacher:

  1. Generate ODE trajectory with teacher
  2. Train consistency model to map different points on trajectory to same output

Algorithm

python
def consistency_distillation_loss(model, teacher, x0):
    # Sample time
    t = sample_timestep()
    t_next = t - delta_t  # one step earlier

    # Add noise to get x_t
    noise = torch.randn_like(x0)
    x_t = add_noise(x0, t, noise)

    # Teacher takes one ODE step: x_t -> x_{t_next}
    with torch.no_grad():
        x_t_next = teacher_ode_step(teacher, x_t, t, t_next)

    # Consistency loss: f(x_t, t) should equal f(x_{t_next}, t_next)
    pred_t = model(x_t, t)
    pred_t_next = model(x_t_next, t_next).detach()  # stop gradient

    return F.mse_loss(pred_t, pred_t_next)

Target Network (EMA)

For stable training, use target network:

θμθ+(1μ)θ\theta^- \leftarrow \mu \theta^- + (1-\mu) \theta

  • $\theta$: Training model
  • $\theta^-$: EMA target (stop gradient)
  • $\mu$: Decay rate (e.g., 0.999)

Loss Function

LCD=E[d(fθ(xtn,tn),fθ(xtn1,tn1))]\mathcal{L}_{\text{CD}} = \mathbb{E}\left[d(f_\theta(x_{t_n}, t_n), f_{\theta^-}(x_{t_{n-1}}, t_{n-1}))\right]

Where $d$ is a distance metric (L2, LPIPS, etc.).

5. Consistency Training (CT)

Learning Without a Teacher

Consistency Distillation requires a teacher model. But we can also learn without a teacher!

Key Idea

Instead of solving ODE exactly, enforce consistency at infinitesimal steps:

limΔt0f(xt+Δt,t+Δt)=f(xt,t)\lim_{\Delta t \to 0} f(x_{t+\Delta t}, t+\Delta t) = f(x_t, t)

Algorithm

python
def consistency_training_loss(model, x0):
    # Sample time
    t = sample_timestep()
    t_next = t - delta_t

    # Add noise
    noise = torch.randn_like(x0)
    x_t = add_noise(x0, t, noise)
    x_t_next = add_noise(x0, t_next, noise)  # same noise!

    # Consistency loss
    pred_t = model(x_t, t)
    pred_t_next = model(x_t_next, t_next).detach()

    return F.mse_loss(pred_t, pred_t_next)

Key difference: Instead of teacher ODE step, sample at different times with same noise.

Why Does It Work?

When $\Delta t \to 0$:

xt+Δtxt+small perturbationx_{t+\Delta t} \approx x_t + \text{small perturbation}

This perturbation aligns with ODE direction. Thus enforcing consistency at infinitesimal steps implies consistency along entire trajectory.

CD vs CT Comparison

PropertyConsistency DistillationConsistency Training
Teacher ModelRequiredNot required
Training DifficultyEasierHarder
Final QualityHigherSlightly lower
FlexibilityDepends on teacherIndependent

6. Sampling

1-Step Sampling

The simplest method:

python
def sample_one_step(model, z):
    # z ~ N(0, I)
    # Directly predict x_0
    return model(z, T)

Done! Generation without iteration.

Multi-Step Sampling (Quality Improvement)

For higher quality:

python
def sample_multi_step(model, z, timesteps):
    """
    timesteps: [T, t_1, t_2, ..., 0] (decreasing)
    """
    x = z

    for i in range(len(timesteps) - 1):
        t = timesteps[i]
        t_next = timesteps[i + 1]

        # Denoise to x_0
        x_0 = model(x, t)

        # Add noise back to t_next (if not last step)
        if t_next > 0:
            noise = torch.randn_like(x)
            x = add_noise(x_0, t_next, noise)

    return x_0

Principle:

  1. Predict $x_0$ from current $x_t$
  2. Add noise back to get $x_{t'}$
  3. Repeat

This alternates denoising and noise injection for quality improvement.

7. Implementation

Consistency Model Class

python
class ConsistencyModel(nn.Module):
    def __init__(self, network, sigma_data=0.5):
        super().__init__()
        self.network = network
        self.sigma_data = sigma_data

    def c_skip(self, t):
        return self.sigma_data**2 / (t**2 + self.sigma_data**2)

    def c_out(self, t):
        return t * self.sigma_data / torch.sqrt(t**2 + self.sigma_data**2)

    def forward(self, x, t):
        # Skip connection for boundary condition
        c_skip = self.c_skip(t)
        c_out = self.c_out(t)

        if c_skip.dim() == 1:
            c_skip = c_skip[:, None, None, None]
            c_out = c_out[:, None, None, None]

        F_x = self.network(x, t)

        return c_skip * x + c_out * F_x

Consistency Distillation Training

python
class ConsistencyDistillation:
    def __init__(self, model, teacher, ema_decay=0.999):
        self.model = model
        self.teacher = teacher
        self.target_model = copy.deepcopy(model)
        self.ema_decay = ema_decay

    def ode_step(self, x, t, t_next):
        """One step of teacher ODE."""
        with torch.no_grad():
            score = self.teacher(x, t)
            dt = t_next - t
            x_next = x + score * dt
        return x_next

    def loss(self, x0):
        B = x0.shape[0]

        # Sample timesteps
        t = torch.rand(B, device=x0.device) * (T - eps) + eps
        t_next = t - delta_t
        t_next = t_next.clamp(min=eps)

        # Forward diffusion
        noise = torch.randn_like(x0)
        x_t = x0 + t[:, None, None, None] * noise

        # Teacher ODE step
        x_t_next = self.ode_step(x_t, t, t_next)

        # Consistency loss
        pred = self.model(x_t, t)
        target = self.target_model(x_t_next, t_next)

        return F.mse_loss(pred, target)

    def update_target(self):
        """EMA update of target network."""
        with torch.no_grad():
            for p, p_target in zip(self.model.parameters(),
                                   self.target_model.parameters()):
                p_target.data.mul_(self.ema_decay).add_(
                    p.data, alpha=1 - self.ema_decay)

Consistency Training

python
class ConsistencyTraining:
    def __init__(self, model, ema_decay=0.999):
        self.model = model
        self.target_model = copy.deepcopy(model)
        self.ema_decay = ema_decay

    def loss(self, x0):
        B = x0.shape[0]

        # Sample timesteps
        t = torch.rand(B, device=x0.device) * (T - eps) + eps
        t_next = t - delta_t
        t_next = t_next.clamp(min=eps)

        # Same noise for both timesteps!
        noise = torch.randn_like(x0)
        x_t = x0 + t[:, None, None, None] * noise
        x_t_next = x0 + t_next[:, None, None, None] * noise

        # Consistency loss
        pred = self.model(x_t, t)
        target = self.target_model(x_t_next, t_next)

        return F.mse_loss(pred, target)

8. Improved Consistency Training (iCT)

Problems with Original CT

  • Unstable early in training
  • Error accumulation with large $\Delta t$
  • Slow convergence

Improvements

  1. Adaptive $\Delta t$: Decrease $\Delta t$ during training
  2. Improved noise schedule: EDM-style noise schedule
  3. Better loss weighting: Time-dependent weighting
python
def adaptive_delta_t(step, total_steps):
    """Delta t decreases during training."""
    progress = step / total_steps
    return delta_t_max * (1 - progress) + delta_t_min * progress

9. Experimental Results

CIFAR-10 FID

ModelNFEFID
DDPM10003.17
DDIM504.67
Progressive Distillation19.12
Consistency Distillation13.55
Consistency Training15.83

ImageNet 64x64

ModelNFEFID
ADM2502.07
Consistency Distillation14.70
Consistency Distillation22.93

Key Findings

  1. 1-step CD outperforms existing distillation methods
  2. 2-step significantly improves quality
  3. CT slightly lower than CD but requires no teacher

10. Latent Consistency Models (LCM)

Application to Stable Diffusion

Train Consistency Models in latent space:

python
# Encode to latent
z = vae.encode(image)

# Train consistency model in latent space
z_0_pred = consistency_model(z_t, t)

# Decode for visualization
image_pred = vae.decode(z_0_pred)

LCM Achievements

  • 4 steps to match Stable Diffusion quality
  • 5-10x speedup compared to original
  • Compatible with CFG

LCM-LoRA

Efficient training with LoRA:

python
# Base SD model + LCM-LoRA
pipe = StableDiffusionPipeline.from_pretrained("...")
pipe.load_lora_weights("lcm-lora-sdv1-5")

# Fast generation
image = pipe(prompt, num_inference_steps=4).images[0]

Conclusion

MethodStepsTeacher RequiredQuality
DDPM1000-High
DDIM50-High
Progressive Distill1-4YesMedium
Consistency Distill1-2YesHigh
Consistency Training1-2NoMedium-High
LCM4YesHigh

Key to Consistency Models:

  • Leverage self-consistency property
  • Predict endpoint instead of learning ODE trajectory directly
  • Enable 1-step generation while allowing multi-step for quality improvement

References

  1. Song, Y., et al. "Consistency Models" (ICML 2023)
  2. Song, Y., Dhariwal, P. "Improved Techniques for Training Consistency Models" (2023)
  3. Luo, S., et al. "Latent Consistency Models" (2023)
  4. Karras, T., et al. "Elucidating the Design Space of Diffusion-Based Generative Models" (NeurIPS 2022)

Stay Updated

Follow us for the latest posts and tutorials

Subscribe to Newsletter

Related Posts