Models & Algorithms

PixArt-α: Stable Diffusion 학습비용 $600K를 $26K로 줄인 방법

분해 학습(Decomposed Training)으로 T2I 학습 효율을 23배 높인 비결. 학술 연구자도 접근 가능한 Text-to-Image 모델 만들기.

PixArt-α: Stable Diffusion 학습비용 $600K를 $26K로 줄인 방법

PixArt-α: 효율적인 고해상도 이미지 생성의 새로운 패러다임

TL;DR: PixArt-α는 DiT 기반 텍스트-이미지 생성 모델로, Stable Diffusion 대비 90% 적은 학습 비용으로 동등하거나 더 나은 품질을 달성합니다. 효율적인 학습 전략(분해 학습), T5 텍스트 인코더, Cross-Attention 최적화가 핵심입니다.

1. 소개: 효율적인 T2I 생성의 필요성

1.1 기존 T2I 모델의 문제점

Stable Diffusion, DALL-E 2 등 대규모 텍스트-이미지 모델의 학습에는 막대한 비용이 듭니다:

모델학습 비용GPU 시간CO₂ 배출
DALL-E 2~$1M~200K A100 hrs~50 tons
Stable Diffusion~$600K~150K A100 hrs~35 tons
Imagen~$2M~400K TPU hrs~100 tons

핵심 문제점:

  • 학술 연구자들의 접근성 제한
  • 환경적 부담 (탄소 발자국)
  • 빠른 실험과 반복의 어려움

1.2 PixArt-α의 목표

목표: Stable Diffusion 수준의 품질을
10% 미만의 학습 비용으로 달성

실제 달성:
- 학습 비용: ~$26K (vs $600K)
- GPU 시간: ~675 A100 days (vs 6,250 days)
- CO₂ 배출: ~2.5 tons (vs 35 tons)

2. 핵심 아이디어: 분해 학습 (Decomposed Training)

2.1 학습의 세 가지 측면

PixArt-α는 T2I 학습을 세 가지 독립적인 측면으로 분해합니다:

T2I 학습 = (1) 픽셀 분포 학습
+ (2) 텍스트-이미지 정렬 학습
+ (3) 미적 품질 학습

분해 학습 전략:

단계목표데이터특징
Stage 1픽셀 분포ImageNetClass-conditional 사전학습
Stage 2텍스트-이미지 정렬SAM (10M)고품질 캡션으로 정렬 학습
Stage 3미적 품질미적 데이터소규모 고품질 데이터로 파인튜닝

2.2 왜 분해 학습이 효율적인가?

python
# 기존 방식: 모든 것을 동시에 학습
def traditional_training(model, data):
    for img, text in data:
        # 픽셀 분포 + 정렬 + 미적 품질을 동시에 학습
        loss = diffusion_loss(model(text), img)
        loss.backward()
    # 문제: 각 측면이 서로 간섭, 수렴 어려움

# PixArt-α 방식: 순차적 분해 학습
def decomposed_training(model, imagenet, sam_data, aesthetic_data):
    # Stage 1: 픽셀 분포만 학습 (class-conditional)
    for img, class_label in imagenet:
        loss = diffusion_loss(model(class_label), img)
        # 이미 DiT가 ImageNet에서 학습된 가중치 활용 가능!

    # Stage 2: 텍스트-이미지 정렬 학습
    for img, caption in sam_data:
        loss = diffusion_loss(model(caption), img)
        # 픽셀 분포는 이미 학습됨 → 정렬에만 집중

    # Stage 3: 미적 품질 향상
    for img, caption in aesthetic_data:
        loss = diffusion_loss(model(caption), img)
        # 소량의 고품질 데이터로 미세 조정

2.3 Stage 1: ImageNet 사전학습 활용

DiT의 ImageNet 가중치를 그대로 활용:

