Models & Algorithms

DiT: U-Net 버리고 Transformer 쓰니까 Scaling Law가 적용됐다 (Sora 기반기술)

U-Net은 크기 키워도 성능 향상이 수확체감. DiT는 모델이 클수록 일관되게 좋아집니다. Sora의 기반이 된 아키텍처 완전 분석.

DiT: U-Net 버리고 Transformer 쓰니까 Scaling Law가 적용됐다 (Sora 기반기술)

DiT: Diffusion Transformer, U-Net을 넘어선 새로운 패러다임

Blog Image
TL;DR: DiT는 Diffusion 모델의 backbone을 U-Net에서 Vision Transformer로 교체합니다. Scaling law가 적용되어 모델이 커질수록 성능이 일관되게 향상됩니다. Sora의 기반 기술입니다.

1. U-Net의 한계

1.1 왜 U-Net이었나?

DDPM부터 Stable Diffusion까지, 모든 주요 Diffusion 모델이 U-Net을 사용한 이유:

  1. Skip Connections: 고해상도 정보 보존
  2. Multi-scale Processing: 다양한 해상도의 특징 추출
  3. Proven Architecture: 세그멘테이션에서 검증됨

1.2 U-Net의 문제점

하지만 U-Net에는 근본적인 한계가 있습니다:

1. Scaling이 어려움

U-Net channels ↑ → 파라미터 ∝ channels²
연산량이 quadratically 증가

2. Inductive Bias

  • CNN의 local connectivity 가정
  • 전역적 정보 처리에 비효율적
  • Attention 블록으로 보완하지만 완벽하지 않음

3. 비일관적 Scaling

U-Net SizeParametersFID 개선
Small100Mbaseline
Medium400M-15%
Large900M-8%
XL2B-3%

수확 체감 현상 발생

1.3 Transformer의 가능성

반면 Vision Transformer는:

  • 일관된 Scaling: 크기에 비례하여 성능 향상
  • 전역 처리: Self-attention으로 모든 패치 간 관계 학습
  • 검증된 Scaling Law: GPT, LLaMA에서 증명됨

2. DiT 아키텍처

2.1 핵심 아이디어

"Diffusion 모델에서 U-Net을 Vision Transformer로 대체하자"

2.2 입력 처리: Patchify

이미지(또는 latent)를 패치로 분할:

python
class PatchEmbed(nn.Module):
    def __init__(self, img_size=32, patch_size=2, in_channels=4, embed_dim=768):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, embed_dim, patch_size, stride=patch_size)

    def forward(self, x):
        # x: (B, C, H, W) → (B, num_patches, embed_dim)
        x = self.proj(x)  # (B, embed_dim, H', W')
        x = x.flatten(2).transpose(1, 2)  # (B, num_patches, embed_dim)
        return x

예시 (Stable Diffusion latent):

  • 입력: 64×64×4 latent
  • 패치 크기: 2×2
  • 패치 수: 32×32 = 1024개
  • 각 패치: 2×2×4 = 16 → embed_dim으로 projection

2.3 Condition 주입: AdaLN

기존 U-Net: Time embedding을 ResBlock에 더함

DiT: Adaptive Layer Normalization (AdaLN)

AdaLN(h,y)=ysLayerNorm(h)+yb\text{AdaLN}(h, y) = y_s \odot \text{LayerNorm}(h) + y_b

여기서 ys,yby_s, y_b는 condition에서 생성된 scale/shift 파라미터

python
class AdaLN(nn.Module):
    def __init__(self, hidden_size, condition_dim):
        super().__init__()
        self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(condition_dim, 6 * hidden_size)
        )

    def forward(self, x, c):
        # c: condition (time + class embedding)
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \
            self.adaLN_modulation(c).chunk(6, dim=-1)

        # 각 블록에서 사용
        return shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp

2.4 DiT Block

python
class DiTBlock(nn.Module):
    def __init__(self, hidden_size, num_heads, condition_dim, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False)
        self.attn = Attention(hidden_size, num_heads)
        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False)
        self.mlp = MLP(hidden_size, int(hidden_size * mlp_ratio))

        # AdaLN modulation
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(condition_dim, 6 * hidden_size)
        )

    def forward(self, x, c):
        # c: condition embedding
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \
            self.adaLN_modulation(c).chunk(6, dim=-1)

        # Self-attention with AdaLN
        x_norm = self.norm1(x)
        x_norm = x_norm * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
        x = x + gate_msa.unsqueeze(1) * self.attn(x_norm)

        # MLP with AdaLN
        x_norm = self.norm2(x)
        x_norm = x_norm * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
        x = x + gate_mlp.unsqueeze(1) * self.mlp(x_norm)

        return x

