[Nanochat 분석하기] 4. RoPE, MQA, Flash Attention

Read time: 2 minutes

현대 LLM의 비밀: RoPE, MQA, Flash Attention

**시리즈**: 나노챗(nanochat)으로 배우는 LLM 구축 - Part 4/9

Attention의 진화

2017년 Transformer 논문 이후 7년간, Attention은 엄청나게 발전했습니다:

  • 2017: Vanilla Attention
    
    2019: Learned Positional Embeddings
    
    2021: RoPE (Rotary Position Embedding)
    
    2022: Flash Attention (Memory Efficient)
    
    2023: Multi-Query Attention (KV Cache 최적화)
    
    2024: Flash Attention 2 (2x faster!)

    nanochat는 이 모든 최신 기법을 사용합니다. 하나씩 살펴봅시다.

    문제 1: Position 정보가 없다

    기본 Attention의 큰 문제:

  • sentence1 = ["cat", "chased", "mouse"]
    sentence2 = ["mouse", "chased", "cat"]
    
    # Attention은 이 둘을 구분 못함!
    # {cat, chased, mouse} = {mouse, chased, cat}
    # Order를 전혀 고려 안 함 ❌

    해결책 1: Learned Position Embeddings (GPT-2 방식)

  • class OldStyleGPT(nn.Module):
        def __init__(self, vocab_size, n_embd, max_len):
            # Token embedding
            self.token_emb = nn.Embedding(vocab_size, n_embd)
    
            # Position embedding (learned)
            self.pos_emb = nn.Embedding(max_len, n_embd)
    
        def forward(self, idx):
            B, T = idx.shape
    
            # Token embeddings
            tok_emb = self.token_emb(idx)  # (B, T, n_embd)
    
            # Position embeddings
            pos = torch.arange(T, device=idx.device)
            pos_emb = self.pos_emb(pos)  # (T, n_embd)
    
            # Add them together
            x = tok_emb + pos_emb  # (B, T, n_embd)
            return x

    문제점:

  • # 훈련: max_len = 1024
    # 추론: "Can I use 2048 tokens?"
    # → NO! Position embedding은 1024까지만 학습됨 ❌
    
    # Position 2000:
    pos_emb[2000]  # Error: index out of range!

    해결책 2: RoPE (Rotary Position Embedding) ⭐

    RoPE의 핵심 아이디어: Position을 회전으로 인코딩

    복소수 평면에서의 회전을 생각해봅시다:

  • # Position 0: 회전 0도
    # Position 1: 회전 θ
    # Position 2: 회전 2θ
    # Position 3: 회전 3θ
    # ...

    2D 회전 행렬:

  • [cos(θ)  -sin(θ)]   [x]
    [sin(θ)   cos(θ)]   [y]

    RoPE 구현

  • def precompute_rope_embeddings(seq_len, head_dim, base=10000):
        """
        RoPE 사전 계산
        seq_len: 최대 시퀀스 길이
        head_dim: Attention head 차원
        base: Frequency base (default 10000)
        """
        # 1. Frequency 계산
        # 낮은 차원 = 높은 주파수 (빠르게 회전)
        # 높은 차원 = 낮은 주파수 (천천히 회전)
        channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32)
        inv_freq = 1.0 / (base ** (channel_range / head_dim))
    
        # 2. Position별 각도 계산
        t = torch.arange(seq_len, dtype=torch.float32)
        freqs = torch.outer(t, inv_freq)  # (seq_len, head_dim//2)
    
        # 3. cos, sin 저장
        cos = freqs.cos()  # (seq_len, head_dim//2)
        sin = freqs.sin()  # (seq_len, head_dim//2)
    
        # 4. Broadcast를 위한 shape 조정
        cos = cos[None, :, None, :]  # (1, seq_len, 1, head_dim//2)
        sin = sin[None, :, None, :]
    
        return cos, sin

    Frequency 계산 이해:

  • head_dim = 64
    base = 10000
    
    # Channel 0, 1: inv_freq[0] = 1 / (10000^0) = 1.0
    # → θ = position * 1.0 (빠르게 회전!)
    
    # Channel 30, 31: inv_freq[15] = 1 / (10000^(30/64)) ≈ 0.001
    # → θ = position * 0.001 (천천히 회전!)

    왜 이게 좋을까요?

  • Position 1: θ = 1.0 * 1 = 1.0 (short-range)
    Position 2: θ = 1.0 * 2 = 2.0
    
    Position 1: θ = 0.001 * 1 = 0.001 (long-range)
    Position 100: θ = 0.001 * 100 = 0.1
    
    # 낮은 차원: 인접 위치 구분 (fine-grained)
    # 높은 차원: 먼 위치까지 인코딩 (coarse-grained)

    RoPE 적용

  • def apply_rotary_emb(x, cos, sin):
        """
        x: Query or Key tensor (B, n_head, seq_len, head_dim)
        cos, sin: Precomputed (1, seq_len, 1, head_dim//2)
        """
        # 1. Split into two halves
        d = x.shape[-1] // 2
        x1 = x[..., :d]   # First half
        x2 = x[..., d:]   # Second half
    
        # 2. Apply rotation
        # [x1]   →  [x1*cos - x2*sin]
        # [x2]      [x1*sin + x2*cos]
        y1 = x1 * cos - x2 * sin
        y2 = x1 * sin + x2 * cos
    
        # 3. Concatenate back
        out = torch.cat([y1, y2], dim=-1)
        return out

    예시:

  • # Head dim = 4 (simplified)
    x = [0.5, 0.3, -0.2, 0.8]  # Query at position 5
    
    # Position 5, frequency 0 (fast rotation):
    θ = 5 * 1.0 = 5.0
    cos(5.0) = 0.28
    sin(5.0) = -0.96
    
    # First pair (dims 0, 2):
    x1 = 0.5, x2 = -0.2
    y1 = 0.5 * 0.28 - (-0.2) * (-0.96) = 0.14 - 0.19 = -0.05
    y2 = 0.5 * (-0.96) + (-0.2) * 0.28 = -0.48 - 0.056 = -0.54
    
    # Result: [-0.05, ..., -0.54, ...]
    # Position 정보가 encoding됨!

    RoPE의 매직 속성:

    상대적 위치가 자동으로 인코딩됩니다!

  • q_5 = rotate(q, position=5)
    k_3 = rotate(k, position=3)
    
    # Attention score:
    q_5 · k_3 ∝ rotate(q·k, angle=(5-3)θ)
               = rotate(q·k, angle=)
    
    # Relative distance = 2를 자동으로 인코딩!

    문제 2: KV Cache가 너무 크다

    Inference할 때 큰 문제:

  • # 모델: 32 layers, 32 heads, head_dim=128
    # Sequence: 2048 tokens
    
    # KV cache size:
    cache_size = 2048 * 32 * 128 * 2 (K+V) * 32 layers * 2 bytes (bf16)
               = 1.07 GB per sample!
    
    # Batch 16:
    total = 1.07 * 16 = 17 GB just for KV cache! 😱

    해결책: Multi-Query Attention (MQA)

    핵심 아이디어: K와 V를 공유하자!

  • class MultiQueryAttention(nn.Module):
        def __init__(self, n_embd, n_head, n_kv_head):
            self.n_head = n_head          # 8 heads for Q
            self.n_kv_head = n_kv_head    # 1 or 2 heads for K, V
    
            head_dim = n_embd // n_head
    
            # Q has full heads
            self.q_proj = nn.Linear(n_embd, n_head * head_dim, bias=False)
    
            # K, V have fewer heads
            self.k_proj = nn.Linear(n_embd, n_kv_head * head_dim, bias=False)
            self.v_proj = nn.Linear(n_embd, n_kv_head * head_dim, bias=False)
    
        def forward(self, x):
            B, T, C = x.shape
    
            # Compute Q, K, V
            q = self.q_proj(x).view(B, T, self.n_head, self.head_dim)
            k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim)
            v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim)
    
            # Repeat K, V to match Q heads
            if self.n_head != self.n_kv_head:
                n_rep = self.n_head // self.n_kv_head
                k = repeat_kv(k, n_rep)  # Expand K
                v = repeat_kv(v, n_rep)  # Expand V
    
            # Now do attention as normal
            # ...

    repeat_kv 구현:

  • def repeat_kv(x, n_rep):
        """
        x: (B, T, n_kv_head, head_dim)
        n_rep: repetition factor
        Returns: (B, T, n_kv_head * n_rep, head_dim)
        """
        if n_rep == 1:
            return x
    
        bs, seq_len, n_kv_heads, head_dim = x.shape
    
        # Insert new dimension and expand
        x = x[:, :, :, None, :].expand(bs, seq_len, n_kv_heads, n_rep, head_dim)
    
        # Reshape to merge
        return x.reshape(bs, seq_len, n_kv_heads * n_rep, head_dim)

    예시:

  • # 8 Q heads, 2 KV heads
    n_head = 8
    n_kv_head = 2
    
    # K shape: (B, T, 2, 128)
    k = compute_keys(x)
    
    # Repeat 4 times (8 / 2 = 4)
    k_repeated = repeat_kv(k, 4)  # (B, T, 8, 128)
    
    # Now Q and K have matching heads!
    # Q heads 0,1,2,3 use K head 0
    # Q heads 4,5,6,7 use K head 1

    메모리 절감:

  • # Standard (n_kv_head = 8):
    KV cache = 2048 * 8 * 128 * 2 = 4 MB per layer
    
    # MQA (n_kv_head = 2):
    KV cache = 2048 * 2 * 128 * 2 = 1 MB per layer
    
    # Savings: 4x smaller! 🎉

    성능 trade-off:

  • n_kv_head = n_head (standard MHA):
      - 최고 품질
      - 가장 많은 메모리
    
    n_kv_head = n_head // 4 (GQA - Grouped Query):
      - 95% 품질
      - 4x 적은 메모리
    
    n_kv_head = 1 (MQA):
      - 90% 품질
      - 8x 적은 메모리

    nanochat는 n_kv_head = n_head (no sharing)를 사용합니다. 왜?
    - $100 예산에서는 메모리가 아직 충분
    - 최고 품질 우선!

    문제 3: Attention은 O(T²) 메모리

    Standard attention의 메모리 문제:

  • # Attention matrix
    Q @ K.T = (B, H, T, head_dim) @ (B, H, head_dim, T)
            = (B, H, T, T)  ← 문제!
    
    # T = 2048:
    attention_matrix = 2048 * 2048 = 4M values per head
                     = 4M * 32 heads = 128M values
                     = 256 MB per layer! (bf16)
    
    # 32 layers:
    total = 256 MB * 32 = 8 GB just for attention! 😱

    해결책: Flash Attention

    Flash Attention의 핵심 insight:

    "Attention matrix를 저장하지 말고, 필요할 때마다 재계산하자"

  • # Standard Attention:
    def standard_attention(Q, K, V):
        # Step 1: Compute and STORE full attention
        scores = Q @ K.T / sqrt(d)  # (T, T) stored in memory
        attn = softmax(scores)       # (T, T) stored in memory
    
        # Step 2: Apply to V
        out = attn @ V
    
        # Memory: O(T²)
        # Compute: O(T²)
  • # Flash Attention:
    def flash_attention(Q, K, V):
        # Process in tiles
        for block_q in range(num_blocks):
            for block_kv in range(num_blocks):
                # Load small tiles to SRAM
                Q_tile = Q[block_q]  # (block_size, d)
                K_tile = K[block_kv]  # (block_size, d)
                V_tile = V[block_kv]
    
                # Compute attention for this tile
                scores_tile = Q_tile @ K_tile.T
                attn_tile = softmax(scores_tile)
                out_tile = attn_tile @ V_tile
    
                # Accumulate result
                # (no storage of full matrix!)
    
        # Memory: O(block_size²) << O(T²)
        # Compute: O(T²) but recompute in backward

    IO-aware design:

  • GPU Memory Hierarchy:
    HBM (High Bandwidth Memory): 2 TB/s, 80 GB
    SRAM (On-chip memory): 19 TB/s, 20 MB
    
    Standard Attention:
    - Store (T×T) in HBM
    - Slow access (2 TB/s)
    
    Flash Attention:
    - Keep tiles in SRAM
    - Fast access (19 TB/s, 9x faster!)
    - Result: 3-4x speedup!

    PyTorch 사용:

  • # Old way (manual):
    scores = Q @ K.transpose(-2, -1) / math.sqrt(head_dim)
    scores = scores.masked_fill(mask == 0, float('-inf'))
    attn = F.softmax(scores, dim=-1)
    out = attn @ V
    
    # New way (Flash Attention built-in):
    out = F.scaled_dot_product_attention(
        Q, K, V,
        is_causal=True  # Automatic causal masking!
    )
    
    # PyTorch automatically uses Flash Attention 2 if available!

    벤치마크:

  • # Sequence length = 2048
    
    Standard Attention:
    - Forward: 15 ms
    - Memory: 8 GB
    - Backward: 30 ms
    
    Flash Attention 2:
    - Forward: 4 ms (3.75x faster!)
    - Memory: 2 GB (4x less!)
    - Backward: 8 ms (3.75x faster!)

    nanochat의 현대적 Attention

    모든 것을 합치면:

  • # nanochat/gpt.py
    class Attention(nn.Module):
        def __init__(self, config):
            self.n_head = config.n_head
            self.n_kv_head = config.n_kv_head
            self.head_dim = config.n_embd // config.n_head
    
            # Combined QKV projection
            self.c_attn = nn.Linear(
                config.n_embd,
                config.n_embd + 2 * self.n_kv_head * self.head_dim,
                bias=False
            )
    
            # RoPE
            self.rope = RotaryEmbedding(self.head_dim)
    
        def forward(self, x):
            B, T, C = x.shape
    
            # 1. Compute Q, K, V
            q, k, v = self.c_attn(x).split([...], dim=2)
    
            # 2. Reshape for multi-head
            q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
            k = k.view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
            v = v.view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
    
            # 3. Apply RoPE (position encoding)
            freqs_cos, freqs_sin = self.rope(T)
            q = apply_rotary_emb(q, freqs_cos, freqs_sin)
            k = apply_rotary_emb(k, freqs_cos, freqs_sin)
    
            # 4. Repeat K, V if using MQA/GQA
            if self.n_kv_head != self.n_head:
                k = repeat_kv(k, self.n_head // self.n_kv_head)
                v = repeat_kv(v, self.n_head // self.n_kv_head)
    
            # 5. Flash Attention!
            y = F.scaled_dot_product_attention(
                q, k, v,
                is_causal=True,
                dropout_p=self.dropout if self.training else 0.0
            )
    
            # 6. Reshape back
            y = y.transpose(1, 2).contiguous().view(B, T, C)
    
            return y

    추가 최적화들

    QK Normalization

  • # Problem: Q·K can have large variance
    # → unstable training
    
    # Solution: Normalize Q and K
    q = F.normalize(q, dim=-1)
    k = F.normalize(k, dim=-1)
    
    # Now Q·K is bounded: [-1, 1]
    # → stable training!

    Grouped Query Attention (GQA)

    MHA와 MQA의 중간:

  • n_head = 32      # Q heads
    n_kv_head = 8    # KV heads (4x sharing)
    
    # Better than MQA (quality)
    # Smaller than MHA (memory)

    Llama 2, Mistral 등이 사용!

    성능 종합

  • # Baseline (vanilla Transformer, 2017):
    Training time: 10 hours
    Memory: 80 GB (full)
    Sequence length: 512 (max)
    
    # + RoPE:
    Training time: 10 hours
    Memory: 80 GB
    Sequence length: ∞ (arbitrary!)
    
    # + MQA (n_kv_head = 4):
    Training time: 10 hours
    Memory: 60 GB (-25%)
    Sequence length: ∞
    
    # + Flash Attention:
    Training time: 4 hours (-60%!)
    Memory: 40 GB (-50%!)
    Sequence length: 2048+
    
    # nanochat d20:
    Training time: 4 hours
    Memory: 40 GB
    Cost: $100
    CORE score: 40 (comparable to GPT-2!)

    핵심 요약

    RoPE = 회전 기반 Position Encoding
    - Relative position 자동 인코딩
    - 임의 길이 지원
    - Learned embedding 불필요

    MQA = K, V 공유
    - KV cache 4-8x 감소
    - Inference 메모리 최적화
    - 품질 손실 최소

    Flash Attention = 타일링 + 재계산
    - O(T²) → O(T) 메모리
    - 3-4x 속도 향상
    - SRAM 활용

    Combined = 현대 LLM 표준
    - GPT-4, Llama 2, Mistral 모두 사용
    - nanochat도 구현!

    다음 단계

    Part 5: "완전한 GPT 모델 조립"에서 다룰 내용:
    - Feed-Forward Network (ReLU²)
    - RMSNorm vs LayerNorm
    - Residual connections
    - Complete Transformer Block
    - 전체 모델 통합

    이제 Attention을 마스터했으니, 나머지 컴포넌트를 조립해봅시다! 🚀

    ---

    📘 참고

    본 포스트는 Andrej Karpathy의 nanochat 오픈소스 프로젝트를 기반으로, 코드 구조와 학습 과정을 분석/설명하기 위해 작성되었습니다. 원본 코드는 MIT License 하에 배포됩니다. 모든 코드 예제는 교육 목적으로 단순화되었습니다.

    참고 자료:
    - [RoPE 논문](https://arxiv.org/abs/2104.09864)
    - [Flash Attention 논문](https://arxiv.org/abs/2205.14135)
    - [GQA 논문](https://arxiv.org/abs/2305.13245)

    태그: #RoPE #FlashAttention #MQA #Optimization #nanochat #ModernLLM