python
class PixArtAlpha(nn.Module):
    def __init__(self, pretrained_dit_path=None):
        super().__init__()

        # DiT 백본 로드
        self.dit_backbone = DiT_XL_2()

        if pretrained_dit_path:
            # ImageNet 사전학습 가중치 로드
            checkpoint = torch.load(pretrained_dit_path)
            self.dit_backbone.load_state_dict(checkpoint, strict=False)
            print("Loaded ImageNet pretrained weights!")

        # Class embedding을 Text embedding으로 교체
        self.text_encoder = T5EncoderModel.from_pretrained("t5-xxl")
        self.text_projector = nn.Linear(4096, 1152)  # T5 → DiT hidden dim

효과:

  • ImageNet 학습: 이미 완료 (DiT 논문)
  • 픽셀 분포 학습: ~0 추가 비용
  • 전체 학습 시간의 ~40% 절약

3. 아키텍처: DiT + Cross-Attention 확장

3.1 DiT 기반 아키텍처

PixArt-α는 DiT-XL/2를 기반으로 합니다:

입력 이미지 (512×512×3)

VAE Encoder

잠재 표현 (64×64×4)

Patchify (p=2)

패치 시퀀스 (1024×1152)

DiT Blocks (×28) with Cross-Attention

Unpatchify

VAE Decoder

출력 이미지 (512×512×3)

3.2 Cross-Attention 통합

DiT의 AdaLN만으로는 복잡한 텍스트 조건 반영이 어렵습니다:

python
class PixArtBlock(nn.Module):
    """
    DiT Block + Cross-Attention for text conditioning
    """
    def __init__(self, hidden_dim, num_heads, text_dim):
        super().__init__()

        # Self-Attention (DiT 원본)
        self.norm1 = nn.LayerNorm(hidden_dim, elementwise_affine=False)
        self.self_attn = nn.MultiheadAttention(hidden_dim, num_heads)

        # Cross-Attention (PixArt-α 추가)
        self.norm2 = nn.LayerNorm(hidden_dim, elementwise_affine=False)
        self.cross_attn = nn.MultiheadAttention(hidden_dim, num_heads)

        # FFN
        self.norm3 = nn.LayerNorm(hidden_dim, elementwise_affine=False)
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.GELU(),
            nn.Linear(hidden_dim * 4, hidden_dim)
        )

        # AdaLN modulation (timestep + text pooled)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_dim, 6 * hidden_dim)  # scale, shift for 3 norms
        )

    def forward(self, x, text_emb, pooled_text, timestep_emb):
        # AdaLN parameters
        c = timestep_emb + pooled_text
        shift_sa, scale_sa, shift_ca, scale_ca, shift_ff, scale_ff = \
            self.adaLN_modulation(c).chunk(6, dim=-1)

        # Self-Attention
        x_norm = self.norm1(x) * (1 + scale_sa) + shift_sa
        x = x + self.self_attn(x_norm, x_norm, x_norm)[0]

        # Cross-Attention with text
        x_norm = self.norm2(x) * (1 + scale_ca) + shift_ca
        x = x + self.cross_attn(x_norm, text_emb, text_emb)[0]

        # FFN
        x_norm = self.norm3(x) * (1 + scale_ff) + shift_ff
        x = x + self.ffn(x_norm)

        return x

3.3 T5 텍스트 인코더

CLIP 대신 T5-XXL 사용의 이점:

python
# CLIP vs T5 비교
clip_features = {
    "dimension": 768,
    "max_tokens": 77,
    "strength": "이미지-텍스트 정렬",
    "weakness": "복잡한 텍스트 이해 제한"
}

t5_features = {
    "dimension": 4096,
    "max_tokens": 512,  # 훨씬 긴 프롬프트 가능
    "strength": "언어 이해력, 복잡한 관계 파악",
    "weakness": "이미지와 직접 학습되지 않음"
}

T5 인코더 사용:

python
from transformers import T5Tokenizer, T5EncoderModel

