[Nanochat 분석하기] 4. RoPE, MQA, Flash Attention
현대 LLM의 비밀: RoPE, MQA, Flash Attention
**시리즈**: 나노챗(nanochat)으로 배우는 LLM 구축 - Part 4/9
Attention의 진화¶
2017년 Transformer 논문 이후 7년간, Attention은 엄청나게 발전했습니다:
2017: Vanilla Attention
  ↓
2019: Learned Positional Embeddings
  ↓
2021: RoPE (Rotary Position Embedding)
  ↓
2022: Flash Attention (Memory Efficient)
  ↓
2023: Multi-Query Attention (KV Cache 최적화)
  ↓
2024: Flash Attention 2 (2x faster!)nanochat는 이 모든 최신 기법을 사용합니다. 하나씩 살펴봅시다.
문제 1: Position 정보가 없다¶
기본 Attention의 큰 문제:
sentence1 = ["cat", "chased", "mouse"]
sentence2 = ["mouse", "chased", "cat"]
# Attention은 이 둘을 구분 못함!
# {cat, chased, mouse} = {mouse, chased, cat}
# Order를 전혀 고려 안 함 ❌해결책 1: Learned Position Embeddings (GPT-2 방식)¶
class OldStyleGPT(nn.Module):
    def __init__(self, vocab_size, n_embd, max_len):
        # Token embedding
        self.token_emb = nn.Embedding(vocab_size, n_embd)
        # Position embedding (learned)
        self.pos_emb = nn.Embedding(max_len, n_embd)
    def forward(self, idx):
        B, T = idx.shape
        # Token embeddings
        tok_emb = self.token_emb(idx)  # (B, T, n_embd)
        # Position embeddings
        pos = torch.arange(T, device=idx.device)
        pos_emb = self.pos_emb(pos)  # (T, n_embd)
        # Add them together
        x = tok_emb + pos_emb  # (B, T, n_embd)
        return x문제점:
# 훈련: max_len = 1024
# 추론: "Can I use 2048 tokens?"
# → NO! Position embedding은 1024까지만 학습됨 ❌
# Position 2000:
pos_emb[2000]  # Error: index out of range!해결책 2: RoPE (Rotary Position Embedding) ⭐¶
RoPE의 핵심 아이디어: Position을 회전으로 인코딩
복소수 평면에서의 회전을 생각해봅시다:
# Position 0: 회전 0도
# Position 1: 회전 θ
# Position 2: 회전 2θ
# Position 3: 회전 3θ
# ...2D 회전 행렬:
[cos(θ)  -sin(θ)]   [x]
[sin(θ)   cos(θ)]   [y]RoPE 구현¶
def precompute_rope_embeddings(seq_len, head_dim, base=10000):
    """
    RoPE 사전 계산
    seq_len: 최대 시퀀스 길이
    head_dim: Attention head 차원
    base: Frequency base (default 10000)
    """
    # 1. Frequency 계산
    # 낮은 차원 = 높은 주파수 (빠르게 회전)
    # 높은 차원 = 낮은 주파수 (천천히 회전)
    channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32)
    inv_freq = 1.0 / (base ** (channel_range / head_dim))
    # 2. Position별 각도 계산
    t = torch.arange(seq_len, dtype=torch.float32)
    freqs = torch.outer(t, inv_freq)  # (seq_len, head_dim//2)
    # 3. cos, sin 저장
    cos = freqs.cos()  # (seq_len, head_dim//2)
    sin = freqs.sin()  # (seq_len, head_dim//2)
    # 4. Broadcast를 위한 shape 조정
    cos = cos[None, :, None, :]  # (1, seq_len, 1, head_dim//2)
    sin = sin[None, :, None, :]
    return cos, sinFrequency 계산 이해:
head_dim = 64
base = 10000
# Channel 0, 1: inv_freq[0] = 1 / (10000^0) = 1.0
# → θ = position * 1.0 (빠르게 회전!)
# Channel 30, 31: inv_freq[15] = 1 / (10000^(30/64)) ≈ 0.001
# → θ = position * 0.001 (천천히 회전!)왜 이게 좋을까요?
Position 1: θ = 1.0 * 1 = 1.0 (short-range)
Position 2: θ = 1.0 * 2 = 2.0
Position 1: θ = 0.001 * 1 = 0.001 (long-range)
Position 100: θ = 0.001 * 100 = 0.1
# 낮은 차원: 인접 위치 구분 (fine-grained)
# 높은 차원: 먼 위치까지 인코딩 (coarse-grained)RoPE 적용¶
def apply_rotary_emb(x, cos, sin):
    """
    x: Query or Key tensor (B, n_head, seq_len, head_dim)
    cos, sin: Precomputed (1, seq_len, 1, head_dim//2)
    """
    # 1. Split into two halves
    d = x.shape[-1] // 2
    x1 = x[..., :d]   # First half
    x2 = x[..., d:]   # Second half
    # 2. Apply rotation
    # [x1]   →  [x1*cos - x2*sin]
    # [x2]      [x1*sin + x2*cos]
    y1 = x1 * cos - x2 * sin
    y2 = x1 * sin + x2 * cos
    # 3. Concatenate back
    out = torch.cat([y1, y2], dim=-1)
    return out예시:
# Head dim = 4 (simplified)
x = [0.5, 0.3, -0.2, 0.8]  # Query at position 5
# Position 5, frequency 0 (fast rotation):
θ = 5 * 1.0 = 5.0
cos(5.0) = 0.28
sin(5.0) = -0.96
# First pair (dims 0, 2):
x1 = 0.5, x2 = -0.2
y1 = 0.5 * 0.28 - (-0.2) * (-0.96) = 0.14 - 0.19 = -0.05
y2 = 0.5 * (-0.96) + (-0.2) * 0.28 = -0.48 - 0.056 = -0.54
# Result: [-0.05, ..., -0.54, ...]
# Position 정보가 encoding됨!RoPE의 매직 속성:
상대적 위치가 자동으로 인코딩됩니다!
q_5 = rotate(q, position=5)
k_3 = rotate(k, position=3)
# Attention score:
q_5 · k_3 ∝ rotate(q·k, angle=(5-3)θ)
           = rotate(q·k, angle=2θ)
# Relative distance = 2를 자동으로 인코딩!문제 2: KV Cache가 너무 크다¶
Inference할 때 큰 문제:
# 모델: 32 layers, 32 heads, head_dim=128
# Sequence: 2048 tokens
# KV cache size:
cache_size = 2048 * 32 * 128 * 2 (K+V) * 32 layers * 2 bytes (bf16)
           = 1.07 GB per sample!
# Batch 16:
total = 1.07 * 16 = 17 GB just for KV cache! 😱해결책: Multi-Query Attention (MQA)¶
핵심 아이디어: K와 V를 공유하자!
class MultiQueryAttention(nn.Module):
    def __init__(self, n_embd, n_head, n_kv_head):
        self.n_head = n_head          # 8 heads for Q
        self.n_kv_head = n_kv_head    # 1 or 2 heads for K, V
        head_dim = n_embd // n_head
        # Q has full heads
        self.q_proj = nn.Linear(n_embd, n_head * head_dim, bias=False)
        # K, V have fewer heads
        self.k_proj = nn.Linear(n_embd, n_kv_head * head_dim, bias=False)
        self.v_proj = nn.Linear(n_embd, n_kv_head * head_dim, bias=False)
    def forward(self, x):
        B, T, C = x.shape
        # Compute Q, K, V
        q = self.q_proj(x).view(B, T, self.n_head, self.head_dim)
        k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim)
        v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim)
        # Repeat K, V to match Q heads
        if self.n_head != self.n_kv_head:
            n_rep = self.n_head // self.n_kv_head
            k = repeat_kv(k, n_rep)  # Expand K
            v = repeat_kv(v, n_rep)  # Expand V
        # Now do attention as normal
        # ...repeat_kv 구현:
def repeat_kv(x, n_rep):
    """
    x: (B, T, n_kv_head, head_dim)
    n_rep: repetition factor
    Returns: (B, T, n_kv_head * n_rep, head_dim)
    """
    if n_rep == 1:
        return x
    bs, seq_len, n_kv_heads, head_dim = x.shape
    # Insert new dimension and expand
    x = x[:, :, :, None, :].expand(bs, seq_len, n_kv_heads, n_rep, head_dim)
    # Reshape to merge
    return x.reshape(bs, seq_len, n_kv_heads * n_rep, head_dim)예시:
# 8 Q heads, 2 KV heads
n_head = 8
n_kv_head = 2
# K shape: (B, T, 2, 128)
k = compute_keys(x)
# Repeat 4 times (8 / 2 = 4)
k_repeated = repeat_kv(k, 4)  # (B, T, 8, 128)
# Now Q and K have matching heads!
# Q heads 0,1,2,3 use K head 0
# Q heads 4,5,6,7 use K head 1메모리 절감:
# Standard (n_kv_head = 8):
KV cache = 2048 * 8 * 128 * 2 = 4 MB per layer
# MQA (n_kv_head = 2):
KV cache = 2048 * 2 * 128 * 2 = 1 MB per layer
# Savings: 4x smaller! 🎉성능 trade-off:
n_kv_head = n_head (standard MHA):
  - 최고 품질
  - 가장 많은 메모리
n_kv_head = n_head // 4 (GQA - Grouped Query):
  - 95% 품질
  - 4x 적은 메모리
n_kv_head = 1 (MQA):
  - 90% 품질
  - 8x 적은 메모리nanochat는 n_kv_head = n_head (no sharing)를 사용합니다. 왜?
- $100 예산에서는 메모리가 아직 충분
- 최고 품질 우선!
문제 3: Attention은 O(T²) 메모리¶
Standard attention의 메모리 문제:
# Attention matrix
Q @ K.T = (B, H, T, head_dim) @ (B, H, head_dim, T)
        = (B, H, T, T)  ← 문제!
# T = 2048:
attention_matrix = 2048 * 2048 = 4M values per head
                 = 4M * 32 heads = 128M values
                 = 256 MB per layer! (bf16)
# 32 layers:
total = 256 MB * 32 = 8 GB just for attention! 😱해결책: Flash Attention¶
Flash Attention의 핵심 insight:
"Attention matrix를 저장하지 말고, 필요할 때마다 재계산하자"
# Standard Attention:
def standard_attention(Q, K, V):
    # Step 1: Compute and STORE full attention
    scores = Q @ K.T / sqrt(d)  # (T, T) stored in memory
    attn = softmax(scores)       # (T, T) stored in memory
    # Step 2: Apply to V
    out = attn @ V
    # Memory: O(T²)
    # Compute: O(T²)# Flash Attention:
def flash_attention(Q, K, V):
    # Process in tiles
    for block_q in range(num_blocks):
        for block_kv in range(num_blocks):
            # Load small tiles to SRAM
            Q_tile = Q[block_q]  # (block_size, d)
            K_tile = K[block_kv]  # (block_size, d)
            V_tile = V[block_kv]
            # Compute attention for this tile
            scores_tile = Q_tile @ K_tile.T
            attn_tile = softmax(scores_tile)
            out_tile = attn_tile @ V_tile
            # Accumulate result
            # (no storage of full matrix!)
    # Memory: O(block_size²) << O(T²)
    # Compute: O(T²) but recompute in backwardIO-aware design:
GPU Memory Hierarchy:
HBM (High Bandwidth Memory): 2 TB/s, 80 GB
SRAM (On-chip memory): 19 TB/s, 20 MB
Standard Attention:
- Store (T×T) in HBM
- Slow access (2 TB/s)
Flash Attention:
- Keep tiles in SRAM
- Fast access (19 TB/s, 9x faster!)
- Result: 3-4x speedup!PyTorch 사용:
# Old way (manual):
scores = Q @ K.transpose(-2, -1) / math.sqrt(head_dim)
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
out = attn @ V
# New way (Flash Attention built-in):
out = F.scaled_dot_product_attention(
    Q, K, V,
    is_causal=True  # Automatic causal masking!
)
# PyTorch automatically uses Flash Attention 2 if available!벤치마크:
# Sequence length = 2048
Standard Attention:
- Forward: 15 ms
- Memory: 8 GB
- Backward: 30 ms
Flash Attention 2:
- Forward: 4 ms (3.75x faster!)
- Memory: 2 GB (4x less!)
- Backward: 8 ms (3.75x faster!)nanochat의 현대적 Attention¶
모든 것을 합치면:
# nanochat/gpt.py
class Attention(nn.Module):
    def __init__(self, config):
        self.n_head = config.n_head
        self.n_kv_head = config.n_kv_head
        self.head_dim = config.n_embd // config.n_head
        # Combined QKV projection
        self.c_attn = nn.Linear(
            config.n_embd,
            config.n_embd + 2 * self.n_kv_head * self.head_dim,
            bias=False
        )
        # RoPE
        self.rope = RotaryEmbedding(self.head_dim)
    def forward(self, x):
        B, T, C = x.shape
        # 1. Compute Q, K, V
        q, k, v = self.c_attn(x).split([...], dim=2)
        # 2. Reshape for multi-head
        q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
        # 3. Apply RoPE (position encoding)
        freqs_cos, freqs_sin = self.rope(T)
        q = apply_rotary_emb(q, freqs_cos, freqs_sin)
        k = apply_rotary_emb(k, freqs_cos, freqs_sin)
        # 4. Repeat K, V if using MQA/GQA
        if self.n_kv_head != self.n_head:
            k = repeat_kv(k, self.n_head // self.n_kv_head)
            v = repeat_kv(v, self.n_head // self.n_kv_head)
        # 5. Flash Attention!
        y = F.scaled_dot_product_attention(
            q, k, v,
            is_causal=True,
            dropout_p=self.dropout if self.training else 0.0
        )
        # 6. Reshape back
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return y추가 최적화들¶
QK Normalization¶
# Problem: Q·K can have large variance
# → unstable training
# Solution: Normalize Q and K
q = F.normalize(q, dim=-1)
k = F.normalize(k, dim=-1)
# Now Q·K is bounded: [-1, 1]
# → stable training!Grouped Query Attention (GQA)¶
MHA와 MQA의 중간:
n_head = 32      # Q heads
n_kv_head = 8    # KV heads (4x sharing)
# Better than MQA (quality)
# Smaller than MHA (memory)Llama 2, Mistral 등이 사용!
성능 종합¶
# Baseline (vanilla Transformer, 2017):
Training time: 10 hours
Memory: 80 GB (full)
Sequence length: 512 (max)
# + RoPE:
Training time: 10 hours
Memory: 80 GB
Sequence length: ∞ (arbitrary!)
# + MQA (n_kv_head = 4):
Training time: 10 hours
Memory: 60 GB (-25%)
Sequence length: ∞
# + Flash Attention:
Training time: 4 hours (-60%!)
Memory: 40 GB (-50%!)
Sequence length: 2048+
# nanochat d20:
Training time: 4 hours
Memory: 40 GB
Cost: $100
CORE score: 40 (comparable to GPT-2!)핵심 요약¶
✅ RoPE = 회전 기반 Position Encoding
- Relative position 자동 인코딩
- 임의 길이 지원
- Learned embedding 불필요
✅ MQA = K, V 공유
- KV cache 4-8x 감소
- Inference 메모리 최적화
- 품질 손실 최소
✅ Flash Attention = 타일링 + 재계산
- O(T²) → O(T) 메모리
- 3-4x 속도 향상
- SRAM 활용
✅ Combined = 현대 LLM 표준
- GPT-4, Llama 2, Mistral 모두 사용
- nanochat도 구현!
다음 단계¶
Part 5: "완전한 GPT 모델 조립"에서 다룰 내용:
- Feed-Forward Network (ReLU²)
- RMSNorm vs LayerNorm
- Residual connections
- Complete Transformer Block
- 전체 모델 통합
이제 Attention을 마스터했으니, 나머지 컴포넌트를 조립해봅시다! 🚀
---
📘 참고¶
본 포스트는 Andrej Karpathy의 nanochat 오픈소스 프로젝트를 기반으로, 코드 구조와 학습 과정을 분석/설명하기 위해 작성되었습니다. 원본 코드는 MIT License 하에 배포됩니다. 모든 코드 예제는 교육 목적으로 단순화되었습니다.
참고 자료:
- [RoPE 논문](https://arxiv.org/abs/2104.09864)
- [Flash Attention 논문](https://arxiv.org/abs/2205.14135)
- [GQA 논문](https://arxiv.org/abs/2305.13245)
태그: #RoPE #FlashAttention #MQA #Optimization #nanochat #ModernLLM