Models & Algorithms

Stable Diffusion 3 & FLUX: MMDiT 아키텍처 완전 분석

U-Net을 버리고 Transformer로. Text와 Image를 동등하게 처리하는 MMDiT 아키텍처와 Rectified Flow, Guidance Distillation까지.

Stable Diffusion 3 & FLUX: MMDiT 아키텍처 완전 분석

Stable Diffusion 3 & FLUX: MMDiT 아키텍처 완전 분석

U-Net을 버리고 Transformer로. Text와 Image를 동등하게 처리하는 새로운 패러다임.

TL;DR

  • MMDiT (Multimodal DiT): 텍스트와 이미지를 하나의 Transformer에서 동시 처리
  • Rectified Flow 채택: DDPM 대신 직선 경로로 빠른 생성
  • FLUX 발전: Guidance Distillation으로 CFG 없이 4-8 스텝 생성
  • 핵심 혁신: Text-Image 간 양방향 attention으로 더 정확한 프롬프트 따르기

1. 왜 U-Net을 버렸는가?

U-Net의 한계

Stable Diffusion 1.x/2.x는 U-Net 기반이었습니다:

python
Text Encoder (CLIP) → Cross-Attention → U-Net → Image

문제점:

  • 일방향 정보 흐름: 텍스트 → 이미지만 가능, 이미지 → 텍스트 피드백 없음
  • Cross-attention 병목: 텍스트 정보가 특정 레이어에서만 주입
  • Scaling 한계: U-Net은 모델 크기 증가 시 성능 향상이 수확체감

DiT의 등장

DiT (Diffusion Transformer)가 보여준 것:

  • Transformer는 scaling law를 따름
  • 모델이 클수록 일관되게 FID 개선
  • 하지만 DiT도 텍스트 처리는 cross-attention 방식

MMDiT: 진정한 Multimodal

SD3의 MMDiT는 텍스트와 이미지를 동등한 시퀀스로 처리:

python
[Text Tokens] + [Image Tokens] → Joint Transformer → [Text'] + [Image']

양방향 attention으로 텍스트가 이미지를 보고, 이미지도 텍스트를 봅니다.

2. MMDiT 아키텍처 상세

입력 처리

이미지 입력:

  1. VAE Encoder로 latent 추출: (H, W, 3)(h, w, 16)
  2. Patchify: (h, w, 16)(N_img, D)
  3. Position embedding 추가

텍스트 입력:

  1. 세 가지 텍스트 인코더 사용:

- CLIP-L (OpenAI)

- CLIP-G (OpenCLIP)

- T5-XXL (Google)

  1. Pooled + Sequence embeddings 결합
  2. (N_txt, D) 형태로 변환

Joint Attention Block

핵심은 MM-DiT Block입니다:

python
class MMDiTBlock(nn.Module):
    def __init__(self, dim):
        self.norm1_img = AdaLayerNorm(dim)
        self.norm1_txt = AdaLayerNorm(dim)
        self.attn = JointAttention(dim)

        self.norm2_img = AdaLayerNorm(dim)
        self.norm2_txt = AdaLayerNorm(dim)
        self.ff_img = FeedForward(dim)
        self.ff_txt = FeedForward(dim)

    def forward(self, img, txt, timestep):
        # Separate normalization
        img_norm = self.norm1_img(img, timestep)
        txt_norm = self.norm1_txt(txt, timestep)

        # Joint attention (핵심!)
        img_attn, txt_attn = self.attn(img_norm, txt_norm)

        img = img + img_attn
        txt = txt + txt_attn

        # Separate feedforward
        img = img + self.ff_img(self.norm2_img(img, timestep))
        txt = txt + self.ff_txt(self.norm2_txt(txt, timestep))

        return img, txt

Joint Attention 메커니즘

python
class JointAttention(nn.Module):
    def forward(self, img, txt):
        # 이미지와 텍스트를 concat
        x = torch.cat([img, txt], dim=1)

        # Q, K, V 계산
        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)

        # Self-attention (모든 토큰이 서로를 봄)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        out = attn @ v

        # 다시 분리
        img_out, txt_out = out.split([img.shape[1], txt.shape[1]], dim=1)

        return img_out, txt_out

핵심 포인트:

  • 이미지 토큰이 텍스트 토큰에 attend
  • 텍스트 토큰이 이미지 토큰에 attend
  • 양방향 정보 흐름으로 더 정확한 text-image alignment