class TextEncoder(nn.Module):
    def __init__(self, model_name="google/flan-t5-xxl"):
        super().__init__()
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.encoder = T5EncoderModel.from_pretrained(model_name)

        # Freeze T5 (학습 효율성)
        for param in self.encoder.parameters():
            param.requires_grad = False

    def forward(self, text):
        # 토큰화
        tokens = self.tokenizer(
            text,
            max_length=512,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        # 인코딩
        with torch.no_grad():
            outputs = self.encoder(**tokens)

        # [batch, seq_len, 4096]
        text_embeddings = outputs.last_hidden_state

        # Pooled representation (문장 전체 요약)
        pooled = text_embeddings.mean(dim=1)

        return text_embeddings, pooled

4. 효율적인 학습 전략

4.1 SAM 데이터셋 활용

Segment Anything Model(SAM)의 부산물 활용:

SAM 데이터셋:
- 이미지 수: 11M (SA-1B의 일부)
- 특징: 고품질, 다양한 객체
- 문제: 캡션 없음 (세그멘테이션 데이터)

해결책: LLaVA로 캡션 자동 생성

캡션 생성 파이프라인:

python
from llava import LLaVAModel

def generate_captions(images, llava_model):
    """
    LLaVA를 사용해 고품질 캡션 생성
    """
    captions = []

    prompt = """Describe this image in detail. Include:
    1. Main subjects and their appearance
    2. Actions or interactions
    3. Background and setting
    4. Colors and lighting
    5. Mood or atmosphere

    Be specific and descriptive."""

    for img in images:
        caption = llava_model.generate(img, prompt)
        captions.append(caption)

    return captions

# 결과 예시
# 기존 LAION 캡션: "a dog"
# LLaVA 생성 캡션: "A golden retriever with fluffy fur sitting on a
#                   wooden porch, looking at the camera with bright
#                   eyes. The background shows a sunny garden with
#                   green grass and colorful flowers."

4.2 효율적인 데이터 전략

python
class EfficientDataStrategy:
    """
    PixArt-α의 데이터 효율성 전략
    """

    def __init__(self):
        # Stage 2: 정렬 학습
        self.alignment_data = {
            "source": "SAM subset",
            "size": "10M images",
            "captions": "LLaVA generated",
            "caption_quality": "High (detailed descriptions)"
        }

        # Stage 3: 미적 품질 학습
        self.aesthetic_data = {
            "source": "Internal + JourneyDB",
            "size": "2M images",
            "filtering": "Aesthetic score > 6.0",
            "resolution": "1024×1024"
        }

    def compare_with_sd(self):
        """
        Stable Diffusion과 데이터 비용 비교
        """
        sd_data = {
            "dataset": "LAION-5B",
            "images": "5 billion",
            "quality": "Mixed (많은 저품질 포함)",
            "filtering_cost": "매우 높음"
        }

        pixart_data = {
            "dataset": "SAM + Aesthetic",
            "images": "12 million",  # 400배 적음!
            "quality": "High (curated)",
            "filtering_cost": "낮음"
        }

        return sd_data, pixart_data

4.3 Re-parameterized Cross-Attention

학습 효율성을 위한 Cross-Attention 최적화:

python
class EfficientCrossAttention(nn.Module):
    """
    학습 초기 안정성을 위한 re-parameterized cross-attention
    """
    def __init__(self, hidden_dim, num_heads, text_dim):
        super().__init__()

        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads

        # Query, Key, Value projections
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(text_dim, hidden_dim)
        self.v_proj = nn.Linear(text_dim, hidden_dim)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)

        # Learnable gate (초기값 0)
        self.gate = nn.Parameter(torch.zeros(1))

    def forward(self, x, text_emb):
        B, N, C = x.shape

        # Q, K, V 계산
        q = self.q_proj(x).view(B, N, self.num_heads, self.head_dim)
        k = self.k_proj(text_emb).view(B, -1, self.num_heads, self.head_dim)
        v = self.v_proj(text_emb).view(B, -1, self.num_heads, self.head_dim)

        # Attention
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
        attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
        attn = attn.softmax(dim=-1)

        out = (attn @ v).transpose(1, 2).reshape(B, N, C)
        out = self.out_proj(out)

        # Gated output: 학습 초기에는 cross-attention 영향 최소화
        return torch.tanh(self.gate) * out

