Models & Algorithms

CFG-free Distillation: Fast Generation Without Guidance

Eliminating the 2x computational cost of CFG. Achieving same quality with single forward pass.

CFG-free Distillation: Fast Generation Without Guidance

CFG-free Distillation: Fast Generation Without Guidance

Eliminating the 2x computational cost of Classifier-Free Guidance. Achieving CFG quality with a single forward pass.

TL;DR

  • Problem: CFG requires two forward passes (conditional + unconditional) = 2x cost
  • Solution: Distill CFG effect into a single model
  • Method: Student mimics Teacher's CFG output
  • Result: Same quality, half the computation, faster inference

1. Classifier-Free Guidance Review

What is CFG?

The key technique for improving conditional generation quality:

ϵ~(xt,c)=ϵ(xt,)+w(ϵ(xt,c)ϵ(xt,))\tilde{\epsilon}(x_t, c) = \epsilon(x_t, \varnothing) + w \cdot (\epsilon(x_t, c) - \epsilon(x_t, \varnothing))

  • $\epsilon(x_t, c)$: Conditional prediction
  • $\epsilon(x_t, \varnothing)$: Unconditional prediction
  • $w$: Guidance scale (typically 7.5)

Problems with CFG

ProblemDescription
2x ComputationTwo forward passes per step
Memory IncreaseEffective batch size doubled
LatencyInference speed halved

Critical for real-time applications!

2. CFG Distillation Idea

Key Insight

What if we train a model to **directly predict** the CFG output?

Teacher (with CFG):

python
2 forward passes → CFG combination → output

Student (CFG-free):

python
1 forward pass → same output

Distillation Objective

L=E[ϵstudent(xt,c)ϵ~teacher(xt,c)2]\mathcal{L} = \mathbb{E}\left[\|\epsilon_\text{student}(x_t, c) - \tilde{\epsilon}_\text{teacher}(x_t, c)\|^2\right]

Student mimics Teacher's CFG result.

3. Training Method

Basic Algorithm

python
def cfg_distillation_loss(student, teacher, x0, c, w=7.5):
    # Add noise
    t = torch.rand(x0.shape[0], device=x0.device)
    noise = torch.randn_like(x0)
    x_t = add_noise(x0, t, noise)

    # Teacher: Apply CFG
    with torch.no_grad():
        eps_cond = teacher(x_t, t, c)
        eps_uncond = teacher(x_t, t, null_cond)
        eps_cfg = eps_uncond + w * (eps_cond - eps_uncond)

    # Student: Single prediction
    eps_student = student(x_t, t, c)

    return F.mse_loss(eps_student, eps_cfg)

Guidance Scale Conditioning

To support various guidance scales:

python
def cfg_distillation_with_scale(student, teacher, x0, c):
    # Sample random guidance scale
    w = torch.rand(x0.shape[0], device=x0.device) * 10 + 1  # [1, 11]

    # Teacher CFG
    with torch.no_grad():
        eps_cfg = compute_cfg(teacher, x_t, t, c, w)

    # Student: scale as input
    eps_student = student(x_t, t, c, w)

    return F.mse_loss(eps_student, eps_cfg)

This allows adjustable guidance scale at inference!

4. Architecture Modifications

Guidance Scale Embedding

python
class CFGFreeUNet(nn.Module):
    def __init__(self, base_unet):
        super().__init__()
        self.unet = base_unet
        # Guidance scale embedding
        self.w_embed = nn.Sequential(
            nn.Linear(1, 256),
            nn.SiLU(),
            nn.Linear(256, 256)
        )

    def forward(self, x, t, c, w):
        # Add w to time embedding
        t_emb = self.time_embed(t)
        w_emb = self.w_embed(w.unsqueeze(-1))
        combined_emb = t_emb + w_emb

        return self.unet(x, combined_emb, c)

Or Simple Approach

If using fixed guidance scale only:

  • No architecture modification needed
  • Distill with specific w value

5. Combining with Progressive Distillation

CFG + Step Distillation

Approach used by SDXL-Turbo, SDXL-Lightning:

python
def combined_distillation(student, teacher, x0, c):
    # 1. Step distillation (2 steps → 1 step)
    t = sample_timestep()
    t_mid = t / 2

    x_t = add_noise(x0, t)

    # Teacher: 2 steps with CFG
    with torch.no_grad():
        # Step 1
        eps1 = compute_cfg(teacher, x_t, t, c, w=7.5)
        x_mid = denoise_step(x_t, eps1, t, t_mid)

        # Step 2
        eps2 = compute_cfg(teacher, x_mid, t_mid, c, w=7.5)
        x_target = denoise_step(x_mid, eps2, t_mid, 0)

    # Student: 1 step, no CFG
    x_pred = student.denoise(x_t, t, c)

    return F.mse_loss(x_pred, x_target)

Final Goal

ModelStepsCFGForward Passes
SD 1.550Yes100
SDXL-Turbo1No1

100x speedup!