AdaLN (Adaptive Layer Normalization)

timestep 정보를 주입하는 방법:

python
class AdaLayerNorm(nn.Module):
    def __init__(self, dim):
        self.norm = nn.LayerNorm(dim, elementwise_affine=False)
        self.proj = nn.Linear(dim, dim * 2)

    def forward(self, x, timestep_emb):
        # timestep에서 scale, shift 예측
        scale, shift = self.proj(timestep_emb).chunk(2, dim=-1)

        # Adaptive normalization
        x = self.norm(x)
        x = x * (1 + scale) + shift

        return x

3. Rectified Flow in SD3

SD3는 DDPM 대신 Rectified Flow를 사용합니다.

왜 Rectified Flow인가?

특성DDPMRectified Flow
경로곡선 (SDE)직선 (ODE)
필요 스텝20-504-10
학습 목표노이즈 예측속도장 예측
Distillation어려움용이

SD3의 Flow Formulation

python
def flow_matching_loss(model, x0, text_emb):
    # Sample time
    t = torch.rand(x0.shape[0])

    # Sample noise
    z = torch.randn_like(x0)

    # Linear interpolation
    x_t = (1 - t) * x0 + t * z

    # Target velocity
    v_target = z - x0

    # Predict velocity
    v_pred = model(x_t, t, text_emb)

    return F.mse_loss(v_pred, v_target)

Logit-Normal Sampling

SD3의 특별한 점: timestep을 uniform이 아닌 logit-normal 분포에서 샘플링

python
def logit_normal_sample(batch_size, m=0.0, s=1.0):
    """중간 timestep에 더 집중"""
    u = torch.randn(batch_size) * s + m
    t = torch.sigmoid(u)  # (0, 1) 범위로 변환
    return t

이유: 중간 timestep이 학습에 더 중요하기 때문

4. FLUX: SD3의 진화

FLUX는 Black Forest Labs (SD3 개발자들이 설립)에서 만든 모델입니다.

FLUX vs SD3 비교

특성SD3FLUX
개발사Stability AIBlack Forest Labs
아키텍처MMDiTMMDiT (개선)
GuidanceCFG 필요Distilled (CFG-free 가능)
모델 크기2B, 8B12B
최소 스텝20-304 (schnell)

FLUX 변형들

  1. FLUX.1-pro: 최고 품질, API only
  2. FLUX.1-dev: 연구/개발용, 오픈 웨이트
  3. FLUX.1-schnell: 1-4 스텝 생성, 가장 빠름

Guidance Distillation

FLUX.1-schnell의 핵심 기술:

기존 CFG (Classifier-Free Guidance):

python
# 추론 시 2배 계산 필요
pred_uncond = model(x_t, t, null_text)
pred_cond = model(x_t, t, text)
pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)

Guidance Distillation 후:

python
# 1번의 forward pass로 CFG 효과
pred = model(x_t, t, text)  # guidance가 내재화됨

학습 방법:

python
def guidance_distillation_loss(student, teacher, x_t, t, text):
    # Teacher: CFG 적용
    with torch.no_grad():
        pred_uncond = teacher(x_t, t, null_text)
        pred_cond = teacher(x_t, t, text)
        target = pred_uncond + cfg_scale * (pred_cond - pred_uncond)

    # Student: 단일 forward
    pred = student(x_t, t, text)

    return F.mse_loss(pred, target)

5. 텍스트 인코더 전략

SD3의 Triple Text Encoder

SD3는 세 가지 텍스트 인코더를 사용:

python
class TripleTextEncoder:
    def __init__(self):
        self.clip_l = CLIPTextModel("openai/clip-vit-large")
        self.clip_g = CLIPTextModel("laion/CLIP-ViT-bigG")
        self.t5 = T5EncoderModel("google/t5-v1_1-xxl")

    def encode(self, text):
        # CLIP embeddings (pooled)
        clip_l_pooled, clip_l_seq = self.clip_l(text)
        clip_g_pooled, clip_g_seq = self.clip_g(text)

        # T5 embedding (sequence only)
        t5_seq = self.t5(text)

        # Pooled: conditioning용
        pooled = torch.cat([clip_l_pooled, clip_g_pooled], dim=-1)

        # Sequence: cross-attention용
        seq = torch.cat([clip_l_seq, clip_g_seq, t5_seq], dim=1)

        return pooled, seq

왜 세 개인가?

인코더강점토큰 제한
CLIP-L일반적 시각 개념77
CLIP-G더 큰 용량77
T5-XXL긴 텍스트, 복잡한 관계512