5. 학습 파이프라인

5.1 전체 학습 과정

python
class PixArtTrainer:
    def __init__(self, config):
        self.model = PixArtAlpha(pretrained_dit_path=config.dit_path)
        self.text_encoder = TextEncoder()
        self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema")

        # Stage별 설정
        self.stages = {
            "alignment": {
                "epochs": 20,
                "lr": 1e-4,
                "batch_size": 256,
                "resolution": 512
            },
            "aesthetic": {
                "epochs": 5,
                "lr": 2e-5,
                "batch_size": 64,
                "resolution": 1024
            }
        }

    def train_stage2_alignment(self, sam_dataloader):
        """
        Stage 2: 텍스트-이미지 정렬 학습
        """
        self.model.train()

        for epoch in range(self.stages["alignment"]["epochs"]):
            for batch in sam_dataloader:
                images, captions = batch

                # VAE 인코딩
                with torch.no_grad():
                    latents = self.vae.encode(images).latent_dist.sample()
                    latents = latents * 0.18215

                    # 텍스트 인코딩
                    text_emb, pooled_text = self.text_encoder(captions)

                # 노이즈 추가
                noise = torch.randn_like(latents)
                timesteps = torch.randint(0, 1000, (images.shape[0],))
                noisy_latents = self.add_noise(latents, noise, timesteps)

                # 노이즈 예측
                pred_noise = self.model(noisy_latents, timesteps, text_emb, pooled_text)

                # Loss
                loss = F.mse_loss(pred_noise, noise)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

    def train_stage3_aesthetic(self, aesthetic_dataloader):
        """
        Stage 3: 미적 품질 향상
        """
        # 낮은 learning rate로 파인튜닝
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=self.stages["aesthetic"]["lr"]
        )

        for epoch in range(self.stages["aesthetic"]["epochs"]):
            for batch in aesthetic_dataloader:
                # Stage 2와 동일한 학습 루프
                # 단, 고해상도(1024) + 고품질 데이터 사용
                pass

5.2 학습 비용 분석

=== PixArt-α 학습 비용 ===

Stage 1 (ImageNet pretrain):
- 이미 DiT에서 완료: $0 (재사용)

Stage 2 (Alignment):
- GPU: 64 × A100
- 시간: ~10 days
- 비용: ~$20,000

Stage 3 (Aesthetic):
- GPU: 32 × A100
- 시간: ~3 days
- 비용: ~$6,000

총 비용: ~$26,000

=== Stable Diffusion 학습 비용 ===

전체 학습:
- GPU: 256 × A100
- 시간: ~25 days
- 비용: ~$600,000

비용 절감: 96% (!)

6. 추론 및 생성

6.1 추론 파이프라인

