Models & Algorithms

CFG-free Distillation: Fast Generation Without Guidance

Eliminating the 2x computational cost of CFG. Achieving same quality with single forward pass.

CFG-free Distillation: Fast Generation Without Guidance

CFG-free Distillation: Fast Generation Without Guidance

Eliminating the 2x computational cost of Classifier-Free Guidance. Achieving CFG quality with a single forward pass.

TL;DR

  • Problem: CFG requires two forward passes (conditional + unconditional) = 2x cost
  • Solution: Distill CFG effect into a single model
  • Method: Student mimics Teacher's CFG output
  • Result: Same quality, half the computation, faster inference

1. Classifier-Free Guidance Review

What is CFG?

The key technique for improving conditional generation quality:

ϵ~(xt,c)=ϵ(xt,)+w(ϵ(xt,c)ϵ(xt,))\tilde{\epsilon}(x_t, c) = \epsilon(x_t, \varnothing) + w \cdot (\epsilon(x_t, c) - \epsilon(x_t, \varnothing))

  • $\epsilon(x_t, c)$: Conditional prediction
  • $\epsilon(x_t, \varnothing)$: Unconditional prediction
  • $w$: Guidance scale (typically 7.5)

Problems with CFG

ProblemDescription
2x ComputationTwo forward passes per step
Memory IncreaseEffective batch size doubled
LatencyInference speed halved

Critical for real-time applications!

2. CFG Distillation Idea

Key Insight

What if we train a model to **directly predict** the CFG output?

Teacher (with CFG):

python
2 forward passes → CFG combination → output

Student (CFG-free):

python
1 forward pass → same output

Distillation Objective

L=E[ϵstudent(xt,c)ϵ~teacher(xt,c)2]\mathcal{L} = \mathbb{E}\left[\|\epsilon_\text{student}(x_t, c) - \tilde{\epsilon}_\text{teacher}(x_t, c)\|^2\right]

Student mimics Teacher's CFG result.

3. Training Method

Basic Algorithm

python
def cfg_distillation_loss(student, teacher, x0, c, w=7.5):
    # Add noise
    t = torch.rand(x0.shape[0], device=x0.device)
    noise = torch.randn_like(x0)
    x_t = add_noise(x0, t, noise)

    # Teacher: Apply CFG
    with torch.no_grad():
        eps_cond = teacher(x_t, t, c)
        eps_uncond = teacher(x_t, t, null_cond)
        eps_cfg = eps_uncond + w * (eps_cond - eps_uncond)

    # Student: Single prediction
    eps_student = student(x_t, t, c)

    return F.mse_loss(eps_student, eps_cfg)

Guidance Scale Conditioning

To support various guidance scales:

python
def cfg_distillation_with_scale(student, teacher, x0, c):
    # Sample random guidance scale
    w = torch.rand(x0.shape[0], device=x0.device) * 10 + 1  # [1, 11]

    # Teacher CFG
    with torch.no_grad():
        eps_cfg = compute_cfg(teacher, x_t, t, c, w)

    # Student: scale as input
    eps_student = student(x_t, t, c, w)

    return F.mse_loss(eps_student, eps_cfg)

This allows adjustable guidance scale at inference!

4. Architecture Modifications

Guidance Scale Embedding

python
class CFGFreeUNet(nn.Module):
    def __init__(self, base_unet):
        super().__init__()
        self.unet = base_unet
        # Guidance scale embedding
        self.w_embed = nn.Sequential(
            nn.Linear(1, 256),
            nn.SiLU(),
            nn.Linear(256, 256)
        )

    def forward(self, x, t, c, w):
        # Add w to time embedding
        t_emb = self.time_embed(t)
        w_emb = self.w_embed(w.unsqueeze(-1))
        combined_emb = t_emb + w_emb

        return self.unet(x, combined_emb, c)

Or Simple Approach

If using fixed guidance scale only:

  • No architecture modification needed
  • Distill with specific w value

5. Combining with Progressive Distillation

CFG + Step Distillation

Approach used by SDXL-Turbo, SDXL-Lightning:

python
def combined_distillation(student, teacher, x0, c):
    # 1. Step distillation (2 steps → 1 step)
    t = sample_timestep()
    t_mid = t / 2

    x_t = add_noise(x0, t)

    # Teacher: 2 steps with CFG
    with torch.no_grad():
        # Step 1
        eps1 = compute_cfg(teacher, x_t, t, c, w=7.5)
        x_mid = denoise_step(x_t, eps1, t, t_mid)

        # Step 2
        eps2 = compute_cfg(teacher, x_mid, t_mid, c, w=7.5)
        x_target = denoise_step(x_mid, eps2, t_mid, 0)

    # Student: 1 step, no CFG
    x_pred = student.denoise(x_t, t, c)

    return F.mse_loss(x_pred, x_target)

Final Goal

ModelStepsCFGForward Passes
SD 1.550Yes100
SDXL-Turbo1No1

