[Nanochat 분석하기] 8. 대화형 AI 만들기

Read time: 2 minutes

Inference와 Fine-tuning: 모델을 대화형 AI로

**시리즈**: 나노챗(nanochat)으로 배우는 LLM 구축 - Part 8/9

Training vs Inference

지금까지 훈련했으니 이제 사용할 차례입니다!

  • Training (병렬 처리):
    Input:  "The cat sat on the"
    Target: "cat sat on the mat"
    → Process all tokens at once
    → Compute loss
    → Update weights
    
    Inference (순차 생성):
    Input: "The cat"
    → Generate "sat" (1 token)
    → Generate "on" (1 token)
    → Generate "the" (1 token)
    → Autoregressive, one at a time

    문제: Naive inference는 O(N²) 복잡도!

    KV Cache: 100배 속도 향상

    The O(N²) Problem

  • # Naive generation:
    tokens = ["The", "cat"]  # Start
    
    Step 1: Process ["The", "cat"] → predict "sat"
    Step 2: Process ["The", "cat", "sat"] → predict "on"
    Step 3: Process ["The", "cat", "sat", "on"] → predict "the"
    ...
    
    # Computation:
    # 2 + 3 + 4 + ... + N = O(N²) tokens processed! ❌

    The Solution: Cache K and V!

    Attention 수식을 다시 보면:

  • # At each step:
    Q = query for NEW token only (1 token)
    K = keys for ALL tokens (past + new)
    V = values for ALL tokens (past + new)
    
    out = softmax(Q @ K.T) @ V

    핵심 insight: K와 V는 이전 step과 겹침!

  • Step 1: K₁ = keys for ["The", "cat"]
    Step 2: K₂ = keys for ["The", "cat", "sat"]
    #           └─ Already computed! ─┘  └─ new ─┘
    
    # Cache K₁, only compute K for "sat"!
    K₂ = concat(K₁, K_new)

    구현:

  • class AttentionWithCache(nn.Module):
        def forward(self, x, kv_cache=None):
            B, T, C = x.shape
    
            # Compute Q, K, V for current input
            q, k, v = self.qkv(x).split([...], dim=-1)
    
            if kv_cache is not None:
                # Use cached K, V
                k_cache, v_cache = kv_cache
                k = torch.cat([k_cache, k], dim=1)  # Append new K
                v = torch.cat([v_cache, v], dim=1)  # Append new V
    
            # Attention as normal
            scores = q @ k.transpose(-2, -1) / sqrt(head_dim)
            attn = F.softmax(scores, dim=-1)
            out = attn @ v
    
            # Return output AND updated cache
            return out, (k, v)

    사용:

  • kv_cache = None
    
    for _ in range(max_new_tokens):
        # Only process NEW token!
        logits, kv_cache = model(new_token, kv_cache=kv_cache)
        next_token = sample(logits)
        tokens.append(next_token)
    
    # Complexity: O(N) instead of O(N²)! 🎉

    메모리 trade-off:

  • # Memory:
    # Store K, V for all past tokens
    cache_size = seq_len * n_layers * n_heads * head_dim * 2 (K+V) * 2 bytes
    
    # Example (nanochat d20):
    cache_size = 1024 * 20 * 10 * 128 * 2 * 2
               = 104 MB per sample
    
    # Speed:
    # 100x faster generation!
    
    # Worth it! ✅

    Sampling Strategies

    Greedy (argmax)는 boring:

  • logits = [2.0, 1.0, 0.5, ...]
    next_token = argmax(logits)  # Always picks index 0
    
    # Output: "The cat sat on the mat. The cat sat on the mat. ..."
    # Repetitive! ❌

    1. Temperature Sampling

  • def sample_with_temperature(logits, temperature=1.0):
        # Divide by temperature
        logits = logits / temperature
    
        # Softmax → probabilities
        probs = F.softmax(logits, dim=-1)
    
        # Sample
        return torch.multinomial(probs, num_samples=1)

    Temperature 효과:

  • logits = [2.0, 1.0, 0.5]
    
    T = 0.1 (confident):
      probs = [0.99, 0.01, 0.00]  # Almost deterministic
    
    T = 1.0 (balanced):
      probs = [0.66, 0.24, 0.10]  # Normal
    
    T = 2.0 (creative):
      probs = [0.46, 0.31, 0.23]  # More random

    2. Top-k Sampling

  • def top_k_sampling(logits, k=50):
        # Keep only top k
        values, indices = torch.topk(logits, k)
    
        # Zero out others
        logits_filtered = torch.full_like(logits, float('-inf'))
        logits_filtered.scatter_(1, indices, values)
    
        # Sample from top k only
        probs = F.softmax(logits_filtered, dim=-1)
        return torch.multinomial(probs, num_samples=1)

    왜 필요?

  • Vocabulary: 50,000 tokens
    
    Without top-k:
    - Might sample rare/gibberish token
    - "The cat <rare_token_49872>"
    
    With top-k (k=50):
    - Only sample from top 50 likely tokens
    - "The cat sat"

    3. Top-p (Nucleus) Sampling

  • def top_p_sampling(logits, p=0.9):
        # Sort by probability
        probs = F.softmax(logits, dim=-1)
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
    
        # Cumulative probability
        cumsum = torch.cumsum(sorted_probs, dim=-1)
    
        # Find nucleus (where cumsum > p)
        nucleus_size = (cumsum > p).nonzero()[0].item() + 1
    
        # Sample from nucleus
        nucleus_indices = sorted_indices[:nucleus_size]
        # ...

    Dynamic cutoff:

  • Probs: [0.5, 0.3, 0.15, 0.04, 0.01, ...]
    
    Top-p (p=0.9):
      0.5 → cumsum=0.5
      0.3 → cumsum=0.8
      0.15 → cumsum=0.95 > 0.9 ← Stop!
    
      Nucleus = [0.5, 0.3, 0.15] (3 tokens)
    
    # Adaptive! Sometimes 3 tokens, sometimes 10

    Combined strategy:

  • # nanochat defaults:
    next_token = sample(
        logits,
        temperature=1.0,
        top_k=50,
        top_p=0.95
    )
    
    # 1. Temperature scaling
    # 2. Keep top 50 tokens
    # 3. Within top 50, keep nucleus (p=0.95)

    Fine-tuning: Base → Chat

    Pretrained 모델은 completion만 합니다:

  • Input: "The capital of France is"
    Output: "Paris, which is located in northern France and has..."

    Chat 모델로 만들려면:

  • Input: <|user|>"What is the capital of France?"<|end|>
    Output: <|assistant|>"The capital of France is Paris."<|end|>

    Chat Format

  • SPECIAL_TOKENS = [
        "<|bos|>",              # Beginning of sequence
        "<|user_start|>",       # User message start
        "<|user_end|>",         # User message end
        "<|assistant_start|>",  # Assistant message start
        "<|assistant_end|>",    # Assistant message end
    ]
    
    # Example conversation:
    text = """
    <|bos|>
    <|user_start|>Hello!<|user_end|>
    <|assistant_start|>Hi there! How can I help?<|assistant_end|>
    <|user_start|>What's 2+2?<|user_end|>
    <|assistant_start|>2+2 equals 4.<|assistant_end|>
    """

    Loss Masking

    Critical: Only train on assistant responses!

  • text = "<|user|>Hello<|end|><|assistant|>Hi!<|end|>"
    tokens = [50297, 15496, 50298, 50299, 13347, 50300]
    
    # Loss mask:
    mask = [0, 0, 0, 1, 1, 1]
    #       └─ Don't train ─┘ └─ Train ─┘

    왜?

  • If we train on user messages:
      Model learns to predict user input
      → Useless! ❌
    
    If we train on assistant messages:
      Model learns to generate good responses
      → Useful! ✅

    구현:

  • def compute_loss_with_mask(logits, targets, mask):
        # Compute loss per position
        loss = F.cross_entropy(logits, targets, reduction='none')
    
        # Apply mask
        loss = loss * mask
    
        # Average over assistant tokens only
        return loss.sum() / mask.sum()

    SFT Training

  • # 1. Load chat dataset (SmolTalk)
    dataset = load_smoltalk()
    # 10K conversations, high quality
    
    # 2. Fine-tune with LOWER learning rate
    optimizer = AdamW(model.parameters(), lr=3e-5)  # 20x lower!
    
    # 3. Short training
    for step in range(500):  # vs 4681 for pretraining
        x, y, mask = next(dataloader)
    
        logits, _ = model(x)
        loss = compute_loss_with_mask(logits, y, mask)
    
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    왜 lower LR?

  • High LR:
      Model forgets pretraining knowledge
    "What's the capital of France?""I don't know"
      (Catastrophic forgetting!)
    
    Low LR (3e-5):
      Model retains knowledge ✅
      Adds conversational ability ✅
    "The capital of France is Paris" still works!

    REINFORCE: Learning from Rewards

    SFT learns from demonstrations, RL learns from trial and error!

  • # SFT:
    "If user asks X, respond Y"
    (Supervised, needs examples)
    
    # RL:
    "Try different responses, keep what works!"
    (Reinforcement, learns from rewards)

    REINFORCE Algorithm

  • for step in range(num_iterations):
        # 1. Generate response
        prompt = sample_problem()
        response, log_probs = model.generate_with_logprobs(prompt)
    
        # 2. Compute reward
        reward = reward_function(prompt, response)
    
        # 3. Policy gradient loss
        loss = -log_probs.sum() * reward
    
        # 4. Update model
        loss.backward()
        optimizer.step()

    Intuition:

  • # If reward is high (good response):
    loss = -log_prob * high_reward
    → Gradient INCREASES probability of this response
    
    # If reward is low (bad response):
    loss = -log_prob * low_reward
    → Gradient DECREASES probability of this response

    Reward Function (GSM8K Math)

  • def compute_reward(problem, response):
        # Extract answer from response
        predicted = extract_answer(response)
        correct = problem['answer']
    
        # Correctness reward
        reward = 1.0 if predicted == correct else 0.0
    
        # Format bonuses
        if "<|end|>" in response:
            reward += 0.1  # Proper ending
    
        if len(response) < 500:
            reward += 0.1  # Not too verbose
    
        return reward

    Example:

  • Problem: "If John has 5 apples and buys 3 more, how many does he have?"
    
    Response 1: "John has apples."
    → Reward: 0.0 (wrong answer)
    
    Response 2: "5 + 3 = 8, so John has 8 apples.<|end|>"
    → Reward: 1.0 + 0.1 + 0.1 = 1.2 (correct + format!)
    
    Model learns to generate Response 2 style! ✅

    RL Training Loop

  • # Use GSM8K dataset (grade school math)
    problems = load_gsm8k()
    
    optimizer = AdamW(model.parameters(), lr=1e-5)  # Even lower!
    
    for step in range(200):  # Short RL phase
        batch = sample(problems, batch_size=16)
    
        total_loss = 0
        for problem in batch:
            # Generate
            prompt = format_prompt(problem)
            response, log_probs = model.generate_with_logprobs(prompt)
    
            # Reward
            reward = compute_reward(problem, response)
    
            # Loss
            loss = -log_probs.sum() * reward
            total_loss += loss
    
        # Update
        total_loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
        print(f"Step {step}, Avg reward: {rewards.mean():.2f}")

    Results:

  • Before RL:
      Accuracy on GSM8K: 20%
    
    After RL (200 steps):
      Accuracy on GSM8K: 35%
    
    RL improves math reasoning! ✅

    Caution: RL is unstable!

  • Too much RL:
      Model forgets language → "jhsdf ksjdf"
    
    Too little RL:
      No improvement → same as SFT
    
    nanochat sweet spot: 200 steps

    핵심 요약

    KV Cache = 100x faster generation
    - Cache past K, V
    - O(N²) → O(N)
    - Memory trade-off worth it

    Sampling = controlled randomness
    - Temperature: creativity control
    - Top-k: limit to best k
    - Top-p: dynamic nucleus

    SFT = Base → Chat
    - Special tokens for structure
    - Loss masking (train on assistant only)
    - Low LR (avoid catastrophic forgetting)

    RL = Learning from rewards
    - REINFORCE algorithm
    - Reward functions
    - Improves reasoning

    Complete pipeline:
    - Pretraining: 4h, $100
    - SFT: 1h
    - RL: 30min
    - Total: ~6h, $100

    다음 단계

    Part 9: "평가와 배포"에서 다룰 내용:
    - MMLU (knowledge test)
    - GSM8K (math reasoning)
    - HumanEval (code generation)
    - CORE metric (aggregate)
    - Web serving (FastAPI + WebSocket)

    마지막 단계: 모델 평가하고 배포하기! 🚀

    ---

    📘 참고

    본 포스트는 Andrej Karpathy의 nanochat 오픈소스 프로젝트를 기반으로, 코드 구조와 학습 과정을 분석/설명하기 위해 작성되었습니다. 원본 코드는 MIT License 하에 배포됩니다. 모든 코드 예제는 교육 목적으로 단순화되었습니다.

    참고 자료:
    - [REINFORCE 논문](https://link.springer.com/article/10.1007/BF00992696)
    - [KV Cache 구현](https://github.com/karpathy/nanochat/blob/main/nanochat/engine.py)

    태그: #Inference #KVCache #Sampling #FineTuning #REINFORCE #RL #nanochat