Models & Algorithms

Stable Diffusion 3 & FLUX: Complete Guide to MMDiT Architecture

From U-Net to Transformer. A deep dive into MMDiT architecture treating text and image equally, plus Rectified Flow and Guidance Distillation.

Stable Diffusion 3 & FLUX: Complete Guide to MMDiT Architecture

Stable Diffusion 3 & FLUX: Complete Guide to MMDiT Architecture

From U-Net to Transformer. A new paradigm treating Text and Image equally.

TL;DR

  • MMDiT (Multimodal DiT): Processes text and image jointly in a single Transformer
  • Rectified Flow Adoption: Straight-line paths instead of DDPM for faster generation
  • FLUX Evolution: Guidance Distillation enables CFG-free 4-8 step generation
  • Core Innovation: Bidirectional attention between text and image for better prompt following

1. Why Abandon U-Net?

U-Net's Limitations

Stable Diffusion 1.x/2.x was U-Net-based:

python
Text Encoder (CLIP) → Cross-Attention → U-Net → Image

Problems:

  • One-way information flow: Only text → image, no image → text feedback
  • Cross-attention bottleneck: Text information injected only at specific layers
  • Scaling limits: U-Net shows diminishing returns with increased model size

The Rise of DiT

DiT (Diffusion Transformer) demonstrated:

  • Transformers follow scaling laws
  • Larger models consistently improve FID
  • But DiT still used cross-attention for text

MMDiT: True Multimodal

SD3's MMDiT treats text and image as equal sequences:

python
[Text Tokens] + [Image Tokens] → Joint Transformer → [Text'] + [Image']

Bidirectional attention lets text see image, and image see text.

2. MMDiT Architecture Details

Input Processing

Image Input:

  1. VAE Encoder extracts latent: (H, W, 3)(h, w, 16)
  2. Patchify: (h, w, 16)(N_img, D)
  3. Add position embedding

Text Input:

  1. Three text encoders:

- CLIP-L (OpenAI)

- CLIP-G (OpenCLIP)

- T5-XXL (Google)

  1. Combine pooled + sequence embeddings
  2. Transform to (N_txt, D) shape

Joint Attention Block

The core is the MM-DiT Block:

python
class MMDiTBlock(nn.Module):
    def __init__(self, dim):
        self.norm1_img = AdaLayerNorm(dim)
        self.norm1_txt = AdaLayerNorm(dim)
        self.attn = JointAttention(dim)

        self.norm2_img = AdaLayerNorm(dim)
        self.norm2_txt = AdaLayerNorm(dim)
        self.ff_img = FeedForward(dim)
        self.ff_txt = FeedForward(dim)

    def forward(self, img, txt, timestep):
        # Separate normalization
        img_norm = self.norm1_img(img, timestep)
        txt_norm = self.norm1_txt(txt, timestep)

        # Joint attention (the key!)
        img_attn, txt_attn = self.attn(img_norm, txt_norm)

        img = img + img_attn
        txt = txt + txt_attn

        # Separate feedforward
        img = img + self.ff_img(self.norm2_img(img, timestep))
        txt = txt + self.ff_txt(self.norm2_txt(txt, timestep))

        return img, txt

Joint Attention Mechanism

python
class JointAttention(nn.Module):
    def forward(self, img, txt):
        # Concatenate image and text
        x = torch.cat([img, txt], dim=1)

        # Compute Q, K, V
        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)

        # Self-attention (all tokens attend to each other)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        out = attn @ v

        # Split back
        img_out, txt_out = out.split([img.shape[1], txt.shape[1]], dim=1)

        return img_out, txt_out

Key Points:

  • Image tokens attend to text tokens
  • Text tokens attend to image tokens
  • Bidirectional information flow for better text-image alignment

AdaLN (Adaptive Layer Normalization)

How timestep information is injected:

