Models & Algorithms

Consistency Models: 1-Step 생성을 위한 새로운 패러다임

Diffusion의 반복 샘플링 없이 단 한 번에 생성. Self-consistency property를 활용한 OpenAI의 혁신적 접근법.

Consistency Models: 1-Step 생성을 위한 새로운 패러다임

Consistency Models: 1-Step 생성을 위한 새로운 패러다임

Diffusion의 반복 샘플링 없이 단 한 번에. OpenAI의 혁신적 접근법.

TL;DR

  • Consistency Models: 동일 trajectory 위의 모든 점을 같은 출력으로 매핑하는 모델
  • Self-Consistency: $f(x_t, t) = f(x_{t'}, t')$ for all $t, t'$ on same trajectory
  • 두 가지 학습법: Consistency Distillation (교사 필요) vs Consistency Training (교사 불필요)
  • 결과: 1-step으로 고품질 생성, 필요시 multi-step으로 품질 향상 가능

1. 왜 Consistency Models인가?

Diffusion의 근본적 한계

Diffusion 모델은 반복적 샘플링이 필수입니다:

python
z ~ N(0,I) → x_T → x_{T-1} → ... → x_1 → x_0

아무리 최적화해도:

  • DDPM: 1000 스텝
  • DDIM: 50-100 스텝
  • DPM-Solver: 10-20 스텝

1-step은 불가능한가?

기존 접근법들의 문제

방법문제점
Progressive Distillation여러 단계 distillation 필요
Rectified FlowReflow 반복 필요
직접 1-step 학습품질 저하 심함

Consistency Models의 아이디어

핵심 관찰:

ODE trajectory 위의 모든 점은 **같은 데이터 포인트**로 수렴한다

따라서:

trajectory 위의 어떤 점에서 시작하든, **같은 출력**을 내는 함수를 학습하자!

2. Self-Consistency Property

정의

Consistency function $f: (x_t, t) \to x_0$는 다음을 만족:

f(xt,t)=f(xt,t)t,t[0,T]f(x_t, t) = f(x_{t'}, t') \quad \forall t, t' \in [0, T]

단, $x_t$와 $x_{t'}$가 같은 ODE trajectory 위에 있을 때.

직관적 이해

python
Noise                                    Data
  z ─────●─────●─────●─────●─────> x_0
         ↓     ↓     ↓     ↓
        f()   f()   f()   f()
         ↓     ↓     ↓     ↓
         └─────┴─────┴─────┘
               모두 같은 x_0

ODE를 따라가면 결국 같은 $x_0$에 도달하므로, 중간 어느 점에서든 바로 $x_0$를 예측할 수 있어야 합니다.

Boundary Condition

$t = 0$에서는 identity가 되어야 합니다:

f(x0,0)=x0f(x_0, 0) = x_0

이미 데이터에 있으면, 그대로 반환.

3. Consistency Model 아키텍처

기본 구조

Boundary condition을 만족하기 위한 설계:

fθ(x,t)=cskip(t)x+cout(t)Fθ(x,t)f_\theta(x, t) = c_{\text{skip}}(t) \cdot x + c_{\text{out}}(t) \cdot F_\theta(x, t)

여기서:

  • $F_\theta$: 신경망 (U-Net, DiT 등)
  • $c_{\text{skip}}(t)$, $c_{\text{out}}(t)$: 시간에 따른 가중치

Skip Connection 설계

Boundary condition $f(x, 0) = x$를 만족하려면:

cskip(0)=1,cout(0)=0c_{\text{skip}}(0) = 1, \quad c_{\text{out}}(0) = 0

일반적인 선택:

cskip(t)=σdata2σdata2+t2c_{\text{skip}}(t) = \frac{\sigma_{\text{data}}^2}{\sigma_{\text{data}}^2 + t^2}

cout(t)=tσdataσdata2+t2c_{\text{out}}(t) = \frac{t \cdot \sigma_{\text{data}}}{\sqrt{\sigma_{\text{data}}^2 + t^2}}

시간 임베딩

$t \to 0$ 근처에서 안정성을 위해 시간을 변환:

t=14log(t+1)t' = \frac{1}{4} \log(t + 1)

4. Consistency Distillation (CD)

개념

미리 학습된 diffusion 모델을 교사로 사용:

  1. 교사 모델로 ODE trajectory 생성
  2. Consistency model이 trajectory의 다른 점들을 같은 출력으로 매핑하도록 학습

알고리즘

python
def consistency_distillation_loss(model, teacher, x0):
    # Sample time
    t = sample_timestep()
    t_next = t - delta_t  # one step earlier

    # Add noise to get x_t
    noise = torch.randn_like(x0)
    x_t = add_noise(x0, t, noise)

    # Teacher takes one ODE step: x_t -> x_{t_next}
    with torch.no_grad():
        x_t_next = teacher_ode_step(teacher, x_t, t, t_next)

    # Consistency loss: f(x_t, t) should equal f(x_{t_next}, t_next)
    pred_t = model(x_t, t)
    pred_t_next = model(x_t_next, t_next).detach()  # stop gradient

    return F.mse_loss(pred_t, pred_t_next)

Target Network (EMA)

안정적 학습을 위해 target network 사용:

θμθ+(1μ)θ\theta^- \leftarrow \mu \theta^- + (1-\mu) \theta

  • $\theta$: 학습되는 모델
  • $\theta^-$: EMA target (stop gradient)
  • $\mu$: decay rate (0.999 등)

Loss Function

LCD=E[d(fθ(xtn,tn),fθ(xtn1,tn1))]\mathcal{L}_{\text{CD}} = \mathbb{E}\left[d(f_\theta(x_{t_n}, t_n), f_{\theta^-}(x_{t_{n-1}}, t_{n-1}))\right]

여기서 $d$는 distance metric (L2, LPIPS 등).

5. Consistency Training (CT)

교사 없이 학습하기

Consistency Distillation은 교사 모델이 필요합니다. 하지만 교사 없이 직접 학습할 수도 있습니다!

핵심 아이디어

ODE를 정확히 풀지 않고, 무한소 스텝에서의 consistency를 강제:

limΔt0f(xt+Δt,t+Δt)=f(xt,t)\lim_{\Delta t \to 0} f(x_{t+\Delta t}, t+\Delta t) = f(x_t, t)

알고리즘

python
def consistency_training_loss(model, x0):
    # Sample time
    t = sample_timestep()
    t_next = t - delta_t

    # Add noise
    noise = torch.randn_like(x0)
    x_t = add_noise(x0, t, noise)
    x_t_next = add_noise(x0, t_next, noise)  # same noise!

    # Consistency loss
    pred_t = model(x_t, t)
    pred_t_next = model(x_t_next, t_next).detach()

    return F.mse_loss(pred_t, pred_t_next)

핵심 차이: 교사의 ODE step 대신, 같은 noise로 다른 시간에서 샘플링.

왜 작동하는가?

$\Delta t \to 0$일 때:

xt+Δtxt+small perturbationx_{t+\Delta t} \approx x_t + \text{small perturbation}

이 perturbation은 ODE의 방향과 일치합니다. 따라서 무한소 단계에서 consistency를 강제하면, 전체 trajectory에서도 consistency가 성립합니다.

CD vs CT 비교

특성Consistency DistillationConsistency Training
교사 모델필요불필요
학습 난이도쉬움어려움
최종 품질더 높음약간 낮음
유연성교사에 의존독립적

6. 샘플링

1-Step 샘플링

가장 단순한 방법:

python
def sample_one_step(model, z):
    # z ~ N(0, I)
    # 바로 x_0 예측
    return model(z, T)

끝! 반복 없이 한 번에 생성.

Multi-Step 샘플링 (품질 향상)

더 높은 품질을 원하면:

python
def sample_multi_step(model, z, timesteps):
    """
    timesteps: [T, t_1, t_2, ..., 0] (decreasing)
    """
    x = z

    for i in range(len(timesteps) - 1):
        t = timesteps[i]
        t_next = timesteps[i + 1]

        # Denoise to x_0
        x_0 = model(x, t)

        # Add noise back to t_next (if not last step)
        if t_next > 0:
            noise = torch.randn_like(x)
            x = add_noise(x_0, t_next, noise)

    return x_0

원리:

  1. 현재 $x_t$에서 $x_0$ 예측
  2. $x_0$에 다시 노이즈 추가하여 $x_{t'}$ 생성
  3. 반복

이렇게 하면 denoising과 noise injection을 번갈아 수행하여 품질 향상.

7. 구현

Consistency Model 클래스

python
class ConsistencyModel(nn.Module):
    def __init__(self, network, sigma_data=0.5):
        super().__init__()
        self.network = network
        self.sigma_data = sigma_data

    def c_skip(self, t):
        return self.sigma_data**2 / (t**2 + self.sigma_data**2)

    def c_out(self, t):
        return t * self.sigma_data / torch.sqrt(t**2 + self.sigma_data**2)

    def forward(self, x, t):
        # Skip connection for boundary condition
        c_skip = self.c_skip(t)
        c_out = self.c_out(t)

        if c_skip.dim() == 1:
            c_skip = c_skip[:, None, None, None]
            c_out = c_out[:, None, None, None]

        F_x = self.network(x, t)

        return c_skip * x + c_out * F_x

Consistency Distillation 학습

python
class ConsistencyDistillation:
    def __init__(self, model, teacher, ema_decay=0.999):
        self.model = model
        self.teacher = teacher
        self.target_model = copy.deepcopy(model)
        self.ema_decay = ema_decay

    def ode_step(self, x, t, t_next):
        """One step of teacher ODE."""
        # Using teacher to estimate velocity/score
        with torch.no_grad():
            score = self.teacher(x, t)
            # Euler step
            dt = t_next - t
            x_next = x + score * dt
        return x_next

    def loss(self, x0):
        B = x0.shape[0]

        # Sample timesteps
        t = torch.rand(B, device=x0.device) * (T - eps) + eps
        t_next = t - delta_t
        t_next = t_next.clamp(min=eps)

        # Forward diffusion
        noise = torch.randn_like(x0)
        x_t = x0 + t[:, None, None, None] * noise

        # Teacher ODE step
        x_t_next = self.ode_step(x_t, t, t_next)

        # Consistency loss
        pred = self.model(x_t, t)
        target = self.target_model(x_t_next, t_next)

        return F.mse_loss(pred, target)

    def update_target(self):
        """EMA update of target network."""
        with torch.no_grad():
            for p, p_target in zip(self.model.parameters(),
                                   self.target_model.parameters()):
                p_target.data.mul_(self.ema_decay).add_(
                    p.data, alpha=1 - self.ema_decay)

Consistency Training 학습

python
class ConsistencyTraining:
    def __init__(self, model, ema_decay=0.999):
        self.model = model
        self.target_model = copy.deepcopy(model)
        self.ema_decay = ema_decay

    def loss(self, x0):
        B = x0.shape[0]

        # Sample timesteps
        t = torch.rand(B, device=x0.device) * (T - eps) + eps
        t_next = t - delta_t
        t_next = t_next.clamp(min=eps)

        # Same noise for both timesteps!
        noise = torch.randn_like(x0)
        x_t = x0 + t[:, None, None, None] * noise
        x_t_next = x0 + t_next[:, None, None, None] * noise

        # Consistency loss
        pred = self.model(x_t, t)
        target = self.target_model(x_t_next, t_next)

        return F.mse_loss(pred, target)

8. Improved Consistency Training (iCT)

원본 CT의 문제점

  • 학습 초기에 불안정
  • 큰 $\Delta t$에서 오류 누적
  • 수렴이 느림

개선 사항

  1. Adaptive $\Delta t$: 학습 진행에 따라 $\Delta t$ 감소
  2. Improved noise schedule: EDM 스타일 noise schedule
  3. Better loss weighting: 시간에 따른 가중치 조정
python
def adaptive_delta_t(step, total_steps):
    """Delta t decreases during training."""
    progress = step / total_steps
    return delta_t_max * (1 - progress) + delta_t_min * progress

9. 실험 결과

CIFAR-10 FID

모델NFEFID
DDPM10003.17
DDIM504.67
Progressive Distillation19.12
Consistency Distillation13.55
Consistency Training15.83

ImageNet 64x64

모델NFEFID
ADM2502.07
Consistency Distillation14.70
Consistency Distillation22.93

핵심 발견

  1. 1-step CD가 기존 distillation 방법들보다 우수
  2. 2-step으로 품질 크게 향상
  3. CT는 CD보다 약간 낮지만, 교사 불필요

10. Latent Consistency Models (LCM)

Stable Diffusion에 적용

Consistency Models를 latent space에서 학습:

python
# Encode to latent
z = vae.encode(image)

# Train consistency model in latent space
z_0_pred = consistency_model(z_t, t)

# Decode for visualization
image_pred = vae.decode(z_0_pred)

LCM의 성과

  • 4 스텝으로 Stable Diffusion 품질 달성
  • 기존 대비 5-10배 속도 향상
  • CFG와 호환 가능

LCM-LoRA

LoRA로 효율적 학습:

python
# Base SD model + LCM-LoRA
pipe = StableDiffusionPipeline.from_pretrained("...")
pipe.load_lora_weights("lcm-lora-sdv1-5")

# Fast generation
image = pipe(prompt, num_inference_steps=4).images[0]

결론

방법스텝교사 필요품질
DDPM1000-높음
DDIM50-높음
Progressive Distill1-4O중간
Consistency Distill1-2O높음
Consistency Training1-2X중상
LCM4O높음

Consistency Models의 핵심:

  • Self-consistency property를 활용
  • ODE trajectory를 직접 학습하지 않고, endpoint를 예측
  • 1-step 생성이 가능하면서도 multi-step으로 품질 향상 가능

References

  1. Song, Y., et al. "Consistency Models" (ICML 2023)
  2. Song, Y., Dhariwal, P. "Improved Techniques for Training Consistency Models" (2023)
  3. Luo, S., et al. "Latent Consistency Models" (2023)
  4. Karras, T., et al. "Elucidating the Design Space of Diffusion-Based Generative Models" (NeurIPS 2022)