T5의 추가로 긴 프롬프트복잡한 관계 이해가 크게 향상되었습니다.

FLUX의 텍스트 처리

FLUX는 더 단순화:

  • CLIP-L + T5-XXL 조합
  • T5에 더 의존 (더 긴 컨텍스트 활용)

6. VAE 개선

SD3의 16채널 VAE

기존 SD 1.x/2.x: 4채널 latent

SD3/FLUX: 16채널 latent

python
# SD 1.x/2.x
latent = vae.encode(image)  # (B, 4, H/8, W/8)

# SD3/FLUX
latent = vae.encode(image)  # (B, 16, H/8, W/8)

장점:

  • 더 많은 정보 보존
  • 세밀한 디테일 재구성
  • 텍스트 렌더링 품질 향상

단점:

  • 메모리 사용량 증가
  • 계산량 증가

7. 실제 사용 예시

Diffusers로 SD3 사용

python
from diffusers import StableDiffusion3Pipeline
import torch

pipe = StableDiffusion3Pipeline.from_pretrained(
    "stabilityai/stable-diffusion-3-medium",
    torch_dtype=torch.float16
)
pipe = pipe.to("cuda")

image = pipe(
    prompt="A cat holding a sign that says 'Hello World'",
    num_inference_steps=28,
    guidance_scale=7.0,
).images[0]

Diffusers로 FLUX 사용

python
from diffusers import FluxPipeline
import torch

# FLUX.1-schnell (빠른 버전)
pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell",
    torch_dtype=torch.bfloat16
)
pipe = pipe.to("cuda")

image = pipe(
    prompt="A cat holding a sign that says 'Hello World'",
    num_inference_steps=4,  # 4 스텝만!
    guidance_scale=0.0,     # CFG 불필요
).images[0]

메모리 최적화

python
# CPU offload
pipe.enable_model_cpu_offload()

# Attention slicing
pipe.enable_attention_slicing()

# VAE tiling (고해상도용)
pipe.enable_vae_tiling()

8. 성능 비교

텍스트 렌더링 능력

SD3/FLUX의 가장 큰 개선점: 텍스트를 이미지에 정확히 렌더링

모델"Hello World" 정확도
SD 1.5~10%
SD 2.1~20%
SDXL~40%
SD3~80%
FLUX~90%

프롬프트 따르기

복잡한 프롬프트 테스트: "A red cube on top of a blue sphere, with a green pyramid to the left"

모델정확도
SDXL중간
SD3높음
FLUX매우 높음

생성 속도 (A100 기준)

모델스텝시간
SDXL30~3s
SD328~4s
FLUX-dev20~5s
FLUX-schnell4~1s

9. 한계와 주의사항

메모리 요구사항

모델VRAM (fp16)
SDXL~8GB
SD3-medium~12GB
FLUX-dev~24GB
FLUX-schnell~12GB

라이선스

  • SD3: Stability AI Community License (상업적 제한)
  • FLUX.1-dev: 연구/개발용 (상업적 제한)
  • FLUX.1-schnell: Apache 2.0 (상업적 사용 가능)

알려진 문제

  1. 인체 해부학: 여전히 손가락 등에서 오류 발생
  2. 텍스트 일관성: 긴 텍스트에서 가끔 오류
  3. 스타일 다양성: 특정 스타일에 편향될 수 있음

결론

특성SD 1.x/2.xSDXLSD3FLUX
아키텍처U-NetU-NetMMDiTMMDiT
Text-ImageCross-attnCross-attnJoint-attnJoint-attn
FlowDDPMDDPMRectifiedRectified
텍스트 렌더링나쁨보통좋음매우 좋음
최소 스텝20+20+20+4

SD3와 FLUX는 U-Net에서 Transformer로, DDPM에서 Rectified Flow로의 패러다임 전환을 보여줍니다. MMDiT의 양방향 attention은 text-image alignment를 크게 개선했고, Rectified Flow는 빠른 생성을 가능하게 했습니다.

References

  1. Esser, P., et al. "Scaling Rectified Flow Transformers for High-Resolution Image Synthesis" (SD3 Paper, 2024)
  2. Black Forest Labs. "FLUX.1 Technical Report" (2024)
  3. Peebles, W. & Xie, S. "Scalable Diffusion Models with Transformers" (DiT, 2023)
  4. Liu, X., et al. "Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow" (2023)