python
class PixArtPipeline:
    def __init__(self, model_path):
        self.model = PixArtAlpha.from_pretrained(model_path)
        self.text_encoder = TextEncoder()
        self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema")
        self.scheduler = DDPMScheduler(num_train_timesteps=1000)

        self.model.eval()
        self.text_encoder.eval()

    @torch.no_grad()
    def generate(
        self,
        prompt: str,
        negative_prompt: str = "",
        num_inference_steps: int = 50,
        guidance_scale: float = 7.5,
        height: int = 1024,
        width: int = 1024,
        seed: int = None
    ):
        if seed is not None:
            torch.manual_seed(seed)

        # 텍스트 인코딩
        text_emb, pooled_text = self.text_encoder(prompt)

        # CFG를 위한 unconditional embedding
        if guidance_scale > 1.0:
            uncond_emb, uncond_pooled = self.text_encoder(negative_prompt)
            text_emb = torch.cat([uncond_emb, text_emb])
            pooled_text = torch.cat([uncond_pooled, pooled_text])

        # 초기 노이즈
        latent_h, latent_w = height // 8, width // 8
        latents = torch.randn(1, 4, latent_h, latent_w)

        # 디노이징 루프
        self.scheduler.set_timesteps(num_inference_steps)

        for t in self.scheduler.timesteps:
            latent_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents

            # 노이즈 예측
            noise_pred = self.model(latent_input, t, text_emb, pooled_text)

            # CFG
            if guidance_scale > 1.0:
                noise_uncond, noise_cond = noise_pred.chunk(2)
                noise_pred = noise_uncond + guidance_scale * (noise_cond - noise_uncond)

            # 디노이징 스텝
            latents = self.scheduler.step(noise_pred, t, latents).prev_sample

        # VAE 디코딩
        latents = latents / 0.18215
        images = self.vae.decode(latents).sample
        images = (images / 2 + 0.5).clamp(0, 1)

        return images

6.2 사용 예시

python
# 파이프라인 초기화
pipe = PixArtPipeline("path/to/pixart-alpha")

# 이미지 생성
image = pipe.generate(
    prompt="A majestic phoenix rising from flames, digital art, "
           "vibrant colors, detailed feathers, dramatic lighting, "
           "4k resolution, trending on artstation",
    negative_prompt="blurry, low quality, distorted",
    num_inference_steps=50,
    guidance_scale=7.5,
    height=1024,
    width=1024,
    seed=42
)

# 저장
save_image(image, "phoenix.png")

7. 실험 결과 및 비교

7.1 정량적 평가

FID Score 비교 (COCO 2014 validation):

모델FID↓학습 비용파라미터
DALL-E 210.39~$1M6.5B
Imagen7.27~$2M3B
Stable Diffusion 1.59.62~$600K860M
PixArt-α**7.32****$26K**600M

핵심: SD의 4%의 비용으로 더 나은 FID 달성!

7.2 Human Preference 평가

사용자 선호도 조사 (1000명):

PixArt-α vs Stable Diffusion:
- PixArt-α 선호: 52%
- SD 선호: 38%
- 동등: 10%

PixArt-α vs DALL-E 2:
- PixArt-α 선호: 45%
- DALL-E 2 선호: 42%
- 동등: 13%

7.3 텍스트 정렬 품질

복잡한 프롬프트 처리 능력 (T5의 강점):

python
# 복잡한 관계가 있는 프롬프트
complex_prompts = [
    "A red cube on top of a blue sphere, with a green pyramid beside them",
    "Three cats: one sleeping, one playing, one eating",
    "A person holding an umbrella in their left hand and a coffee cup in their right hand"
]

# T5 vs CLIP 정확도
results = {
    "spatial_relations": {"PixArt-α (T5)": 0.82, "SD (CLIP)": 0.64},
    "object_counting": {"PixArt-α (T5)": 0.75, "SD (CLIP)": 0.58},
    "attribute_binding": {"PixArt-α (T5)": 0.79, "SD (CLIP)": 0.67}
}

8. 확장: PixArt-α → PixArt-Σ

8.1 PixArt-Σ의 개선점

PixArt-Σ (후속 버전):

1. 약한 to 강한 학습 전략
- PixArt-α 체크포인트에서 시작
- 더 강한 T5 사용 (XXL → 더 큰 버전)

2. 해상도 향상
- 512 → 1024 → 2K까지 지원
- Multi-scale training

3. 효율성 향상
- KV-compression으로 메모리 절약
- 더 빠른 추론

8.2 VAE Finetuning

