[Nanochat 분석하기] 8. 대화형 AI 만들기
Inference와 Fine-tuning: 모델을 대화형 AI로
**시리즈**: 나노챗(nanochat)으로 배우는 LLM 구축 - Part 8/9
Training vs Inference¶
지금까지 훈련했으니 이제 사용할 차례입니다!
Training (병렬 처리):
Input: "The cat sat on the"
Target: "cat sat on the mat"
→ Process all tokens at once
→ Compute loss
→ Update weights
Inference (순차 생성):
Input: "The cat"
→ Generate "sat" (1 token)
→ Generate "on" (1 token)
→ Generate "the" (1 token)
→ Autoregressive, one at a time문제: Naive inference는 O(N²) 복잡도!
KV Cache: 100배 속도 향상¶
The O(N²) Problem¶
# Naive generation:
tokens = ["The", "cat"] # Start
Step 1: Process ["The", "cat"] → predict "sat"
Step 2: Process ["The", "cat", "sat"] → predict "on"
Step 3: Process ["The", "cat", "sat", "on"] → predict "the"
...
# Computation:
# 2 + 3 + 4 + ... + N = O(N²) tokens processed! ❌The Solution: Cache K and V!¶
Attention 수식을 다시 보면:
# At each step:
Q = query for NEW token only (1 token)
K = keys for ALL tokens (past + new)
V = values for ALL tokens (past + new)
out = softmax(Q @ K.T) @ V핵심 insight: K와 V는 이전 step과 겹침!
Step 1: K₁ = keys for ["The", "cat"]
Step 2: K₂ = keys for ["The", "cat", "sat"]
# └─ Already computed! ─┘ └─ new ─┘
# Cache K₁, only compute K for "sat"!
K₂ = concat(K₁, K_new)구현:
class AttentionWithCache(nn.Module):
def forward(self, x, kv_cache=None):
B, T, C = x.shape
# Compute Q, K, V for current input
q, k, v = self.qkv(x).split([...], dim=-1)
if kv_cache is not None:
# Use cached K, V
k_cache, v_cache = kv_cache
k = torch.cat([k_cache, k], dim=1) # Append new K
v = torch.cat([v_cache, v], dim=1) # Append new V
# Attention as normal
scores = q @ k.transpose(-2, -1) / sqrt(head_dim)
attn = F.softmax(scores, dim=-1)
out = attn @ v
# Return output AND updated cache
return out, (k, v)사용:
kv_cache = None
for _ in range(max_new_tokens):
# Only process NEW token!
logits, kv_cache = model(new_token, kv_cache=kv_cache)
next_token = sample(logits)
tokens.append(next_token)
# Complexity: O(N) instead of O(N²)! 🎉메모리 trade-off:
# Memory:
# Store K, V for all past tokens
cache_size = seq_len * n_layers * n_heads * head_dim * 2 (K+V) * 2 bytes
# Example (nanochat d20):
cache_size = 1024 * 20 * 10 * 128 * 2 * 2
= 104 MB per sample
# Speed:
# 100x faster generation!
# Worth it! ✅Sampling Strategies¶
Greedy (argmax)는 boring:
logits = [2.0, 1.0, 0.5, ...]
next_token = argmax(logits) # Always picks index 0
# Output: "The cat sat on the mat. The cat sat on the mat. ..."
# Repetitive! ❌1. Temperature Sampling¶
def sample_with_temperature(logits, temperature=1.0):
# Divide by temperature
logits = logits / temperature
# Softmax → probabilities
probs = F.softmax(logits, dim=-1)
# Sample
return torch.multinomial(probs, num_samples=1)Temperature 효과:
logits = [2.0, 1.0, 0.5]
T = 0.1 (confident):
probs = [0.99, 0.01, 0.00] # Almost deterministic
T = 1.0 (balanced):
probs = [0.66, 0.24, 0.10] # Normal
T = 2.0 (creative):
probs = [0.46, 0.31, 0.23] # More random2. Top-k Sampling¶
def top_k_sampling(logits, k=50):
# Keep only top k
values, indices = torch.topk(logits, k)
# Zero out others
logits_filtered = torch.full_like(logits, float('-inf'))
logits_filtered.scatter_(1, indices, values)
# Sample from top k only
probs = F.softmax(logits_filtered, dim=-1)
return torch.multinomial(probs, num_samples=1)왜 필요?
Vocabulary: 50,000 tokens
Without top-k:
- Might sample rare/gibberish token
- "The cat <rare_token_49872>" ❌
With top-k (k=50):
- Only sample from top 50 likely tokens
- "The cat sat" ✅3. Top-p (Nucleus) Sampling¶
def top_p_sampling(logits, p=0.9):
# Sort by probability
probs = F.softmax(logits, dim=-1)
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
# Cumulative probability
cumsum = torch.cumsum(sorted_probs, dim=-1)
# Find nucleus (where cumsum > p)
nucleus_size = (cumsum > p).nonzero()[0].item() + 1
# Sample from nucleus
nucleus_indices = sorted_indices[:nucleus_size]
# ...Dynamic cutoff:
Probs: [0.5, 0.3, 0.15, 0.04, 0.01, ...]
Top-p (p=0.9):
0.5 → cumsum=0.5
0.3 → cumsum=0.8
0.15 → cumsum=0.95 > 0.9 ← Stop!
Nucleus = [0.5, 0.3, 0.15] (3 tokens)
# Adaptive! Sometimes 3 tokens, sometimes 10Combined strategy:
# nanochat defaults:
next_token = sample(
logits,
temperature=1.0,
top_k=50,
top_p=0.95
)
# 1. Temperature scaling
# 2. Keep top 50 tokens
# 3. Within top 50, keep nucleus (p=0.95)Fine-tuning: Base → Chat¶
Pretrained 모델은 completion만 합니다:
Input: "The capital of France is"
Output: "Paris, which is located in northern France and has..."Chat 모델로 만들려면:
Input: <|user|>"What is the capital of France?"<|end|>
Output: <|assistant|>"The capital of France is Paris."<|end|>Chat Format¶
SPECIAL_TOKENS = [
"<|bos|>", # Beginning of sequence
"<|user_start|>", # User message start
"<|user_end|>", # User message end
"<|assistant_start|>", # Assistant message start
"<|assistant_end|>", # Assistant message end
]
# Example conversation:
text = """
<|bos|>
<|user_start|>Hello!<|user_end|>
<|assistant_start|>Hi there! How can I help?<|assistant_end|>
<|user_start|>What's 2+2?<|user_end|>
<|assistant_start|>2+2 equals 4.<|assistant_end|>
"""Loss Masking¶
Critical: Only train on assistant responses!
text = "<|user|>Hello<|end|><|assistant|>Hi!<|end|>"
tokens = [50297, 15496, 50298, 50299, 13347, 50300]
# Loss mask:
mask = [0, 0, 0, 1, 1, 1]
# └─ Don't train ─┘ └─ Train ─┘왜?
If we train on user messages:
Model learns to predict user input
→ Useless! ❌
If we train on assistant messages:
Model learns to generate good responses
→ Useful! ✅구현:
def compute_loss_with_mask(logits, targets, mask):
# Compute loss per position
loss = F.cross_entropy(logits, targets, reduction='none')
# Apply mask
loss = loss * mask
# Average over assistant tokens only
return loss.sum() / mask.sum()SFT Training¶
# 1. Load chat dataset (SmolTalk)
dataset = load_smoltalk()
# 10K conversations, high quality
# 2. Fine-tune with LOWER learning rate
optimizer = AdamW(model.parameters(), lr=3e-5) # 20x lower!
# 3. Short training
for step in range(500): # vs 4681 for pretraining
x, y, mask = next(dataloader)
logits, _ = model(x)
loss = compute_loss_with_mask(logits, y, mask)
loss.backward()
optimizer.step()
optimizer.zero_grad()왜 lower LR?
High LR:
Model forgets pretraining knowledge
→ "What's the capital of France?" → "I don't know" ❌
(Catastrophic forgetting!)
Low LR (3e-5):
Model retains knowledge ✅
Adds conversational ability ✅
→ "The capital of France is Paris" still works!REINFORCE: Learning from Rewards¶
SFT learns from demonstrations, RL learns from trial and error!
# SFT:
"If user asks X, respond Y"
(Supervised, needs examples)
# RL:
"Try different responses, keep what works!"
(Reinforcement, learns from rewards)REINFORCE Algorithm¶
for step in range(num_iterations):
# 1. Generate response
prompt = sample_problem()
response, log_probs = model.generate_with_logprobs(prompt)
# 2. Compute reward
reward = reward_function(prompt, response)
# 3. Policy gradient loss
loss = -log_probs.sum() * reward
# 4. Update model
loss.backward()
optimizer.step()Intuition:
# If reward is high (good response):
loss = -log_prob * high_reward
→ Gradient INCREASES probability of this response
# If reward is low (bad response):
loss = -log_prob * low_reward
→ Gradient DECREASES probability of this responseReward Function (GSM8K Math)¶
def compute_reward(problem, response):
# Extract answer from response
predicted = extract_answer(response)
correct = problem['answer']
# Correctness reward
reward = 1.0 if predicted == correct else 0.0
# Format bonuses
if "<|end|>" in response:
reward += 0.1 # Proper ending
if len(response) < 500:
reward += 0.1 # Not too verbose
return rewardExample:
Problem: "If John has 5 apples and buys 3 more, how many does he have?"
Response 1: "John has apples."
→ Reward: 0.0 (wrong answer)
Response 2: "5 + 3 = 8, so John has 8 apples.<|end|>"
→ Reward: 1.0 + 0.1 + 0.1 = 1.2 (correct + format!)
Model learns to generate Response 2 style! ✅RL Training Loop¶
# Use GSM8K dataset (grade school math)
problems = load_gsm8k()
optimizer = AdamW(model.parameters(), lr=1e-5) # Even lower!
for step in range(200): # Short RL phase
batch = sample(problems, batch_size=16)
total_loss = 0
for problem in batch:
# Generate
prompt = format_prompt(problem)
response, log_probs = model.generate_with_logprobs(prompt)
# Reward
reward = compute_reward(problem, response)
# Loss
loss = -log_probs.sum() * reward
total_loss += loss
# Update
total_loss.backward()
optimizer.step()
optimizer.zero_grad()
print(f"Step {step}, Avg reward: {rewards.mean():.2f}")Results:
Before RL:
Accuracy on GSM8K: 20%
After RL (200 steps):
Accuracy on GSM8K: 35%
RL improves math reasoning! ✅Caution: RL is unstable!
Too much RL:
Model forgets language → "jhsdf ksjdf" ❌
Too little RL:
No improvement → same as SFT
nanochat sweet spot: 200 steps핵심 요약¶
✅ KV Cache = 100x faster generation
- Cache past K, V
- O(N²) → O(N)
- Memory trade-off worth it
✅ Sampling = controlled randomness
- Temperature: creativity control
- Top-k: limit to best k
- Top-p: dynamic nucleus
✅ SFT = Base → Chat
- Special tokens for structure
- Loss masking (train on assistant only)
- Low LR (avoid catastrophic forgetting)
✅ RL = Learning from rewards
- REINFORCE algorithm
- Reward functions
- Improves reasoning
✅ Complete pipeline:
- Pretraining: 4h, $100
- SFT: 1h
- RL: 30min
- Total: ~6h, $100
다음 단계¶
Part 9: "평가와 배포"에서 다룰 내용:
- MMLU (knowledge test)
- GSM8K (math reasoning)
- HumanEval (code generation)
- CORE metric (aggregate)
- Web serving (FastAPI + WebSocket)
마지막 단계: 모델 평가하고 배포하기! 🚀
---
📘 참고¶
본 포스트는 Andrej Karpathy의 nanochat 오픈소스 프로젝트를 기반으로, 코드 구조와 학습 과정을 분석/설명하기 위해 작성되었습니다. 원본 코드는 MIT License 하에 배포됩니다. 모든 코드 예제는 교육 목적으로 단순화되었습니다.
참고 자료:
- [REINFORCE 논문](https://link.springer.com/article/10.1007/BF00992696)
- [KV Cache 구현](https://github.com/karpathy/nanochat/blob/main/nanochat/engine.py)
태그: #Inference #KVCache #Sampling #FineTuning #REINFORCE #RL #nanochat