python
class AdaLayerNorm(nn.Module):
    def __init__(self, dim):
        self.norm = nn.LayerNorm(dim, elementwise_affine=False)
        self.proj = nn.Linear(dim, dim * 2)

    def forward(self, x, timestep_emb):
        # Predict scale, shift from timestep
        scale, shift = self.proj(timestep_emb).chunk(2, dim=-1)

        # Adaptive normalization
        x = self.norm(x)
        x = x * (1 + scale) + shift

        return x

3. Rectified Flow in SD3

SD3 uses Rectified Flow instead of DDPM.

Why Rectified Flow?

PropertyDDPMRectified Flow
TrajectoryCurved (SDE)Straight (ODE)
Required Steps20-504-10
Training TargetNoise predictionVelocity field
DistillationDifficultEasy

SD3's Flow Formulation

python
def flow_matching_loss(model, x0, text_emb):
    # Sample time
    t = torch.rand(x0.shape[0])

    # Sample noise
    z = torch.randn_like(x0)

    # Linear interpolation
    x_t = (1 - t) * x0 + t * z

    # Target velocity
    v_target = z - x0

    # Predict velocity
    v_pred = model(x_t, t, text_emb)

    return F.mse_loss(v_pred, v_target)

Logit-Normal Sampling

SD3's unique approach: sample timesteps from logit-normal distribution instead of uniform

python
def logit_normal_sample(batch_size, m=0.0, s=1.0):
    """Focus more on middle timesteps."""
    u = torch.randn(batch_size) * s + m
    t = torch.sigmoid(u)  # Transform to (0, 1)
    return t

Reason: Middle timesteps are more important for learning

4. FLUX: SD3's Evolution

FLUX is created by Black Forest Labs (founded by SD3 developers).

FLUX vs SD3 Comparison

PropertySD3FLUX
DeveloperStability AIBlack Forest Labs
ArchitectureMMDiTMMDiT (improved)
GuidanceCFG requiredDistilled (CFG-free possible)
Model Size2B, 8B12B
Minimum Steps20-304 (schnell)

FLUX Variants

  1. FLUX.1-pro: Highest quality, API only
  2. FLUX.1-dev: Research/development, open weights
  3. FLUX.1-schnell: 1-4 step generation, fastest

Guidance Distillation

The key technique in FLUX.1-schnell:

Traditional CFG (Classifier-Free Guidance):

python
# Requires 2x computation at inference
pred_uncond = model(x_t, t, null_text)
pred_cond = model(x_t, t, text)
pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)

After Guidance Distillation:

python
# Single forward pass achieves CFG effect
pred = model(x_t, t, text)  # Guidance internalized

Training method:

python
def guidance_distillation_loss(student, teacher, x_t, t, text):
    # Teacher: Apply CFG
    with torch.no_grad():
        pred_uncond = teacher(x_t, t, null_text)
        pred_cond = teacher(x_t, t, text)
        target = pred_uncond + cfg_scale * (pred_cond - pred_uncond)

    # Student: Single forward
    pred = student(x_t, t, text)

    return F.mse_loss(pred, target)

5. Text Encoder Strategy

SD3's Triple Text Encoder

SD3 uses three text encoders:

python
class TripleTextEncoder:
    def __init__(self):
        self.clip_l = CLIPTextModel("openai/clip-vit-large")
        self.clip_g = CLIPTextModel("laion/CLIP-ViT-bigG")
        self.t5 = T5EncoderModel("google/t5-v1_1-xxl")

    def encode(self, text):
        # CLIP embeddings (pooled)
        clip_l_pooled, clip_l_seq = self.clip_l(text)
        clip_g_pooled, clip_g_seq = self.clip_g(text)

        # T5 embedding (sequence only)
        t5_seq = self.t5(text)

        # Pooled: for conditioning
        pooled = torch.cat([clip_l_pooled, clip_g_pooled], dim=-1)

        # Sequence: for cross-attention
        seq = torch.cat([clip_l_seq, clip_g_seq, t5_seq], dim=1)

        return pooled, seq

Why three encoders?

EncoderStrengthToken Limit
CLIP-LGeneral visual concepts77
CLIP-GLarger capacity77
T5-XXLLong text, complex relationships512

