Models & Algorithms

Rectified Flow: 1-Step 생성을 향한 경로 직선화

Flow Matching도 느리다면? Reflow로 경로를 펴서 1-step 생성까지. Stable Diffusion 3와 FLUX의 핵심 기술.

Rectified Flow: 1-Step 생성을 향한 경로 직선화

Rectified Flow: 1-Step 생성을 향한 경로 직선화

Flow Matching도 느리다면? Reflow로 경로를 펴서 1-step 생성까지 도달하는 방법.

TL;DR

  • Rectified Flow: Flow Matching의 경로를 반복적으로 "직선화"하는 기법
  • Reflow: 학습된 모델로 (noise, data) 쌍을 생성하고, 이 쌍으로 더 직선적인 경로를 재학습
  • 핵심 이점: Reflow를 반복할수록 경로가 직선에 가까워지고, 최종적으로 1-step 생성 가능
  • 실제 적용: Stable Diffusion 3, FLUX가 Rectified Flow 기반

1. 왜 Flow Matching만으로는 부족한가?

Flow Matching은 DDPM보다 훨씬 적은 스텝(10-50)으로 생성이 가능합니다. 하지만 여전히 한계가 있습니다.

Flow Matching의 한계

Flow Matching의 목표 속도장은:

vt(xtx0,z)=zx0v_t(x_t | x_0, z) = z - x_0

이론적으로는 상수 속도장이지만, 실제 학습에서는 marginal velocity field를 학습합니다:

vt(xt)=Ex0,zxt[zx0]v_t(x_t) = \mathbb{E}_{x_0, z | x_t}[z - x_0]

문제는 서로 다른 $(x_0, z)$ 쌍들이 같은 $x_t$를 지나갈 수 있다는 것입니다. 이 경로들이 교차(crossing)하면, 학습된 속도장은 이들의 평균이 되어 실제로는 곡선을 따라가게 됩니다.

경로 교차 문제

두 데이터 포인트 $x_0^{(1)}, x_0^{(2)}$와 두 노이즈 $z^{(1)}, z^{(2)}$가 있을 때:

xt(1)=(1t)x0(1)+tz(1)x_t^{(1)} = (1-t)x_0^{(1)} + tz^{(1)}

xt(2)=(1t)x0(2)+tz(2)x_t^{(2)} = (1-t)x_0^{(2)} + tz^{(2)}

어떤 $t$에서 $x_t^{(1)} = x_t^{(2)}$가 되면, 신경망은 두 방향의 평균을 예측하게 됩니다. 이것이 transport cost를 증가시키고 샘플링 스텝을 늘려야 하는 원인입니다.

2. Rectified Flow의 핵심 아이디어

Rectified Flow는 간단하지만 강력한 아이디어입니다:

**"학습된 flow로 (z, x₀) 쌍을 만들고, 이 쌍으로 다시 직선 경로를 학습하면 경로가 펴진다"**

Reflow 절차

  1. 초기 Flow Matching 학습: 랜덤 $(x_0, z)$ 쌍으로 기본 모델 $v_{\theta_0}$ 학습
  2. Coupling 생성: 학습된 모델로 noise $z$에서 시작해 data $\hat{x}_0$를 생성

- 이제 $(z, \hat{x}_0)$는 실제로 flow를 따라 연결된 쌍

  1. Reflow 학습: 새로운 모델 $v_{\theta_1}$을 $(z, \hat{x}_0)$ 쌍의 직선 경로로 학습
  2. 반복: 2-3을 반복할수록 경로가 더 직선화

수학적 표현

$k$번째 reflow 후의 coupling을 $\pi_k$라 하면:

Lreflow(k)=E(x0,z)πk,t[(zx0)vθ(xt,t)2]\mathcal{L}_{\text{reflow}}^{(k)} = \mathbb{E}_{(x_0, z) \sim \pi_k, t} \left[ \| (z - x_0) - v_{\theta}(x_t, t) \|^2 \right]

여기서 $x_t = (1-t)x_0 + tz$이고, $\pi_k$는 $k$번째 모델이 생성한 coupling입니다.

3. 왜 Reflow가 경로를 직선화하는가?

직관적 이해

처음에는 랜덤 coupling $(x_0, z)$를 사용합니다. 이 경로들은 서로 교차할 수 있습니다.

