[Nanochat 분석하기] 7. 3배 빠른 훈련 최적화
3배 빠른 훈련: Mixed Precision과 Compiler 최적화
**시리즈**: 나노챗(nanochat)으로 배우는 LLM 구축 - Part 7/9
속도가 돈이다¶
Training time = Money
Without optimization:
- d20 (370M params)
- 8×H100
- 12 hours
- Cost: $300
With ALL optimizations:
- Same model
- Same hardware
- 4 hours (3x faster!)
- Cost: $100 (3x cheaper!)
Savings: $200 AND 8 hours오늘은 이 3배 속도 향상을 달성하는 4가지 기법을 배웁니다.
1. Mixed Precision: bfloat16¶
float32 vs bfloat16¶
float32 (4 bytes):
[sign: 1 bit][exponent: 8 bits][mantissa: 23 bits]
Range: 10^-38 to 10^38
Precision: ~7 decimal digits
bfloat16 (2 bytes):
[sign: 1 bit][exponent: 8 bits][mantissa: 7 bits]
Range: 10^-38 to 10^38 (SAME!)
Precision: ~2 decimal digits (LESS!)핵심 insight: Range는 같고 precision만 낮음!
# float32:
x = 3.14159265
# Stores: 3.14159265 (정확!)
# bfloat16:
x = 3.14159265
# Stores: 3.140625 (약간 손실)
# But range is same:
huge = 1e30 # Both can store!
tiny = 1e-30 # Both can store!왜 float16이 아니라 bfloat16?¶
# float16의 문제:
Range: 6e-5 to 65504
# Training gradients:
grad = 0.0001 # Underflow! → 0 ❌
loss = 100000 # Overflow! → inf ❌
# bfloat16 solves this:
Range: 10^-38 to 10^38 (same as float32!)
grad = 0.0001 # OK ✅
loss = 100000 # OK ✅Autocast: 자동 precision 관리¶
# Manual (tedious):
model = model.to(torch.bfloat16)
x = x.to(torch.bfloat16)
# What about gradients? Optimizer states? ❌
# Automatic (easy):
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
logits, loss = model(x, y)
# PyTorch handles everything! ✅Autocast의 마법:
with autocast:
# Matrix multiply → bfloat16 (fast!)
logits = x @ weight
# Softmax → float32 (accurate!)
probs = F.softmax(logits, dim=-1)
# Cross-entropy → float32 (stable!)
loss = -torch.log(probs[target])
# Auto selects best precision per operation!Operations whitelist/blacklist:
bfloat16 (fast ops):
- matmul, conv
- Linear layers
- Attention
float32 (precision-critical ops):
- softmax, log_softmax
- LayerNorm, RMSNorm
- Loss functions성능 향상¶
# Benchmark: GPT-2 (1.5B params)
float32:
- Forward: 120 ms
- Backward: 240 ms
- Memory: 6 GB
bfloat16 with autocast:
- Forward: 45 ms (2.7x faster!)
- Backward: 90 ms (2.7x faster!)
- Memory: 3 GB (2x less!)
# Why so fast?
# 1. Tensor Cores (optimized for bf16)
# 2. Half memory bandwidth2. Flash Attention 2: 메모리 효율¶
Standard attention의 치명적 문제:
Q, K, V: [B, H, T, D]
# Attention matrix:
scores = Q @ K.T # [B, H, T, T] ← O(T²) memory!
# Example: T=2048
scores size = 2048² = 4M values per head
= 4M × 32 heads = 128M values
= 256 MB per layer (bf16)
= 8 GB for 32 layers! 😱Flash Attention 해결책:
# Never materialize full (T×T) matrix!
# Process in tiles:
block_size = 128
for q_block in range(T // block_size):
for kv_block in range(T // block_size):
# Load small tile to SRAM (fast memory)
Q_tile = Q[q_block] # [block_size, D]
K_tile = K[kv_block]
V_tile = V[kv_block]
# Compute attention for this tile
scores_tile = Q_tile @ K_tile.T # [128, 128] only!
attn_tile = softmax(scores_tile)
out_tile = attn_tile @ V_tile
# Accumulate (no full matrix storage!)Memory reduction:
# Standard:
Memory = O(T²) = 2048² = 4M values
# Flash Attention:
Memory = O(block_size²) = 128² = 16K values
# Reduction: 4M / 16K = 250x less! 🎉PyTorch usage:
# Old (manual):
scores = Q @ K.transpose(-2, -1) / sqrt(D)
scores = scores.masked_fill(mask == 0, -inf)
attn = F.softmax(scores, dim=-1)
out = attn @ V
# New (Flash Attention built-in):
out = F.scaled_dot_product_attention(
Q, K, V,
is_causal=True # Automatic masking!
)
# PyTorch auto-selects Flash Attention 2!3. torch.compile: 무료 20% 성능¶
PyTorch 2.0의 killer feature!
# Before:
model = GPT(config)
# After:
model = torch.compile(model, mode='max-autotune')
# That's it! 20-30% speedup! 🚀Eager vs Compiled¶
Eager mode (기본):
# Python execution:
out = model.linear1(x) # Python call
out = F.relu(out) # Python call
out = model.linear2(out) # Python call
# Each operation:
# 1. Python interpreter
# 2. Launch GPU kernel
# 3. Return to Python
# 4. Repeat
# Python overhead! ❌Compiled mode:
model = torch.compile(model)
# First run: Compile!
# - Trace operations
# - Fuse kernels
# - Generate optimized code
# Subsequent runs:
# - Execute fused kernels directly
# - No Python overhead!
# - 20-30% faster! ✅Kernel Fusion¶
# Eager (3 separate kernels):
x = linear(x) # Kernel 1
x = relu(x) # Kernel 2
x = linear2(x) # Kernel 3
# Each kernel:
# - Load from HBM
# - Compute
# - Store to HBM
# Compiled (1 fused kernel):
x = fused_linear_relu_linear(x)
# - Load once
# - Compute all
# - Store once
# → Save memory bandwidth! 🎉Compilation modes:
# 'default': Fast compile, good speed
model = torch.compile(model)
# 'reduce-overhead': Min Python overhead
model = torch.compile(model, mode='reduce-overhead')
# 'max-autotune': Slow compile, max speed
model = torch.compile(model, mode='max-autotune')
# nanochat uses this!Benchmark:
# GPT-2 (1.5B) on A100
Eager:
- Step time: 420 ms
Compiled (max-autotune):
- First step: 35 seconds (compiling...)
- Step time: 320 ms (1.3x faster!)
# Worth it for long training!4. Model FLOPs Utilization (MFU)¶
훈련이 얼마나 효율적인지 측정:
def calculate_mfu(model, tokens_per_sec, gpu_name='H100'):
# 1. Model FLOPs
n_params = sum(p.numel() for p in model.parameters())
flops_per_token = 6 * n_params # Forward + backward
# 2. Actual FLOPs
actual_flops = flops_per_token * tokens_per_sec
# 3. Peak FLOPs (hardware)
peak_flops = {
'H100': 989e12, # 989 TFLOPS (bf16)
'A100': 312e12, # 312 TFLOPS (bf16)
}[gpu_name]
# 4. MFU
mfu = actual_flops / peak_flops * 100
return mfuExample:
# nanochat d20 on 8×H100:
n_params = 370M
tokens_per_sec = 2.5M
flops_per_token = 6 * 370M = 2.22 GFLOPS
actual_flops = 2.22G * 2.5M = 5.55 PFLOPS
peak_flops = 989T * 8 = 7.912 PFLOPS
MFU = 5.55 / 7.912 = 70%! # Excellent! 🎉MFU interpretation:
< 20%: Poor (bottlenecked)
20-40%: OK (typical unoptimized)
40-60%: Good (well-optimized)
> 60%: Excellent! (production-grade)
nanochat: 70% ← Among best!왜 100% MFU는 불가능?
Even perfect code can't reach 100% because:
1. Memory bandwidth:
- Need to load weights from HBM
- Takes time
2. Non-compute ops:
- Softmax, LayerNorm
- Don't use Tensor Cores
3. Data loading:
- CPU → GPU transfer
- Preprocessing
4. Kernel launch overhead:
- Small operations
- GPU scheduling
Transformers typically: 40-60% MFU
nanochat achieves: 70% MFU!Profiling: 병목 찾기¶
어디가 느린지 찾기:
from torch.profiler import profile, ProfilerActivity
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
) as prof:
# Training step
for _ in range(10):
logits, loss = model(x, y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Print results
print(prof.key_averages().table(
sort_by="cuda_time_total",
row_limit=10
))Output:
Name CPU time CUDA time Memory
---------------------------------------------------------
aten::matmul 5.2ms 45.3ms 2.1 GB
aten::scaled_dot_product_... 1.1ms 23.1ms 0.8 GB
aten::add_ 0.3ms 2.1ms 0.1 GB
...
# Insights:
# - matmul takes 45ms (most time)
# - Attention takes 23ms (optimized with Flash!)
# - Memory usage: 2.1 GB (fits in GPU)Combined Effect¶
모든 최적화를 합치면:
# Baseline (eager, fp32):
model = GPT(config).cuda()
# Step time: 840 ms
# Memory: 80 GB
# MFU: 15%
# + bfloat16:
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
logits, loss = model(x, y)
# Step time: 480 ms (1.75x faster)
# Memory: 40 GB (2x less)
# MFU: 25%
# + Flash Attention:
# (use F.scaled_dot_product_attention)
# Step time: 320 ms (1.5x faster)
# Memory: 25 GB (1.6x less)
# MFU: 45%
# + torch.compile:
model = torch.compile(model, mode='max-autotune')
# Step time: 240 ms (1.3x faster)
# Memory: 25 GB (same)
# MFU: 70%
# Total speedup: 840 / 240 = 3.5x! 🚀Real-world impact:
nanochat d20 training:
Without optimizations:
- 14 hours
- $350
With ALL optimizations:
- 4 hours (3.5x faster!)
- $100 (3.5x cheaper!)
Savings: $250 AND 10 hours!Implementation Checklist¶
# ✅ 1. Use bfloat16 autocast
ctx = torch.amp.autocast('cuda', dtype=torch.bfloat16)
with ctx:
logits, loss = model(x, y)
# ✅ 2. Use Flash Attention
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
# ✅ 3. Compile model
model = torch.compile(model, mode='max-autotune')
# ✅ 4. Monitor MFU
mfu = calculate_mfu(model, tokens_per_sec)
print(f"MFU: {mfu:.1f}%")
# Target: > 40% for good, > 60% for excellent핵심 요약¶
✅ bfloat16 = 2x speedup
- Same range as fp32
- Half memory
- Tensor Core acceleration
✅ Flash Attention = 3-6x for long sequences
- O(T²) → O(T) memory
- Tiling + recomputation
- Built into PyTorch
✅ torch.compile = 20-30% free speedup
- Kernel fusion
- One-time compilation cost
- Mode='max-autotune' for best
✅ MFU = efficiency metric
- Actual FLOPs / Peak FLOPs
- Target: 40-60%
- nanochat: 70%!
✅ Combined = 3.5x total speedup
- $350 → $100
- 14h → 4h
다음 단계¶
Part 8: "Inference와 Fine-tuning"에서 다룰 내용:
- KV cache로 100배 빠른 generation
- Sampling strategies
- SFT (Supervised Fine-Tuning)
- REINFORCE (RL from human feedback)
훈련을 마쳤으니 이제 사용할 차례! 🚀
---
📘 참고¶
본 포스트는 Andrej Karpathy의 nanochat 오픈소스 프로젝트를 기반으로, 코드 구조와 학습 과정을 분석/설명하기 위해 작성되었습니다. 원본 코드는 MIT License 하에 배포됩니다. 모든 코드 예제는 교육 목적으로 단순화되었습니다.
참고 자료:
- [Flash Attention 2 논문](https://arxiv.org/abs/2307.08691)
- [torch.compile 가이드](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html)
태그: #Optimization #bfloat16 #FlashAttention #torchcompile #MFU #Performance