python
# PixArt-Σ의 향상된 VAE
class ImprovedVAE:
    """
    고해상도를 위한 VAE 파인튜닝
    """
    def __init__(self, base_vae):
        self.vae = base_vae

        # Decoder만 파인튜닝 (Encoder freeze)
        for name, param in self.vae.named_parameters():
            if "decoder" not in name:
                param.requires_grad = False

    def finetune(self, high_res_data):
        """
        고해상도 이미지로 디코더 파인튜닝
        """
        for images in high_res_data:
            # Encode-Decode
            latents = self.vae.encode(images).latent_dist.sample()
            reconstructed = self.vae.decode(latents).sample

            # Reconstruction loss + Perceptual loss
            loss = F.mse_loss(reconstructed, images)
            loss += self.perceptual_loss(reconstructed, images)

            loss.backward()

9. 구현 팁과 모범 사례

9.1 효율적인 학습을 위한 팁

python
# 1. Mixed Precision Training
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

with autocast():
    pred = model(x, t, text_emb, pooled)
    loss = F.mse_loss(pred, noise)

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

# 2. Gradient Checkpointing
model.enable_gradient_checkpointing()

# 3. Flash Attention
from flash_attn import flash_attn_func

def efficient_attention(q, k, v):
    return flash_attn_func(q, k, v, causal=False)

9.2 메모리 최적화

python
# 1. Text Encoder 분리 추론
def encode_text_batch(prompts, text_encoder, batch_size=16):
    """
    큰 배치의 텍스트를 메모리 효율적으로 인코딩
    """
    all_embeddings = []

    for i in range(0, len(prompts), batch_size):
        batch = prompts[i:i+batch_size]
        with torch.no_grad():
            emb, pooled = text_encoder(batch)
        all_embeddings.append((emb.cpu(), pooled.cpu()))

    return all_embeddings

# 2. Inference 시 VAE 분리
@torch.no_grad()
def memory_efficient_decode(latents, vae, tile_size=512):
    """
    고해상도 이미지를 타일 단위로 디코딩
    """
    # 구현: 타일로 분할하여 순차적으로 디코딩
    pass

10. 결론 및 시사점

10.1 PixArt-α의 기여

기여설명
**효율적 학습**96% 비용 절감으로 민주화
**분해 학습**복잡한 T2I를 독립적 하위 문제로 분리
**T5 활용**더 나은 텍스트 이해력
**데이터 효율성**고품질 소량 데이터 > 저품질 대량 데이터

10.2 연구 방향에 대한 시사점

PixArt-α가 보여준 것:
1. 대규모 ≠ 고품질: 효율적인 전략이 더 중요
2. 사전학습 활용: 바퀴를 재발명하지 말 것
3. 데이터 품질: 양보다 질이 중요
4. 분해 접근법: 복잡한 문제를 단순화

향후 연구 방향:
- 더 효율적인 텍스트-이미지 정렬 방법
- 비디오 생성으로의 확장
- 더 작은 모델로 동등한 품질 달성

참고문헌

  1. Chen, J., et al. (2023). PixArt-α: Fast Training of Diffusion Transformer for Photorealistic Text-to-Image Synthesis. arXiv:2310.00426
  2. Chen, J., et al. (2024). PixArt-Σ: Weak-to-Strong Training of Diffusion Transformer for 4K Text-to-Image Generation. arXiv:2403.04692
  3. Peebles, W., & Xie, S. (2023). Scalable Diffusion Models with Transformers. ICCV 2023
  4. Rombach, R., et al. (2022). High-Resolution Image Synthesis with Latent Diffusion Models. CVPR 2022
  5. Raffel, C., et al. (2020). Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer. JMLR

Tags: #PixArt-α #DiT #Text-to-Image #Efficient-Training #T5 #Cross-Attention #Diffusion #Decomposed-Training #Image-Generation

이 글의 실험 코드는 첨부된 Jupyter Notebook에서 확인할 수 있습니다.