Stable Diffusion 3 & FLUX: Complete Guide to MMDiT Architecture
From U-Net to Transformer. A deep dive into MMDiT architecture treating text and image equally, plus Rectified Flow and Guidance Distillation.

Stable Diffusion 3 & FLUX: Complete Guide to MMDiT Architecture
From U-Net to Transformer. A new paradigm treating Text and Image equally.
TL;DR
- MMDiT (Multimodal DiT): Processes text and image jointly in a single Transformer
- Rectified Flow Adoption: Straight-line paths instead of DDPM for faster generation
- FLUX Evolution: Guidance Distillation enables CFG-free 4-8 step generation
- Core Innovation: Bidirectional attention between text and image for better prompt following
1. Why Abandon U-Net?
U-Net's Limitations
Stable Diffusion 1.x/2.x was U-Net-based:
Text Encoder (CLIP) → Cross-Attention → U-Net → ImageProblems:
- One-way information flow: Only text → image, no image → text feedback
- Cross-attention bottleneck: Text information injected only at specific layers
- Scaling limits: U-Net shows diminishing returns with increased model size
The Rise of DiT
DiT (Diffusion Transformer) demonstrated:
- Transformers follow scaling laws
- Larger models consistently improve FID
- But DiT still used cross-attention for text
MMDiT: True Multimodal
SD3's MMDiT treats text and image as equal sequences:
[Text Tokens] + [Image Tokens] → Joint Transformer → [Text'] + [Image']Bidirectional attention lets text see image, and image see text.
2. MMDiT Architecture Details
Input Processing
Image Input:
- VAE Encoder extracts latent:
(H, W, 3)→(h, w, 16) - Patchify:
(h, w, 16)→(N_img, D) - Add position embedding
Text Input:
- Three text encoders:
- CLIP-L (OpenAI)
- CLIP-G (OpenCLIP)
- T5-XXL (Google)
- Combine pooled + sequence embeddings
- Transform to
(N_txt, D)shape
Joint Attention Block
The core is the MM-DiT Block:
class MMDiTBlock(nn.Module):
def __init__(self, dim):
self.norm1_img = AdaLayerNorm(dim)
self.norm1_txt = AdaLayerNorm(dim)
self.attn = JointAttention(dim)
self.norm2_img = AdaLayerNorm(dim)
self.norm2_txt = AdaLayerNorm(dim)
self.ff_img = FeedForward(dim)
self.ff_txt = FeedForward(dim)
def forward(self, img, txt, timestep):
# Separate normalization
img_norm = self.norm1_img(img, timestep)
txt_norm = self.norm1_txt(txt, timestep)
# Joint attention (the key!)
img_attn, txt_attn = self.attn(img_norm, txt_norm)
img = img + img_attn
txt = txt + txt_attn
# Separate feedforward
img = img + self.ff_img(self.norm2_img(img, timestep))
txt = txt + self.ff_txt(self.norm2_txt(txt, timestep))
return img, txtJoint Attention Mechanism
class JointAttention(nn.Module):
def forward(self, img, txt):
# Concatenate image and text
x = torch.cat([img, txt], dim=1)
# Compute Q, K, V
q = self.to_q(x)
k = self.to_k(x)
v = self.to_v(x)
# Self-attention (all tokens attend to each other)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
out = attn @ v
# Split back
img_out, txt_out = out.split([img.shape[1], txt.shape[1]], dim=1)
return img_out, txt_outKey Points:
- Image tokens attend to text tokens
- Text tokens attend to image tokens
- Bidirectional information flow for better text-image alignment
AdaLN (Adaptive Layer Normalization)
How timestep information is injected:
class AdaLayerNorm(nn.Module):
def __init__(self, dim):
self.norm = nn.LayerNorm(dim, elementwise_affine=False)
self.proj = nn.Linear(dim, dim * 2)
def forward(self, x, timestep_emb):
# Predict scale, shift from timestep
scale, shift = self.proj(timestep_emb).chunk(2, dim=-1)
# Adaptive normalization
x = self.norm(x)
x = x * (1 + scale) + shift
return x3. Rectified Flow in SD3
SD3 uses Rectified Flow instead of DDPM.
Why Rectified Flow?
| Property | DDPM | Rectified Flow |
|---|---|---|
| Trajectory | Curved (SDE) | Straight (ODE) |
| Required Steps | 20-50 | 4-10 |
| Training Target | Noise prediction | Velocity field |
| Distillation | Difficult | Easy |
SD3's Flow Formulation
def flow_matching_loss(model, x0, text_emb):
# Sample time
t = torch.rand(x0.shape[0])
# Sample noise
z = torch.randn_like(x0)
# Linear interpolation
x_t = (1 - t) * x0 + t * z
# Target velocity
v_target = z - x0
# Predict velocity
v_pred = model(x_t, t, text_emb)
return F.mse_loss(v_pred, v_target)Logit-Normal Sampling
SD3's unique approach: sample timesteps from logit-normal distribution instead of uniform
def logit_normal_sample(batch_size, m=0.0, s=1.0):
"""Focus more on middle timesteps."""
u = torch.randn(batch_size) * s + m
t = torch.sigmoid(u) # Transform to (0, 1)
return tReason: Middle timesteps are more important for learning
4. FLUX: SD3's Evolution
FLUX is created by Black Forest Labs (founded by SD3 developers).
FLUX vs SD3 Comparison
| Property | SD3 | FLUX |
|---|---|---|
| Developer | Stability AI | Black Forest Labs |
| Architecture | MMDiT | MMDiT (improved) |
| Guidance | CFG required | Distilled (CFG-free possible) |
| Model Size | 2B, 8B | 12B |
| Minimum Steps | 20-30 | 4 (schnell) |
FLUX Variants
- FLUX.1-pro: Highest quality, API only
- FLUX.1-dev: Research/development, open weights
- FLUX.1-schnell: 1-4 step generation, fastest
Guidance Distillation
The key technique in FLUX.1-schnell:
Traditional CFG (Classifier-Free Guidance):
# Requires 2x computation at inference
pred_uncond = model(x_t, t, null_text)
pred_cond = model(x_t, t, text)
pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)After Guidance Distillation:
# Single forward pass achieves CFG effect
pred = model(x_t, t, text) # Guidance internalizedTraining method:
def guidance_distillation_loss(student, teacher, x_t, t, text):
# Teacher: Apply CFG
with torch.no_grad():
pred_uncond = teacher(x_t, t, null_text)
pred_cond = teacher(x_t, t, text)
target = pred_uncond + cfg_scale * (pred_cond - pred_uncond)
# Student: Single forward
pred = student(x_t, t, text)
return F.mse_loss(pred, target)5. Text Encoder Strategy
SD3's Triple Text Encoder
SD3 uses three text encoders:
class TripleTextEncoder:
def __init__(self):
self.clip_l = CLIPTextModel("openai/clip-vit-large")
self.clip_g = CLIPTextModel("laion/CLIP-ViT-bigG")
self.t5 = T5EncoderModel("google/t5-v1_1-xxl")
def encode(self, text):
# CLIP embeddings (pooled)
clip_l_pooled, clip_l_seq = self.clip_l(text)
clip_g_pooled, clip_g_seq = self.clip_g(text)
# T5 embedding (sequence only)
t5_seq = self.t5(text)
# Pooled: for conditioning
pooled = torch.cat([clip_l_pooled, clip_g_pooled], dim=-1)
# Sequence: for cross-attention
seq = torch.cat([clip_l_seq, clip_g_seq, t5_seq], dim=1)
return pooled, seqWhy three encoders?
| Encoder | Strength | Token Limit |
|---|---|---|
| CLIP-L | General visual concepts | 77 |
| CLIP-G | Larger capacity | 77 |
| T5-XXL | Long text, complex relationships | 512 |
T5 addition significantly improved long prompt and complex relationship understanding.
FLUX's Text Processing
FLUX simplifies:
- CLIP-L + T5-XXL combination
- More reliance on T5 (leveraging longer context)
6. VAE Improvements
SD3's 16-Channel VAE
Previous SD 1.x/2.x: 4-channel latent
SD3/FLUX: 16-channel latent
# SD 1.x/2.x
latent = vae.encode(image) # (B, 4, H/8, W/8)
# SD3/FLUX
latent = vae.encode(image) # (B, 16, H/8, W/8)Advantages:
- More information preserved
- Better fine detail reconstruction
- Improved text rendering quality
Disadvantages:
- Increased memory usage
- Higher computation
7. Practical Usage Examples
Using SD3 with Diffusers
from diffusers import StableDiffusion3Pipeline
import torch
pipe = StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3-medium",
torch_dtype=torch.float16
)
pipe = pipe.to("cuda")
image = pipe(
prompt="A cat holding a sign that says 'Hello World'",
num_inference_steps=28,
guidance_scale=7.0,
).images[0]Using FLUX with Diffusers
from diffusers import FluxPipeline
import torch
# FLUX.1-schnell (fast version)
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16
)
pipe = pipe.to("cuda")
image = pipe(
prompt="A cat holding a sign that says 'Hello World'",
num_inference_steps=4, # Only 4 steps!
guidance_scale=0.0, # No CFG needed
).images[0]Memory Optimization
# CPU offload
pipe.enable_model_cpu_offload()
# Attention slicing
pipe.enable_attention_slicing()
# VAE tiling (for high resolution)
pipe.enable_vae_tiling()8. Performance Comparison
Text Rendering Capability
SD3/FLUX's biggest improvement: Accurate text rendering in images
| Model | "Hello World" Accuracy |
|---|---|
| SD 1.5 | ~10% |
| SD 2.1 | ~20% |
| SDXL | ~40% |
| SD3 | ~80% |
| FLUX | ~90% |
Prompt Following
Complex prompt test: "A red cube on top of a blue sphere, with a green pyramid to the left"
| Model | Accuracy |
|---|---|
| SDXL | Medium |
| SD3 | High |
| FLUX | Very High |
Generation Speed (A100)
| Model | Steps | Time |
|---|---|---|
| SDXL | 30 | ~3s |
| SD3 | 28 | ~4s |
| FLUX-dev | 20 | ~5s |
| FLUX-schnell | 4 | ~1s |
9. Limitations and Considerations
Memory Requirements
| Model | VRAM (fp16) |
|---|---|
| SDXL | ~8GB |
| SD3-medium | ~12GB |
| FLUX-dev | ~24GB |
| FLUX-schnell | ~12GB |
Licensing
- SD3: Stability AI Community License (commercial restrictions)
- FLUX.1-dev: Research/development (commercial restrictions)
- FLUX.1-schnell: Apache 2.0 (commercial use allowed)
Known Issues
- Human anatomy: Still errors with fingers, etc.
- Text consistency: Occasional errors with long text
- Style diversity: May be biased toward certain styles
Conclusion
| Property | SD 1.x/2.x | SDXL | SD3 | FLUX |
|---|---|---|---|---|
| Architecture | U-Net | U-Net | MMDiT | MMDiT |
| Text-Image | Cross-attn | Cross-attn | Joint-attn | Joint-attn |
| Flow | DDPM | DDPM | Rectified | Rectified |
| Text Rendering | Poor | Fair | Good | Excellent |
| Minimum Steps | 20+ | 20+ | 20+ | 4 |
SD3 and FLUX demonstrate the paradigm shift from U-Net to Transformer, from DDPM to Rectified Flow. MMDiT's bidirectional attention significantly improves text-image alignment, and Rectified Flow enables fast generation.
References
- Esser, P., et al. "Scaling Rectified Flow Transformers for High-Resolution Image Synthesis" (SD3 Paper, 2024)
- Black Forest Labs. "FLUX.1 Technical Report" (2024)
- Peebles, W. & Xie, S. "Scalable Diffusion Models with Transformers" (DiT, 2023)
- Liu, X., et al. "Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow" (2023)