[Nanochat 분석하기] 7. 3배 빠른 훈련 최적화

Read time: 2 minutes

3배 빠른 훈련: Mixed Precision과 Compiler 최적화

**시리즈**: 나노챗(nanochat)으로 배우는 LLM 구축 - Part 7/9

속도가 돈이다

  • Training time = Money
    
    Without optimization:
    - d20 (370M params)
    - 8×H100
    - 12 hours
    - Cost: $300
    
    With ALL optimizations:
    - Same model
    - Same hardware
    - 4 hours (3x faster!)
    - Cost: $100 (3x cheaper!)
    
    Savings: $200 AND 8 hours

    오늘은 이 3배 속도 향상을 달성하는 4가지 기법을 배웁니다.

    1. Mixed Precision: bfloat16

    float32 vs bfloat16

  • float32 (4 bytes):
    [sign: 1 bit][exponent: 8 bits][mantissa: 23 bits]
    Range: 10^-38 to 10^38
    Precision: ~7 decimal digits
    
    bfloat16 (2 bytes):
    [sign: 1 bit][exponent: 8 bits][mantissa: 7 bits]
    Range: 10^-38 to 10^38 (SAME!)
    Precision: ~2 decimal digits (LESS!)

    핵심 insight: Range는 같고 precision만 낮음!

  • # float32:
    x = 3.14159265
    # Stores: 3.14159265 (정확!)
    
    # bfloat16:
    x = 3.14159265
    # Stores: 3.140625 (약간 손실)
    
    # But range is same:
    huge = 1e30  # Both can store!
    tiny = 1e-30  # Both can store!

    왜 float16이 아니라 bfloat16?

  • # float16의 문제:
    Range: 6e-5 to 65504
    
    # Training gradients:
    grad = 0.0001  # Underflow! → 0 ❌
    loss = 100000  # Overflow! → inf ❌
    
    # bfloat16 solves this:
    Range: 10^-38 to 10^38 (same as float32!)
    grad = 0.0001  # OK ✅
    loss = 100000  # OK ✅

    Autocast: 자동 precision 관리

  • # Manual (tedious):
    model = model.to(torch.bfloat16)
    x = x.to(torch.bfloat16)
    # What about gradients? Optimizer states? ❌
    
    # Automatic (easy):
    with torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, loss = model(x, y)
    
    # PyTorch handles everything! ✅

    Autocast의 마법:

  • with autocast:
        # Matrix multiply → bfloat16 (fast!)
        logits = x @ weight
    
        # Softmax → float32 (accurate!)
        probs = F.softmax(logits, dim=-1)
    
        # Cross-entropy → float32 (stable!)
        loss = -torch.log(probs[target])
    
    # Auto selects best precision per operation!

    Operations whitelist/blacklist:

  • bfloat16 (fast ops):
    - matmul, conv
    - Linear layers
    - Attention
    
    float32 (precision-critical ops):
    - softmax, log_softmax
    - LayerNorm, RMSNorm
    - Loss functions

    성능 향상

  • # Benchmark: GPT-2 (1.5B params)
    
    float32:
    - Forward: 120 ms
    - Backward: 240 ms
    - Memory: 6 GB
    
    bfloat16 with autocast:
    - Forward: 45 ms (2.7x faster!)
    - Backward: 90 ms (2.7x faster!)
    - Memory: 3 GB (2x less!)
    
    # Why so fast?
    # 1. Tensor Cores (optimized for bf16)
    # 2. Half memory bandwidth

    2. Flash Attention 2: 메모리 효율

    Standard attention의 치명적 문제:

  • Q, K, V: [B, H, T, D]
    
    # Attention matrix:
    scores = Q @ K.T  # [B, H, T, T] ← O(T²) memory!
    
    # Example: T=2048
    scores size = 2048² = 4M values per head
                = 4M × 32 heads = 128M values
                = 256 MB per layer (bf16)
                = 8 GB for 32 layers! 😱

    Flash Attention 해결책:

  • # Never materialize full (T×T) matrix!
    
    # Process in tiles:
    block_size = 128
    
    for q_block in range(T // block_size):
        for kv_block in range(T // block_size):
            # Load small tile to SRAM (fast memory)
            Q_tile = Q[q_block]  # [block_size, D]
            K_tile = K[kv_block]
            V_tile = V[kv_block]
    
            # Compute attention for this tile
            scores_tile = Q_tile @ K_tile.T  # [128, 128] only!
            attn_tile = softmax(scores_tile)
            out_tile = attn_tile @ V_tile
    
            # Accumulate (no full matrix storage!)

    Memory reduction:

  • # Standard:
    Memory = O(T²) = 2048² = 4M values
    
    # Flash Attention:
    Memory = O(block_size²) = 128² = 16K values
    
    # Reduction: 4M / 16K = 250x less! 🎉

    PyTorch usage:

  • # Old (manual):
    scores = Q @ K.transpose(-2, -1) / sqrt(D)
    scores = scores.masked_fill(mask == 0, -inf)
    attn = F.softmax(scores, dim=-1)
    out = attn @ V
    
    # New (Flash Attention built-in):
    out = F.scaled_dot_product_attention(
        Q, K, V,
        is_causal=True  # Automatic masking!
    )
    
    # PyTorch auto-selects Flash Attention 2!

    3. torch.compile: 무료 20% 성능

    PyTorch 2.0의 killer feature!

  • # Before:
    model = GPT(config)
    
    # After:
    model = torch.compile(model, mode='max-autotune')
    
    # That's it! 20-30% speedup! 🚀

    Eager vs Compiled

    Eager mode (기본):

  • # Python execution:
    out = model.linear1(x)      # Python call
    out = F.relu(out)           # Python call
    out = model.linear2(out)    # Python call
    
    # Each operation:
    # 1. Python interpreter
    # 2. Launch GPU kernel
    # 3. Return to Python
    # 4. Repeat
    
    # Python overhead! ❌

    Compiled mode:

  • model = torch.compile(model)
    
    # First run: Compile!
    # - Trace operations
    # - Fuse kernels
    # - Generate optimized code
    
    # Subsequent runs:
    # - Execute fused kernels directly
    # - No Python overhead!
    # - 20-30% faster! ✅

    Kernel Fusion

  • # Eager (3 separate kernels):
    x = linear(x)   # Kernel 1
    x = relu(x)     # Kernel 2
    x = linear2(x)  # Kernel 3
    
    # Each kernel:
    # - Load from HBM
    # - Compute
    # - Store to HBM
    
    # Compiled (1 fused kernel):
    x = fused_linear_relu_linear(x)
    
    # - Load once
    # - Compute all
    # - Store once
    # → Save memory bandwidth! 🎉

    Compilation modes:

  • # 'default': Fast compile, good speed
    model = torch.compile(model)
    
    # 'reduce-overhead': Min Python overhead
    model = torch.compile(model, mode='reduce-overhead')
    
    # 'max-autotune': Slow compile, max speed
    model = torch.compile(model, mode='max-autotune')
    # nanochat uses this!

    Benchmark:

  • # GPT-2 (1.5B) on A100
    
    Eager:
    - Step time: 420 ms
    
    Compiled (max-autotune):
    - First step: 35 seconds (compiling...)
    - Step time: 320 ms (1.3x faster!)
    
    # Worth it for long training!

    4. Model FLOPs Utilization (MFU)

    훈련이 얼마나 효율적인지 측정:

  • def calculate_mfu(model, tokens_per_sec, gpu_name='H100'):
        # 1. Model FLOPs
        n_params = sum(p.numel() for p in model.parameters())
        flops_per_token = 6 * n_params  # Forward + backward
    
        # 2. Actual FLOPs
        actual_flops = flops_per_token * tokens_per_sec
    
        # 3. Peak FLOPs (hardware)
        peak_flops = {
            'H100': 989e12,  # 989 TFLOPS (bf16)
            'A100': 312e12,  # 312 TFLOPS (bf16)
        }[gpu_name]
    
        # 4. MFU
        mfu = actual_flops / peak_flops * 100
        return mfu

    Example:

  • # nanochat d20 on 8×H100:
    n_params = 370M
    tokens_per_sec = 2.5M
    
    flops_per_token = 6 * 370M = 2.22 GFLOPS
    actual_flops = 2.22G * 2.5M = 5.55 PFLOPS
    
    peak_flops = 989T * 8 = 7.912 PFLOPS
    
    MFU = 5.55 / 7.912 = 70%!  # Excellent! 🎉

    MFU interpretation:

  • < 20%: Poor (bottlenecked)
    20-40%: OK (typical unoptimized)
    40-60%: Good (well-optimized)
    > 60%: Excellent! (production-grade)
    
    nanochat: 70% ← Among best!

    왜 100% MFU는 불가능?

  • Even perfect code can't reach 100% because:
    
    1. Memory bandwidth:
       - Need to load weights from HBM
       - Takes time
    
    2. Non-compute ops:
       - Softmax, LayerNorm
       - Don't use Tensor Cores
    
    3. Data loading:
       - CPU → GPU transfer
       - Preprocessing
    
    4. Kernel launch overhead:
       - Small operations
       - GPU scheduling
    
    Transformers typically: 40-60% MFU
    nanochat achieves: 70% MFU!

    Profiling: 병목 찾기

    어디가 느린지 찾기:

  • from torch.profiler import profile, ProfilerActivity
    
    with profile(
        activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
        record_shapes=True,
        profile_memory=True,
    ) as prof:
        # Training step
        for _ in range(10):
            logits, loss = model(x, y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
    
    # Print results
    print(prof.key_averages().table(
        sort_by="cuda_time_total",
        row_limit=10
    ))

    Output:

  • Name                         CPU time  CUDA time  Memory
    ---------------------------------------------------------
    aten::matmul                   5.2ms    45.3ms   2.1 GB
    aten::scaled_dot_product_...  1.1ms    23.1ms   0.8 GB
    aten::add_                     0.3ms     2.1ms   0.1 GB
    ...
    
    # Insights:
    # - matmul takes 45ms (most time)
    # - Attention takes 23ms (optimized with Flash!)
    # - Memory usage: 2.1 GB (fits in GPU)

    Combined Effect

    모든 최적화를 합치면:

  • # Baseline (eager, fp32):
    model = GPT(config).cuda()
    # Step time: 840 ms
    # Memory: 80 GB
    # MFU: 15%
    
    # + bfloat16:
    with torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, loss = model(x, y)
    # Step time: 480 ms (1.75x faster)
    # Memory: 40 GB (2x less)
    # MFU: 25%
    
    # + Flash Attention:
    # (use F.scaled_dot_product_attention)
    # Step time: 320 ms (1.5x faster)
    # Memory: 25 GB (1.6x less)
    # MFU: 45%
    
    # + torch.compile:
    model = torch.compile(model, mode='max-autotune')
    # Step time: 240 ms (1.3x faster)
    # Memory: 25 GB (same)
    # MFU: 70%
    
    # Total speedup: 840 / 240 = 3.5x! 🚀

    Real-world impact:

  • nanochat d20 training:
    
    Without optimizations:
    - 14 hours
    - $350
    
    With ALL optimizations:
    - 4 hours (3.5x faster!)
    - $100 (3.5x cheaper!)
    
    Savings: $250 AND 10 hours!

    Implementation Checklist

  • # ✅ 1. Use bfloat16 autocast
    ctx = torch.amp.autocast('cuda', dtype=torch.bfloat16)
    
    with ctx:
        logits, loss = model(x, y)
    
    # ✅ 2. Use Flash Attention
    out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
    
    # ✅ 3. Compile model
    model = torch.compile(model, mode='max-autotune')
    
    # ✅ 4. Monitor MFU
    mfu = calculate_mfu(model, tokens_per_sec)
    print(f"MFU: {mfu:.1f}%")
    
    # Target: > 40% for good, > 60% for excellent

    핵심 요약

    bfloat16 = 2x speedup
    - Same range as fp32
    - Half memory
    - Tensor Core acceleration

    Flash Attention = 3-6x for long sequences
    - O(T²) → O(T) memory
    - Tiling + recomputation
    - Built into PyTorch

    torch.compile = 20-30% free speedup
    - Kernel fusion
    - One-time compilation cost
    - Mode='max-autotune' for best

    MFU = efficiency metric
    - Actual FLOPs / Peak FLOPs
    - Target: 40-60%
    - nanochat: 70%!

    Combined = 3.5x total speedup
    - $350 → $100
    - 14h → 4h

    다음 단계

    Part 8: "Inference와 Fine-tuning"에서 다룰 내용:
    - KV cache로 100배 빠른 generation
    - Sampling strategies
    - SFT (Supervised Fine-Tuning)
    - REINFORCE (RL from human feedback)

    훈련을 마쳤으니 이제 사용할 차례! 🚀

    ---

    📘 참고

    본 포스트는 Andrej Karpathy의 nanochat 오픈소스 프로젝트를 기반으로, 코드 구조와 학습 과정을 분석/설명하기 위해 작성되었습니다. 원본 코드는 MIT License 하에 배포됩니다. 모든 코드 예제는 교육 목적으로 단순화되었습니다.

    참고 자료:
    - [Flash Attention 2 논문](https://arxiv.org/abs/2307.08691)
    - [torch.compile 가이드](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html)

    태그: #Optimization #bfloat16 #FlashAttention #torchcompile #MFU #Performance