2.5 출력 처리: Unpatchify

패치들을 다시 이미지로 재구성:

python
class FinalLayer(nn.Module):
    def __init__(self, hidden_size, patch_size, out_channels):
        super().__init__()
        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False)
        self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 2 * hidden_size)
        )

    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
        x = self.norm_final(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
        x = self.linear(x)
        return x

def unpatchify(x, patch_size, img_size):
    """
    x: (B, num_patches, patch_size²×C)
    → (B, C, H, W)
    """
    p = patch_size
    h = w = img_size // p
    c = x.shape[-1] // (p * p)

    x = x.reshape(x.shape[0], h, w, p, p, c)
    x = torch.einsum('nhwpqc->nchpwq', x)
    x = x.reshape(x.shape[0], c, h * p, w * p)
    return x

3. 전체 DiT 모델

3.1 모델 정의

python
class DiT(nn.Module):
    def __init__(
        self,
        input_size=32,          # Latent size
        patch_size=2,
        in_channels=4,
        hidden_size=1152,
        depth=28,
        num_heads=16,
        mlp_ratio=4.0,
        num_classes=1000,       # For class conditioning
        learn_sigma=True,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = in_channels * 2 if learn_sigma else in_channels
        self.patch_size = patch_size
        self.num_heads = num_heads

        # Patch embedding
        self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size)
        num_patches = self.x_embedder.num_patches

        # Time embedding
        self.t_embedder = TimestepEmbedder(hidden_size)

        # Class embedding
        self.y_embedder = LabelEmbedder(num_classes, hidden_size)

        # Positional embedding
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size))

        # DiT blocks
        self.blocks = nn.ModuleList([
            DiTBlock(hidden_size, num_heads, hidden_size, mlp_ratio)
            for _ in range(depth)
        ])

        # Final layer
        self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)

        self.initialize_weights()

    def forward(self, x, t, y):
        """
        x: (B, C, H, W) noisy latent
        t: (B,) timesteps
        y: (B,) class labels
        """
        # Patchify + positional embedding
        x = self.x_embedder(x) + self.pos_embed

        # Condition embedding
        t = self.t_embedder(t)
        y = self.y_embedder(y)
        c = t + y  # Combined condition

        # DiT blocks
        for block in self.blocks:
            x = block(x, c)

        # Final layer
        x = self.final_layer(x, c)

        # Unpatchify
        x = unpatchify(x, self.patch_size, self.input_size)

        return x

3.2 모델 Variants

ModelLayersHidden SizeHeadsParameters
DiT-S12384633M
DiT-B1276812130M
DiT-L24102416458M
DiT-XL28115216675M

4. Scaling Law 분석

4.1 모델 크기 vs 성능

DiT 논문의 핵심 발견:

DiT: FID vs 모델 크기 (ImageNet 256×256)

일관된 개선: 파라미터 증가에 따라 FID가 지속적으로 감소

4.2 Compute vs 성능

Compute (GFLOPs)DiT FIDU-Net FID
5043.552.3
10025.131.8
20012.418.6
5004.99.3

같은 연산량에서 DiT가 더 효율적

4.3 Scaling Law Formula

경험적으로 발견된 관계:

FIDACα\text{FID} \approx A \cdot C^{-\alpha}

여기서:

  • CC: 연산량 (GFLOPs)
  • α0.4\alpha \approx 0.4 for DiT
  • α0.25\alpha \approx 0.25 for U-Net

DiT의 scaling exponent가 더 큼 → 스케일링에 더 유리

5. 학습 및 샘플링

5.1 학습 코드

python
def train_step(model, vae, x, y, noise_scheduler):
    # 1. Encode to latent
    with torch.no_grad():
        z = vae.encode(x).latent_dist.sample() * 0.18215

    # 2. Sample timestep
    t = torch.randint(0, noise_scheduler.num_train_timesteps, (z.shape[0],))

    # 3. Add noise
    noise = torch.randn_like(z)
    z_t = noise_scheduler.add_noise(z, noise, t)

    # 4. Predict noise (or v-prediction)
    model_output = model(z_t, t, y)

    # 5. Compute loss
    if model.learn_sigma:
        noise_pred, _ = model_output.chunk(2, dim=1)
    else:
        noise_pred = model_output

    loss = F.mse_loss(noise_pred, noise)

    return loss

5.2 Classifier-Free Guidance

