[Nanochat 분석하기] 3. Self-Attention 완벽 이해

Read time: 2 minutes

Transformer의 핵심: Self-Attention 완벽 이해

**시리즈**: 나노챗(nanochat)으로 배우는 LLM 구축 - Part 3/9

Attention이란 무엇인가?

2017년 Google의 "Attention Is All You Need" 논문은 AI 역사를 바꿨습니다. Attention 메커니즘의 핵심 아이디어는 놀랍도록 단순합니다:

"모든 단어가 다른 모든 단어를 볼 수 있게 하자"

예문을 봅시다:

  • "The cat sat on the mat because it was tired"

    "it"이 무엇을 가리킬까요? 사람은 "cat"을 가리킨다는 걸 압니다. Attention은 모델이 이런 연결을 학습하게 합니다.

    Baseline: Bigram Model (Attention 없이)

    Attention의 가치를 이해하려면, 먼저 Attention 없는 모델을 봐야 합니다.

  • class BigramLanguageModel(nn.Module):
        def __init__(self, vocab_size):
            super().__init__()
            # 모델 전체가 이것뿐!
            self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
    
        def forward(self, idx):
            # idx: [batch, seq_len]
            logits = self.token_embedding_table(idx)  # [batch, seq_len, vocab_size]
            return logits

    문제점:

  • text = "The cat sat on the mat"
    # Bigram model:
    # P("cat" | "The") - only看 "The"
    # P("sat" | "cat") - only看 "cat"
    # P("on" | "sat") - only看 "sat"
    
    # Context가 1개 토큰으로 제한됨! ❌

    "The cat"을 보고 "sat"을 예측하는 것과 "A dog"를 보고 "sat"을 예측하는 것이 다른데, Bigram은 이를 구분 못합니다.

    Self-Attention: 모든 것을 보기

    Self-Attention은 전체 문맥을 봅니다:

  • def simple_attention_example():
        sentence = ["The", "cat", "sat"]
    
        # "sat"을 예측할 때:
        # - "The": 0.1 가중치 (약간 참고)
        # - "cat": 0.7 가중치 (주로 참고!)
        # - "sat": 0.2 가중치 (자기 자신)
    
        # Weighted average:
        # representation_of_sat = 0.1*emb("The") + 0.7*emb("cat") + 0.2*emb("sat")

    이 가중치를 어떻게 계산하느냐가 Attention의 핵심입니다!

    Query, Key, Value: 도서관 비유

    Attention을 이해하는 가장 좋은 방법은 도서관 비유입니다:

    Query (Q): "나는 무엇을 찾고 있는가?"

  • # "it" 토큰의 Query
    query_it = "What animal is this referring to?"

    Key (K): "나는 무슨 정보를 제공하는가?"

  • # "cat" 토큰의 Key
    key_cat = "I provide information about animals"
    
    # "the" 토큰의 Key
    key_the = "I provide information about articles"

    Value (V): "실제로 전달할 정보"

  • # "cat" 토큰의 Value
    value_cat = [vector representing "cat" meaning]

    Attention 수식: 단계별 분해

    완전한 수식은:

    $$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

    복잡해 보이지만 단계별로 나누면 간단합니다.

    Step 1: QKV 계산

  • class SingleHeadAttention(nn.Module):
        def __init__(self, n_embd, head_size):
            super().__init__()
            self.query = nn.Linear(n_embd, head_size, bias=False)
            self.key = nn.Linear(n_embd, head_size, bias=False)
            self.value = nn.Linear(n_embd, head_size, bias=False)
    
        def forward(self, x):
            # x: [batch, seq_len, n_embd]
            B, T, C = x.shape
    
            q = self.query(x)  # [B, T, head_size]
            k = self.key(x)    # [B, T, head_size]
            v = self.value(x)  # [B, T, head_size]

    예시:

  • x = embeddings  # [2, 4, 64] (batch=2, seq_len=4, n_embd=64)
    
    # 각 토큰마다 Q, K, V 계산
    q = W_q @ x  # [2, 4, 16] (head_size=16)
    k = W_k @ x  # [2, 4, 16]
    v = W_v @ x  # [2, 4, 16]

    Step 2: Attention Scores (QK^T)

  •         # Attention scores
            scores = q @ k.transpose(-2, -1)  # [B, T, T]

    이게 핵심입니다! Dot product가 유사도를 계산합니다:

  • # "it" (position 3) query
    q[3] = [0.2, 0.5, -0.3, 0.8, ...]
    
    # "cat" (position 1) key
    k[1] = [0.3, 0.4, -0.2, 0.9, ...]
    
    # Similarity (dot product)
    similarity = q[3] · k[1]
               = 0.2*0.3 + 0.5*0.4 + (-0.3)*(-0.2) + 0.8*0.9 + ...
               = 1.12  # High score! "it"는 "cat"과 관련 있음

    Step 3: Scaling (1/√d_k)

  •         # Scale by sqrt(head_size)
            scores = scores / math.sqrt(self.head_size)

    왜 스케일링?

  • # Without scaling:
    head_size = 64
    q · k = sum of 64 multiplications
    
    # 예시:
    # q = [1, 1, 1, ..., 1] (64 ones)
    # k = [1, 1, 1, ..., 1]
    # q · k = 64  ← 너무 큼!
    
    # After softmax:
    scores = [64, 5, 10]
    softmax(scores) = [0.9999, 0.0000, 0.0001]
    # Gradient vanishes! ❌
    
    # With scaling:
    scores = [64, 5, 10] / sqrt(64)
           = [8, 0.625, 1.25]
    softmax(scores) = [0.997, 0.001, 0.002]
    # Better gradients! ✅

    Step 4: Causal Masking

    LLM은 미래를 보면 안 됩니다!

  •         # Causal mask
            mask = torch.tril(torch.ones(T, T))  # Lower triangular
            scores = scores.masked_fill(mask == 0, float('-inf'))

    Mask 시각화:

  • T = 4
    mask = torch.tril(torch.ones(4, 4))
    
    # mask:
    # [[1, 0, 0, 0],   1 = 볼 수 있음
    #  [1, 1, 0, 0],   0 = 볼 수 없음 → -inf
    #  [1, 1, 1, 0],
    #  [1, 1, 1, 1]]
    
    # scores after masking:
    # [[5.2, -inf, -inf, -inf],   ← Position 0은 자기만 봄
    #  [3.1,  4.5, -inf, -inf],   ← Position 1은 0, 1 봄
    #  [2.8,  3.9,  5.1, -inf],   ← Position 2는 0, 1, 2 봄
    #  [4.2,  6.1,  3.7,  4.9]]   ← Position 3은 모두 봄

    Step 5: Softmax

  •         # Softmax to get probabilities
            attn = F.softmax(scores, dim=-1)  # [B, T, T]

    -inf는 확률 0이 됩니다:

  • scores = [[5.2, -inf, -inf, -inf]]
    
    after softmax:
    attn = [[1.0, 0.0, 0.0, 0.0]]
    # Position 0은 100% 자기 자신만 attend!

    Step 6: Value와 곱하기

  •         # Weighted sum of values
            output = attn @ v  # [B, T, head_size]
            return output

    최종 계산:

  • # For "it" at position 3:
    attn[3] = [0.05, 0.70, 0.10, 0.15]  # Attention weights
    #          The   cat   sat   on
    
    v[0] = value_of_The
    v[1] = value_of_cat
    v[2] = value_of_sat
    v[3] = value_of_on
    
    output[3] = 0.05 * v[0] + 0.70 * v[1] + 0.10 * v[2] + 0.15 * v[3]
    #           ← "cat"의 정보가 70%!

    완전한 구현

  • class SingleHeadAttention(nn.Module):
        def __init__(self, n_embd, head_size, block_size):
            super().__init__()
            self.head_size = head_size
    
            # Q, K, V projections
            self.query = nn.Linear(n_embd, head_size, bias=False)
            self.key = nn.Linear(n_embd, head_size, bias=False)
            self.value = nn.Linear(n_embd, head_size, bias=False)
    
            # Causal mask (not a parameter)
            self.register_buffer(
                'tril',
                torch.tril(torch.ones(block_size, block_size))
            )
    
        def forward(self, x):
            B, T, C = x.shape
    
            # 1. Compute Q, K, V
            q = self.query(x)  # (B, T, head_size)
            k = self.key(x)    # (B, T, head_size)
            v = self.value(x)  # (B, T, head_size)
    
            # 2. Compute attention scores
            scores = q @ k.transpose(-2, -1)  # (B, T, T)
    
            # 3. Scale
            scores = scores / math.sqrt(self.head_size)
    
            # 4. Apply causal mask
            scores = scores.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
    
            # 5. Softmax
            attn = F.softmax(scores, dim=-1)  # (B, T, T)
    
            # 6. Weighted sum of values
            out = attn @ v  # (B, T, head_size)
    
            return out

    Multi-Head Attention: 여러 관점

    Single-head는 하나의 "질문"만 할 수 있습니다. Multi-head는 여러 질문을 동시에!

  • # Head 1: "What's the subject?"
    # → "cat" focuses on nouns
    
    # Head 2: "What's the action?"
    # → "sat" focuses on verbs
    
    # Head 3: "What's the location?"
    # → "on the mat" focuses on prepositions
    
    # Head 4: "What's the temporal context?"
    # → focuses on time indicators

    구현:

  • class MultiHeadAttention(nn.Module):
        def __init__(self, n_embd, n_head, head_size, block_size):
            super().__init__()
    
            # Multiple attention heads
            self.heads = nn.ModuleList([
                SingleHeadAttention(n_embd, head_size, block_size)
                for _ in range(n_head)
            ])
    
            # Output projection
            self.proj = nn.Linear(n_head * head_size, n_embd)
    
        def forward(self, x):
            # Run all heads in parallel
            head_outputs = [head(x) for head in self.heads]
    
            # Concatenate all heads
            out = torch.cat(head_outputs, dim=-1)  # (B, T, n_head * head_size)
    
            # Project back to n_embd
            out = self.proj(out)  # (B, T, n_embd)
    
            return out

    예시:

  • n_embd = 64
    n_head = 4
    head_size = 16
    
    # Each head outputs (B, T, 16)
    # 4 heads concatenate to (B, T, 64)
    # Project keeps (B, T, 64)
    
    # Shape preservation allows stacking layers!

    nanochat의 Attention

    nanochat는 더 현대적인 구현을 사용합니다:

  • # nanochat/gpt.py (line 110-180)
    class Attention(nn.Module):
        def __init__(self, config):
            # One linear layer for all Q, K, V (more efficient!)
            self.c_attn = nn.Linear(
                config.n_embd,
                config.n_embd + 2 * config.n_kv_head * config.head_dim,
                bias=False
            )
    
        def forward(self, x, freqs_cos, freqs_sin):
            B, T, C = x.shape
    
            # Compute Q, K, V in one shot
            qkv = self.c_attn(x)
            q, k, v = qkv.split([self.n_embd, ...], dim=2)
    
            # Reshape for multi-head
            q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
            k = k.view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
            v = v.view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
    
            # Apply RoPE (more on this in next post!)
            q, k = apply_rotary_emb(q, k, freqs_cos, freqs_sin)
    
            # PyTorch's optimized attention (Flash Attention!)
            y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
    
            # Reshape back
            y = y.transpose(1, 2).contiguous().view(B, T, C)
    
            return y

    주요 차이점:
    - 하나의 Linear layer로 Q, K, V 계산 (빠름!)
    - Flash Attention 사용 (메모리 효율)
    - RoPE 적용 (다음 포스트에서!)

    성능 비교

  • # Test: Sentence completion
    text = "The cat sat on the"
    
    # Bigram (no attention):
    # Only看 "the" → predicts "mat" 20% accuracy
    
    # Single-head Attention:
    # 看 full context → predicts "mat" 60% accuracy
    
    # Multi-head Attention (4 heads):
    # Multiple perspectives → predicts "mat" 80% accuracy!

    핵심 요약

    Attention = 문맥을 보는 메커니즘
    - 모든 토큰이 다른 모든 토큰을 참조

    Q, K, V = 정보 검색 시스템
    - Query: 무엇을 찾는가?
    - Key: 무엇을 제공하는가?
    - Value: 실제 정보

    Causal Masking = 미래 차단
    - 학습과 추론을 동일하게
    - Lower triangular matrix

    Multi-Head = 여러 관점
    - 병렬로 다른 패턴 학습
    - 성능 향상

    수식 단계:
    1. QK^T: 유사도 계산
    2. /√d_k: 스케일링
    3. mask: 미래 차단
    4. softmax: 확률 변환
    5. ×V: 가중 합

    다음 단계

    Part 4: "현대 LLM의 비밀"에서 다룰 내용:
    - RoPE (Rotary Positional Embeddings)
    - MQA (Multi-Query Attention)
    - Flash Attention (3-4배 빠른 구현)
    - QK Normalization

    Attention의 기초를 마스터했으니, 이제 최신 기법들로 업그레이드합시다! 🚀

    ---

    📘 참고

    본 포스트는 Andrej Karpathy의 nanochat 오픈소스 프로젝트를 기반으로, 코드 구조와 학습 과정을 분석/설명하기 위해 작성되었습니다. 원본 코드는 MIT License 하에 배포됩니다. 모든 코드 예제는 교육 목적으로 단순화되었습니다.

    참고 자료:
    - [Attention Is All You Need (원논문)](https://arxiv.org/abs/1706.03762)
    - [The Illustrated Transformer](http://jalammar.github.io/illustrated-transformer/)
    - [Attention Visualization Tool](http://exbert.net/)

    태그: #Attention #Transformer #SelfAttention #DeepLearning #nanochat #QKV