Models & Algorithms

CFG-free Distillation: Guidance 없이 빠른 생성

Classifier-Free Guidance의 2배 연산 비용 제거. 단일 forward pass로 동일 품질 달성.

CFG-free Distillation: Guidance 없이 빠른 생성

CFG-free Distillation: Guidance 없이 빠른 생성

Classifier-Free Guidance의 2배 연산 비용을 제거. 단일 forward pass로 CFG 품질 달성.

TL;DR

  • 문제: CFG는 조건부/무조건부 두 번의 forward pass 필요 (2배 비용)
  • 해결: Distillation으로 CFG 효과를 단일 모델에 학습
  • 방법: Teacher의 CFG 출력을 Student가 모방
  • 결과: 동일 품질, 절반의 연산량, 더 빠른 추론

1. Classifier-Free Guidance 복습

CFG란?

조건부 생성의 품질을 높이는 핵심 기법:

ϵ~(xt,c)=ϵ(xt,)+w(ϵ(xt,c)ϵ(xt,))\tilde{\epsilon}(x_t, c) = \epsilon(x_t, \varnothing) + w \cdot (\epsilon(x_t, c) - \epsilon(x_t, \varnothing))

  • $\epsilon(x_t, c)$: 조건부 예측
  • $\epsilon(x_t, \varnothing)$: 무조건부 예측
  • $w$: guidance scale (보통 7.5)

CFG의 문제점

문제설명
2x 연산량매 step마다 두 번의 forward pass
메모리 증가배치 크기 실질적 2배
지연 시간추론 속도 절반으로 감소

실시간 애플리케이션에서 치명적!

2. CFG Distillation 아이디어

핵심 통찰

CFG의 출력을 **직접 예측**하도록 학습하면?

Teacher (CFG 사용):

python
2 forward passes → CFG 결합 → 출력

Student (CFG-free):

python
1 forward pass → 동일한 출력

Distillation 목표

L=E[ϵstudent(xt,c)ϵ~teacher(xt,c)2]\mathcal{L} = \mathbb{E}\left[\|\epsilon_\text{student}(x_t, c) - \tilde{\epsilon}_\text{teacher}(x_t, c)\|^2\right]

Student가 Teacher의 CFG 결과를 모방.

3. 학습 방법

기본 알고리즘

python
def cfg_distillation_loss(student, teacher, x0, c, w=7.5):
    # 노이즈 추가
    t = torch.rand(x0.shape[0], device=x0.device)
    noise = torch.randn_like(x0)
    x_t = add_noise(x0, t, noise)

    # Teacher: CFG 적용
    with torch.no_grad():
        eps_cond = teacher(x_t, t, c)
        eps_uncond = teacher(x_t, t, null_cond)
        eps_cfg = eps_uncond + w * (eps_cond - eps_uncond)

    # Student: 단일 예측
    eps_student = student(x_t, t, c)

    return F.mse_loss(eps_student, eps_cfg)

Guidance Scale 조건화

다양한 guidance scale을 지원하려면:

python
def cfg_distillation_with_scale(student, teacher, x0, c):
    # 랜덤 guidance scale 샘플링
    w = torch.rand(x0.shape[0], device=x0.device) * 10 + 1  # [1, 11]

    # Teacher CFG
    with torch.no_grad():
        eps_cfg = compute_cfg(teacher, x_t, t, c, w)

    # Student: scale도 입력으로
    eps_student = student(x_t, t, c, w)

    return F.mse_loss(eps_student, eps_cfg)

이렇게 하면 추론 시 guidance scale 조절 가능!

4. 아키텍처 수정

Guidance Scale 임베딩

python
class CFGFreeUNet(nn.Module):
    def __init__(self, base_unet):
        super().__init__()
        self.unet = base_unet
        # Guidance scale 임베딩
        self.w_embed = nn.Sequential(
            nn.Linear(1, 256),
            nn.SiLU(),
            nn.Linear(256, 256)
        )

    def forward(self, x, t, c, w):
        # w를 시간 임베딩에 추가
        t_emb = self.time_embed(t)
        w_emb = self.w_embed(w.unsqueeze(-1))
        combined_emb = t_emb + w_emb

        return self.unet(x, combined_emb, c)

또는 간단한 방식

고정 guidance scale만 사용한다면:

  • 아키텍처 수정 불필요
  • 특정 w 값으로만 distillation

5. Progressive Distillation과 결합

CFG + Step Distillation

SDXL-Turbo, SDXL-Lightning 등의 접근:

python
def combined_distillation(student, teacher, x0, c):
    # 1. Step distillation (2 steps → 1 step)
    t = sample_timestep()
    t_mid = t / 2

    x_t = add_noise(x0, t)

    # Teacher: 2 steps with CFG
    with torch.no_grad():
        # Step 1
        eps1 = compute_cfg(teacher, x_t, t, c, w=7.5)
        x_mid = denoise_step(x_t, eps1, t, t_mid)

        # Step 2
        eps2 = compute_cfg(teacher, x_mid, t_mid, c, w=7.5)
        x_target = denoise_step(x_mid, eps2, t_mid, 0)

    # Student: 1 step, no CFG
    x_pred = student.denoise(x_t, t, c)

    return F.mse_loss(x_pred, x_target)