python
@torch.no_grad()
def sample(model, vae, num_samples, num_classes, cfg_scale=4.0, num_steps=250):
    # Random class labels
    y = torch.randint(0, num_classes, (num_samples,))
    y_null = torch.full_like(y, num_classes)  # Null class

    # Initial noise
    z = torch.randn(num_samples, 4, 32, 32)

    # Sampling loop
    for t in tqdm(reversed(range(num_steps))):
        t_batch = torch.full((num_samples,), t)

        # CFG: predict both conditional and unconditional
        z_input = torch.cat([z, z], dim=0)
        t_input = torch.cat([t_batch, t_batch], dim=0)
        y_input = torch.cat([y, y_null], dim=0)

        model_output = model(z_input, t_input, y_input)
        eps_cond, eps_uncond = model_output.chunk(2, dim=0)

        # Guidance
        eps = eps_uncond + cfg_scale * (eps_cond - eps_uncond)

        # DDPM step
        z = ddpm_step(z, eps, t)

    # Decode
    z = z / 0.18215
    images = vae.decode(z).sample

    return images

6. DiT의 응용

6.1 Sora (OpenAI)

Sora는 DiT를 비디오 생성으로 확장:

Video DiT:
- 입력: 3D latent (T × H × W × C)
- Patchify: Spacetime patches
- Attention: Spatial + Temporal
- Output: Video frames

핵심 변경점:

  • 2D patches → 3D patches
  • 2D positional encoding → 3D positional encoding
  • Cross-frame attention 추가

6.2 Flux (Black Forest Labs)

Flux는 DiT를 T2I에 최적화:

  • MMDiT (Multimodal DiT): Text-image joint attention
  • Rectified Flow: 더 빠른 샘플링
  • 더 큰 스케일: 12B 파라미터

6.3 PixArt 시리즈

PixArt-α, PixArt-Σ:

  • 효율적인 학습 (10% 비용)
  • T5 text encoder 사용
  • Class-to-text 전이 학습

7. 구현 최적화

7.1 Flash Attention

python
from flash_attn import flash_attn_func

class FlashAttention(nn.Module):
    def __init__(self, dim, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        q, k, v = qkv.unbind(2)

        # Flash Attention
        out = flash_attn_func(q, k, v)

        out = out.reshape(B, N, C)
        return self.proj(out)

7.2 Gradient Checkpointing

python
class DiTWithCheckpoint(DiT):
    def forward(self, x, t, y):
        x = self.x_embedder(x) + self.pos_embed
        c = self.t_embedder(t) + self.y_embedder(y)

        # Gradient checkpointing for memory efficiency
        for block in self.blocks:
            x = checkpoint(block, x, c, use_reentrant=False)

        x = self.final_layer(x, c)
        return unpatchify(x, self.patch_size, self.input_size)

7.3 Mixed Precision Training

python
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for batch in dataloader:
    with autocast():
        loss = train_step(model, vae, batch['image'], batch['label'])

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

8. 실험 결과

8.1 ImageNet 256×256

ModelFID ↓IS ↑Parameters
ADM10.94100.98554M
LDM-410.56103.49400M
**DiT-XL/2****2.27****278.24**675M

DiT-XL이 SOTA 달성!

8.2 ImageNet 512×512

ModelFID ↓Parameters
ADM-G7.72608M
**DiT-XL/2****3.04**675M

8.3 Scaling 실험

ModelGFLOPsFID
DiT-S/2668.4
DiT-B/22343.5
DiT-L/28023.3
DiT-XL/21199.6
DiT-XL/2 (더 긴 학습)1192.27

9. 결론

DiT는 Diffusion 모델의 새로운 시대를 열었습니다:

  1. Scalable Architecture: Transformer의 scaling law 활용
  2. 일관된 성능 향상: 크기에 비례하는 품질
  3. 범용성: 이미지, 비디오, 3D 등 다양한 modality 지원
  4. 효율성: 같은 연산량에서 더 좋은 성능

Sora, Flux 등 최신 생성 모델들이 DiT 기반인 이유입니다.

다음 글에서는 PixArt-α를 다룹니다: DiT를 효율적으로 학습하는 방법과 T5 text encoder의 활용.

References

  1. Peebles, W., & Xie, S. (2023). Scalable Diffusion Models with Transformers. ICCV 2023
  2. Dosovitskiy, A., et al. (2021). An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. ICLR 2021
  3. Chen, J., et al. (2023). PixArt-α: Fast Training of Diffusion Transformer for Photorealistic Text-to-Image Synthesis. arXiv
  4. Esser, P., et al. (2024). Scaling Rectified Flow Transformers for High-Resolution Image Synthesis. arXiv

Tags: #DiT #Diffusion-Transformer #Scaling-Law #Sora #Vision-Transformer #이미지생성

이 글의 전체 코드는 첨부된 Jupyter Notebook에서 확인할 수 있습니다.