T5 addition significantly improved long prompt and complex relationship understanding.

FLUX's Text Processing

FLUX simplifies:

  • CLIP-L + T5-XXL combination
  • More reliance on T5 (leveraging longer context)

6. VAE Improvements

SD3's 16-Channel VAE

Previous SD 1.x/2.x: 4-channel latent

SD3/FLUX: 16-channel latent

python
# SD 1.x/2.x
latent = vae.encode(image)  # (B, 4, H/8, W/8)

# SD3/FLUX
latent = vae.encode(image)  # (B, 16, H/8, W/8)

Advantages:

  • More information preserved
  • Better fine detail reconstruction
  • Improved text rendering quality

Disadvantages:

  • Increased memory usage
  • Higher computation

7. Practical Usage Examples

Using SD3 with Diffusers

python
from diffusers import StableDiffusion3Pipeline
import torch

pipe = StableDiffusion3Pipeline.from_pretrained(
    "stabilityai/stable-diffusion-3-medium",
    torch_dtype=torch.float16
)
pipe = pipe.to("cuda")

image = pipe(
    prompt="A cat holding a sign that says 'Hello World'",
    num_inference_steps=28,
    guidance_scale=7.0,
).images[0]

Using FLUX with Diffusers

python
from diffusers import FluxPipeline
import torch

# FLUX.1-schnell (fast version)
pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell",
    torch_dtype=torch.bfloat16
)
pipe = pipe.to("cuda")

image = pipe(
    prompt="A cat holding a sign that says 'Hello World'",
    num_inference_steps=4,  # Only 4 steps!
    guidance_scale=0.0,     # No CFG needed
).images[0]

Memory Optimization

python
# CPU offload
pipe.enable_model_cpu_offload()

# Attention slicing
pipe.enable_attention_slicing()

# VAE tiling (for high resolution)
pipe.enable_vae_tiling()

8. Performance Comparison

Text Rendering Capability

SD3/FLUX's biggest improvement: Accurate text rendering in images

Model"Hello World" Accuracy
SD 1.5~10%
SD 2.1~20%
SDXL~40%
SD3~80%
FLUX~90%

Prompt Following

Complex prompt test: "A red cube on top of a blue sphere, with a green pyramid to the left"

ModelAccuracy
SDXLMedium
SD3High
FLUXVery High

Generation Speed (A100)

ModelStepsTime
SDXL30~3s
SD328~4s
FLUX-dev20~5s
FLUX-schnell4~1s

9. Limitations and Considerations

Memory Requirements

ModelVRAM (fp16)
SDXL~8GB
SD3-medium~12GB
FLUX-dev~24GB
FLUX-schnell~12GB

Licensing

  • SD3: Stability AI Community License (commercial restrictions)
  • FLUX.1-dev: Research/development (commercial restrictions)
  • FLUX.1-schnell: Apache 2.0 (commercial use allowed)

Known Issues

  1. Human anatomy: Still errors with fingers, etc.
  2. Text consistency: Occasional errors with long text
  3. Style diversity: May be biased toward certain styles

Conclusion

PropertySD 1.x/2.xSDXLSD3FLUX
ArchitectureU-NetU-NetMMDiTMMDiT
Text-ImageCross-attnCross-attnJoint-attnJoint-attn
FlowDDPMDDPMRectifiedRectified
Text RenderingPoorFairGoodExcellent
Minimum Steps20+20+20+4

SD3 and FLUX demonstrate the paradigm shift from U-Net to Transformer, from DDPM to Rectified Flow. MMDiT's bidirectional attention significantly improves text-image alignment, and Rectified Flow enables fast generation.

References

  1. Esser, P., et al. "Scaling Rectified Flow Transformers for High-Resolution Image Synthesis" (SD3 Paper, 2024)
  2. Black Forest Labs. "FLUX.1 Technical Report" (2024)
  3. Peebles, W. & Xie, S. "Scalable Diffusion Models with Transformers" (DiT, 2023)
  4. Liu, X., et al. "Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow" (2023)