6. Implementation Example

Simple CFG-free Distillation

python
class CFGFreeDistillation:
    def __init__(self, student, teacher, guidance_scale=7.5):
        self.student = student
        self.teacher = teacher
        self.w = guidance_scale

        # Freeze teacher
        for p in teacher.parameters():
            p.requires_grad = False

    def compute_teacher_cfg(self, x_t, t, c):
        """Compute Teacher's CFG output"""
        # Conditional prediction
        eps_cond = self.teacher(x_t, t, c)

        # Unconditional prediction (null condition)
        null_c = torch.zeros_like(c)
        eps_uncond = self.teacher(x_t, t, null_c)

        # CFG combination
        return eps_uncond + self.w * (eps_cond - eps_uncond)

    def loss(self, x0, c):
        B = x0.shape[0]

        # Sample noise
        t = torch.rand(B, device=x0.device)
        noise = torch.randn_like(x0)

        # Forward diffusion
        sigma = t.view(B, 1, 1, 1)
        x_t = x0 + sigma * noise

        # Teacher CFG target
        with torch.no_grad():
            target = self.compute_teacher_cfg(x_t, t, c)

        # Student prediction
        pred = self.student(x_t, t, c)

        return F.mse_loss(pred, target)

Training Loop

python
def train_cfg_free(student, teacher, dataloader, epochs=100):
    distiller = CFGFreeDistillation(student, teacher)
    optimizer = torch.optim.AdamW(student.parameters(), lr=1e-4)

    for epoch in range(epochs):
        for images, captions in dataloader:
            # Text encoding
            c = text_encoder(captions)

            loss = distiller.loss(images, c)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch}: loss = {loss.item():.4f}")

7. Variational Score Distillation

Limitations

Limitations of simple distillation:

  • Potential mode collapse
  • Inherits Teacher's limitations
  • Reduced diversity

VSD Approach

Using Score Distillation Sampling:

θLVSD=E[w(t)(ϵteacherϵstudent)ϵstudentθ]\nabla_\theta \mathcal{L}_\text{VSD} = \mathbb{E}\left[w(t)(\epsilon_\text{teacher} - \epsilon_\text{student}) \frac{\partial \epsilon_\text{student}}{\partial \theta}\right]

Adversarial Distillation

Adding GAN loss:

python
def adversarial_distillation_loss(student, teacher, discriminator, x0, c):
    # Distillation loss
    dist_loss = cfg_distillation_loss(student, teacher, x0, c)

    # Generate sample
    z = torch.randn_like(x0)
    x_gen = student.sample(z, c)

    # Adversarial loss
    adv_loss = -discriminator(x_gen, c).mean()

    return dist_loss + 0.1 * adv_loss

8. Real-World Models

SDXL-Turbo

Stability AI's approach:

  • Adversarial Diffusion Distillation (ADD)
  • CFG-free + 1-4 step generation
  • Uses GAN discriminator

SDXL-Lightning

ByteDance's approach:

  • Progressive distillation
  • CFG distillation
  • Efficient training with LoRA

LCM (Latent Consistency Models)

Consistency distillation + CFG:

  • Consistency loss for step reduction
  • Internalized CFG effect

9. Quality Comparison

Quantitative Results

ModelStepsCFGFIDCLIP Score
SDXL507.523.50.32
SDXL-Turbo1No24.10.31
SDXL-Lightning4No23.80.32
LCM-SDXL4No24.50.31

Nearly identical quality, overwhelming speed improvement!

Speed Comparison (A100)

ModelTime/Image
SDXL 50 steps + CFG3.2s
SDXL-Turbo 1 step0.08s

40x faster!

10. Limitations and Future

Current Limitations

LimitationDescription
Teacher DependencyTeacher quality is the ceiling
Training CostRequires large-scale data
Reduced FlexibilityFixed or limited guidance scale

Future Directions

  1. Self-Distillation: Self-improvement without Teacher
  2. Continuous Guidance: Support arbitrary w values
  3. Multi-Modal Guidance: Handle multiple conditions simultaneously

Conclusion

MethodCFG RequiredStepsSpeed
Basic DiffusionYes50+1x
CFG DistillationNo50+2x
+ Step DistillationNo1-425-100x

Key Points of CFG-free Distillation:

  • Distill Teacher's CFG effect into Student
  • Eliminate 2x computational overhead
  • Combined with step distillation for 10-100x speedup
  • Enables real-time image generation

References

  1. Sauer, A., et al. "Adversarial Diffusion Distillation" (2023)
  2. Lin, S., et al. "SDXL-Lightning" (2024)
  3. Luo, S., et al. "Latent Consistency Models" (2023)
  4. Ho, J., Salimans, T. "Classifier-Free Diffusion Guidance" (2022)
  5. Meng, C., et al. "On Distillation of Guided Diffusion Models" (CVPR 2023)

Stay Updated

Follow us for the latest posts and tutorials

Subscribe to Newsletter

Related Posts