하지만 학습된 flow $\phi_1$을 따라가면:

  • $z$에서 출발한 경로는 특정 $\hat{x}_0$에 도착
  • 이 $(z, \hat{x}_0)$ 쌍은 이미 flow를 따라 연결되어 있음
  • 따라서 이 쌍들의 직선 경로는 덜 교차

Transport Cost 감소

Reflow의 핵심은 transport cost를 줄이는 것입니다:

Cost(π)=E(x0,z)π[zx02]\text{Cost}(\pi) = \mathbb{E}_{(x_0, z) \sim \pi} \left[ \| z - x_0 \|^2 \right]

Reflow를 반복하면:

Cost(π0)Cost(π1)Cost(π2)\text{Cost}(\pi_0) \geq \text{Cost}(\pi_1) \geq \text{Cost}(\pi_2) \geq \cdots

경로가 직선화되면서 transport cost가 감소합니다.

이론적 보장

논문에서 증명된 중요한 성질:

  1. Causality: Reflow된 coupling은 인과적(causal)입니다. 즉, $z$가 주어지면 $x_0$가 결정됨
  2. Straightness: Reflow를 무한히 반복하면 경로가 완전히 직선이 됨
  3. 1-step 가능성: 완전히 직선화되면 Euler 1-step으로 정확한 샘플링 가능

4. 1-Step Distillation

Reflow만으로도 경로가 직선화되지만, 실용적으로 1-step 생성을 위해서는 distillation이 필요합니다.

Progressive Distillation

스텝 수를 점진적으로 줄이는 방법:

  1. Teacher 모델: N steps
  2. Student 모델: N/2 steps로 teacher 출력 모방
  3. 반복하여 1-step까지 도달

Ldistill=Ez[ϕteacher(z)Gθ(z)2]\mathcal{L}_{\text{distill}} = \mathbb{E}_{z} \left[ \| \phi_{\text{teacher}}(z) - G_{\theta}(z) \|^2 \right]

Direct Distillation

Rectified Flow의 장점은 경로가 이미 직선에 가깝기 때문에 직접 1-step distillation이 가능하다는 것:

L1-step=Ez[x0(zvθ(z,1))2]\mathcal{L}_{\text{1-step}} = \mathbb{E}_{z} \left[ \| x_0 - (z - v_{\theta}(z, 1)) \|^2 \right]

여기서 $v_{\theta}(z, 1)$은 $t=1$에서의 속도 예측입니다.

5. 구현

Reflow 학습

python
class RectifiedFlow:
    def __init__(self, model):
        self.model = model

    def loss(self, x0, z):
        """Reflow loss with fixed coupling."""
        t = torch.rand(x0.shape[0], device=x0.device)

        # Linear interpolation
        x_t = (1 - t[:, None]) * x0 + t[:, None] * z

        # Target velocity (straight line)
        v_target = z - x0

        # Predicted velocity
        v_pred = self.model(x_t, t)

        return F.mse_loss(v_pred, v_target)

    @torch.no_grad()
    def sample(self, z, n_steps=1):
        """Sample with Euler method."""
        x = z
        dt = 1.0 / n_steps

        for i in range(n_steps):
            t = 1.0 - i * dt
            t_batch = torch.full((x.shape[0],), t, device=x.device)
            v = self.model(x, t_batch)
            x = x - v * dt

        return x

    @torch.no_grad()
    def generate_coupling(self, z, n_steps=50):
        """Generate (z, x0) coupling pairs."""
        x0 = self.sample(z, n_steps=n_steps)
        return z, x0

Reflow 학습 루프

python
def train_reflow(data, n_reflows=3, n_epochs=500):
    """Train with multiple reflow iterations."""

    # Initial Flow Matching
    model = create_model()
    rf = RectifiedFlow(model)

    # Train on random coupling
    for epoch in range(n_epochs):
        x0 = sample_data(data)
        z = torch.randn_like(x0)
        loss = rf.loss(x0, z)
        loss.backward()
        optimizer.step()

    # Reflow iterations
    for k in range(n_reflows):
        print(f"Reflow {k+1}")

        # Generate coupling from current model
        z_all = torch.randn(len(data), dim)
        z_all, x0_all = rf.generate_coupling(z_all, n_steps=50)

        # Train new model on this coupling
        new_model = create_model()
        new_rf = RectifiedFlow(new_model)

        for epoch in range(n_epochs):
            idx = torch.randperm(len(x0_all))[:batch_size]
            loss = new_rf.loss(x0_all[idx], z_all[idx])
            loss.backward()
            optimizer.step()

        rf = new_rf

    return rf