최종 목표

ModelStepsCFGForward Passes
SD 1.550Yes100
SDXL-Turbo1No1

100배 속도 향상!

6. 구현 예제

간단한 CFG-free Distillation

python
class CFGFreeDistillation:
    def __init__(self, student, teacher, guidance_scale=7.5):
        self.student = student
        self.teacher = teacher
        self.w = guidance_scale

        # Teacher는 학습하지 않음
        for p in teacher.parameters():
            p.requires_grad = False

    def compute_teacher_cfg(self, x_t, t, c):
        """Teacher의 CFG 출력 계산"""
        # 조건부 예측
        eps_cond = self.teacher(x_t, t, c)

        # 무조건부 예측 (null condition)
        null_c = torch.zeros_like(c)
        eps_uncond = self.teacher(x_t, t, null_c)

        # CFG 결합
        return eps_uncond + self.w * (eps_cond - eps_uncond)

    def loss(self, x0, c):
        B = x0.shape[0]

        # 노이즈 샘플링
        t = torch.rand(B, device=x0.device)
        noise = torch.randn_like(x0)

        # Forward diffusion
        sigma = t.view(B, 1, 1, 1)
        x_t = x0 + sigma * noise

        # Teacher CFG target
        with torch.no_grad():
            target = self.compute_teacher_cfg(x_t, t, c)

        # Student prediction
        pred = self.student(x_t, t, c)

        return F.mse_loss(pred, target)

학습 루프

python
def train_cfg_free(student, teacher, dataloader, epochs=100):
    distiller = CFGFreeDistillation(student, teacher)
    optimizer = torch.optim.AdamW(student.parameters(), lr=1e-4)

    for epoch in range(epochs):
        for images, captions in dataloader:
            # 텍스트 인코딩
            c = text_encoder(captions)

            loss = distiller.loss(images, c)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch}: loss = {loss.item():.4f}")

7. Variational Score Distillation

문제점

단순 distillation의 한계:

  • Mode collapse 가능성
  • Teacher의 한계를 그대로 상속
  • 다양성 감소

VSD 접근

Score Distillation Sampling을 활용:

θLVSD=E[w(t)(ϵteacherϵstudent)ϵstudentθ]\nabla_\theta \mathcal{L}_\text{VSD} = \mathbb{E}\left[w(t)(\epsilon_\text{teacher} - \epsilon_\text{student}) \frac{\partial \epsilon_\text{student}}{\partial \theta}\right]

Adversarial Distillation

GAN loss 추가:

python
def adversarial_distillation_loss(student, teacher, discriminator, x0, c):
    # Distillation loss
    dist_loss = cfg_distillation_loss(student, teacher, x0, c)

    # Generate sample
    z = torch.randn_like(x0)
    x_gen = student.sample(z, c)

    # Adversarial loss
    adv_loss = -discriminator(x_gen, c).mean()

    return dist_loss + 0.1 * adv_loss

8. 실제 모델들

SDXL-Turbo

Stability AI의 접근:

  • Adversarial Diffusion Distillation (ADD)
  • CFG-free + 1-4 step generation
  • GAN discriminator 사용

SDXL-Lightning

ByteDance의 접근:

  • Progressive distillation
  • CFG distillation
  • LoRA 기반 효율적 학습

LCM (Latent Consistency Models)

Consistency distillation + CFG:

  • Consistency loss로 step 감소
  • CFG 효과 내재화

9. 품질 비교

정량적 결과

ModelStepsCFGFIDCLIP Score
SDXL507.523.50.32
SDXL-Turbo1No24.10.31
SDXL-Lightning4No23.80.32
LCM-SDXL4No24.50.31

거의 동등한 품질, 압도적 속도 향상!

속도 비교 (A100 기준)

ModelTime/Image
SDXL 50 steps + CFG3.2s
SDXL-Turbo 1 step0.08s

40배 빠름!

10. 한계와 미래

현재 한계

한계설명
Teacher 의존Teacher 품질이 상한선
학습 비용대규모 데이터 필요
유연성 감소Guidance scale 고정 또는 제한적

미래 방향

  1. Self-Distillation: Teacher 없이 자체 개선
  2. Continuous Guidance: 임의의 w 값 지원
  3. Multi-Modal Guidance: 여러 조건 동시 처리

결론

방법CFG 필요Steps속도
기본 DiffusionYes50+1x
CFG DistillationNo50+2x
+ Step DistillationNo1-425-100x

CFG-free Distillation의 핵심:

  • Teacher의 CFG 효과를 Student에 증류
  • 2배 연산량 제거
  • Step distillation과 결합하면 수십 배 속도 향상
  • 실시간 이미지 생성 가능

References

  1. Sauer, A., et al. "Adversarial Diffusion Distillation" (2023)
  2. Lin, S., et al. "SDXL-Lightning" (2024)
  3. Luo, S., et al. "Latent Consistency Models" (2023)
  4. Ho, J., Salimans, T. "Classifier-Free Diffusion Guidance" (2022)
  5. Meng, C., et al. "On Distillation of Guided Diffusion Models" (CVPR 2023)