100x speedup!

6. Implementation Example

Simple CFG-free Distillation

python
class CFGFreeDistillation:
    def __init__(self, student, teacher, guidance_scale=7.5):
        self.student = student
        self.teacher = teacher
        self.w = guidance_scale

        # Freeze teacher
        for p in teacher.parameters():
            p.requires_grad = False

    def compute_teacher_cfg(self, x_t, t, c):
        """Compute Teacher's CFG output"""
        # Conditional prediction
        eps_cond = self.teacher(x_t, t, c)

        # Unconditional prediction (null condition)
        null_c = torch.zeros_like(c)
        eps_uncond = self.teacher(x_t, t, null_c)

        # CFG combination
        return eps_uncond + self.w * (eps_cond - eps_uncond)

    def loss(self, x0, c):
        B = x0.shape[0]

        # Sample noise
        t = torch.rand(B, device=x0.device)
        noise = torch.randn_like(x0)

        # Forward diffusion
        sigma = t.view(B, 1, 1, 1)
        x_t = x0 + sigma * noise

        # Teacher CFG target
        with torch.no_grad():
            target = self.compute_teacher_cfg(x_t, t, c)

        # Student prediction
        pred = self.student(x_t, t, c)

        return F.mse_loss(pred, target)

Training Loop

python
def train_cfg_free(student, teacher, dataloader, epochs=100):
    distiller = CFGFreeDistillation(student, teacher)
    optimizer = torch.optim.AdamW(student.parameters(), lr=1e-4)

    for epoch in range(epochs):
        for images, captions in dataloader:
            # Text encoding
            c = text_encoder(captions)

            loss = distiller.loss(images, c)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch}: loss = {loss.item():.4f}")

7. Variational Score Distillation

Limitations

Limitations of simple distillation:

  • Potential mode collapse
  • Inherits Teacher's limitations
  • Reduced diversity

VSD Approach

Using Score Distillation Sampling:

θLVSD=E[w(t)(ϵteacherϵstudent)ϵstudentθ]\nabla_\theta \mathcal{L}_\text{VSD} = \mathbb{E}\left[w(t)(\epsilon_\text{teacher} - \epsilon_\text{student}) \frac{\partial \epsilon_\text{student}}{\partial \theta}\right]

Adversarial Distillation

Adding GAN loss:

python
def adversarial_distillation_loss(student, teacher, discriminator, x0, c):
    # Distillation loss
    dist_loss = cfg_distillation_loss(student, teacher, x0, c)

    # Generate sample
    z = torch.randn_like(x0)
    x_gen = student.sample(z, c)

    # Adversarial loss
    adv_loss = -discriminator(x_gen, c).mean()

    return dist_loss + 0.1 * adv_loss

8. Real-World Models

SDXL-Turbo

Stability AI's approach:

  • Adversarial Diffusion Distillation (ADD)
  • CFG-free + 1-4 step generation
  • Uses GAN discriminator

SDXL-Lightning

ByteDance's approach:

  • Progressive distillation
  • CFG distillation
  • Efficient training with LoRA

LCM (Latent Consistency Models)

Consistency distillation + CFG:

  • Consistency loss for step reduction
  • Internalized CFG effect

9. Quality Comparison

Quantitative Results

ModelStepsCFGFIDCLIP Score
SDXL507.523.50.32
SDXL-Turbo1No24.10.31
SDXL-Lightning4No23.80.32
LCM-SDXL4No24.50.31

Nearly identical quality, overwhelming speed improvement!

Speed Comparison (A100)

ModelTime/Image
SDXL 50 steps + CFG3.2s
SDXL-Turbo 1 step0.08s

40x faster!

10. Limitations and Future

Current Limitations

LimitationDescription
Teacher DependencyTeacher quality is the ceiling
Training CostRequires large-scale data
Reduced FlexibilityFixed or limited guidance scale

Future Directions

  1. Self-Distillation: Self-improvement without Teacher
  2. Continuous Guidance: Support arbitrary w values
  3. Multi-Modal Guidance: Handle multiple conditions simultaneously

Conclusion

MethodCFG RequiredStepsSpeed
Basic DiffusionYes50+1x
CFG DistillationNo50+2x
+ Step DistillationNo1-425-100x

Key Points of CFG-free Distillation:

  • Distill Teacher's CFG effect into Student
  • Eliminate 2x computational overhead
  • Combined with step distillation for 10-100x speedup
  • Enables real-time image generation

References

  1. Sauer, A., et al. "Adversarial Diffusion Distillation" (2023)
  2. Lin, S., et al. "SDXL-Lightning" (2024)
  3. Luo, S., et al. "Latent Consistency Models" (2023)
  4. Ho, J., Salimans, T. "Classifier-Free Diffusion Guidance" (2022)
  5. Meng, C., et al. "On Distillation of Guided Diffusion Models" (CVPR 2023)