1-Step Distillation

python
def distill_to_one_step(teacher_rf, student_model, data, n_epochs=1000):
    """Distill to 1-step generator."""
    optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-4)

    for epoch in range(n_epochs):
        z = torch.randn(batch_size, dim)

        # Teacher generates target
        with torch.no_grad():
            x0_teacher = teacher_rf.sample(z, n_steps=10)

        # Student predicts in 1 step
        # x0 = z - v(z, t=1)
        v_pred = student_model(z, torch.ones(batch_size))
        x0_student = z - v_pred

        loss = F.mse_loss(x0_student, x0_teacher)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return student_model

6. Stable Diffusion 3와 FLUX

SD3의 Rectified Flow 적용

Stable Diffusion 3는 Rectified Flow를 채택했습니다:

  1. MMDiT 아키텍처: 텍스트와 이미지를 동시에 처리하는 Multimodal DiT
  2. Rectified Flow: 기존 DDPM 대신 직선 경로 학습
  3. 결과: 동일 품질에서 더 적은 스텝 필요

FLUX의 발전

FLUX (by Black Forest Labs)는 SD3를 더 발전시켰습니다:

  • Guidance Distillation: CFG를 모델에 내재화
  • 더 적은 스텝: 4-8 스텝으로 고품질 생성
  • FLUX.1-schnell: 1-4 스텝 생성 가능한 distilled 버전

왜 Rectified Flow인가?

기존 Stable Diffusion (DDPM 기반)에서 전환한 이유:

특성DDPM/DDIMRectified Flow
이론적 기반Score matchingOptimal transport
경로곡선직선 (또는 직선에 가까움)
최소 스텝~20-50~4-10 (distill 후 1)
Distillation어려움상대적으로 쉬움

7. Reflow 횟수와 품질

몇 번의 Reflow가 필요한가?

실험적으로:

  • 1-Reflow: 상당한 직선화, 10-step으로 좋은 품질
  • 2-Reflow: 더 직선화, 5-step 가능
  • 3-Reflow: 거의 직선, 1-2 step 가능

하지만 reflow를 많이 할수록:

  • 학습 시간 증가
  • Coupling 생성에 시간 소요
  • 수렴 속도 감소 가능

실용적 선택

대부분의 경우 1-2회 reflow + distillation이 가장 효율적입니다.

8. 한계와 주의사항

Coupling 품질 의존성

Reflow는 이전 모델의 생성 품질에 의존합니다:

  • 초기 모델이 나쁘면 → 나쁜 coupling → 나쁜 reflow 결과
  • 해결책: 초기 Flow Matching을 충분히 학습

Mode Collapse 위험

Reflow를 너무 많이 하면:

  • Coupling이 특정 모드에 집중될 수 있음
  • 다양성(diversity) 감소 가능
  • 해결책: 적절한 reflow 횟수 선택, regularization

계산 비용

각 reflow 단계마다:

  • 전체 데이터셋에 대해 coupling 생성 필요
  • 새 모델 학습 필요
  • 총 비용 = (1 + n_reflows) × 기본 학습 비용

결론

방법스텝 수특징
DDPM1000원본, 느림
DDIM50-100결정론적 샘플링
Flow Matching10-50직선 경로 학습
Rectified Flow5-10Reflow로 경로 직선화
Rectified + Distill1-41-step 생성 가능

Rectified Flow는 "경로를 펴면 빨라진다"는 직관적 아이디어를 실현한 방법입니다. Stable Diffusion 3와 FLUX의 성공이 이 접근법의 실용성을 증명했습니다.

References

  1. Liu, X., et al. "Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow" (ICLR 2023)
  2. Esser, P., et al. "Scaling Rectified Flow Transformers for High-Resolution Image Synthesis" (Stable Diffusion 3, 2024)
  3. Lipman, Y., et al. "Flow Matching for Generative Modeling" (ICLR 2023)
  4. Salimans, T. & Ho, J. "Progressive Distillation for Fast Sampling of Diffusion Models" (ICLR 2022)