Apr 3, 2026

Mixture of Experts - Scaling Without the Compute Tax

How Mixture of Experts lets language models have trillions of parameters while only using a fraction per token. From gating networks to Switch Transformers to Mixtral - with interactive visualizations and PyTorch code.

Mixture of Experts - Scaling Without the Compute Tax hero image

There is a simple recipe for making language models better: make them bigger. More parameters, more data, more compute. GPT-2 had 1.5 billion parameters. GPT-3 had 175 billion. GPT-4 likely exceeded a trillion. Every generation, the numbers went up, and so did the capabilities.

But this recipe has a problem. A fatal one.

If you double the number of parameters in a dense model, you approximately double the compute cost per token. Every single token - from the word "the" to a complex multi-step reasoning chain - flows through every single parameter. A 175B-parameter model must execute 175 billion parameter operations for the word "the." That feels wasteful.

What if the model could choose which parameters to use, depending on the input? What if a trillion-parameter model could route different tokens through different subsets of its parameters, using only a fraction of them for each token?

This is the central idea behind Mixture of Experts (MoE) - and it is one of the most important architectural innovations in modern AI.

The Scaling Dilemma

Let's quantify the problem. Training compute scales roughly as:

C6NDC \approx 6ND

where CC is total floating-point operations, NN is the number of parameters, and DD is the number of training tokens (Kaplan et al., 2020). Inference cost per token is proportional to NN.

For a dense model, every token activates every parameter. The relationship is linear and inescapable:

ModelTotal ParamsActive Params/TokenFLOPs/Token
GPT-21.5B1.5B~3B
GPT-3175B175B~350B
Dense 1T1T1T~2T

That last row is the wall. Training a 1-trillion-parameter dense model would require approximately 6×1012×D6 \times 10^{12} \times D FLOPs. At current hardware prices, this is economically impractical for most organizations. And at inference time, every user query would burn through a trillion operations per token.

We need a way to decouple model capacity (total parameters = knowledge storage) from compute cost (active parameters per token).

The MoE Idea: Conditional Computation

The insight dates back to 1991, when Jacobs et al. proposed the original Mixture of Experts paper. The idea is elegant:

Instead of one massive network that processes everything, use a collection of smaller specialist networks (experts) and a learned gating function that decides which experts to consult for each input.

Think of it like a hospital. A hospital has hundreds of specialists - cardiologists, neurologists, dermatologists, radiologists. When a patient arrives, they don't see every doctor. A triage nurse (the router) examines the patient and directs them to the relevant specialists. The hospital has vast total expertise, but each patient only uses a small fraction of it.

In neural network terms:

  • Experts = independent feed-forward networks (FFNs), each with their own parameters
  • Router/Gate = a small network that assigns each input token to the top-K experts
  • Sparsity = only K out of N experts activate per token (typically K=1 or K=2)

Dense vs Sparse: The Key Trade-off

In a dense model, the feed-forward network (FFN) in each transformer layer has, say, dmodel×dffd_{\text{model}} \times d_{\text{ff}} parameters, and every token flows through all of them.

In a sparse MoE model, we replace that single FFN with NN independent expert FFNs, but each token only activates KK of them:

Dense FLOPs per token=2dmodeldff\text{Dense FLOPs per token} = 2 \cdot d_{\text{model}} \cdot d_{\text{ff}}

MoE FLOPs per token=K2dmodeldffexpert\text{MoE FLOPs per token} = K \cdot 2 \cdot d_{\text{model}} \cdot d_{\text{ff}}^{\text{expert}}

If each expert is the same size as the original dense FFN, the total parameters increase by N×N\times, but the compute only increases by K×K\times (where KNK \ll N). If we shrink each expert to keep total compute constant, we get N×N\times more parameters for free in terms of compute.

This is the fundamental value proposition of MoE: scale parameters independently of compute.

Gating Networks: The Router

The router is the brain of the MoE system. It receives a token representation xRdmodelx \in \mathbb{R}^{d_{\text{model}}} and produces a probability distribution over all NN experts.

Softmax Gating

The simplest gating mechanism is a learned linear projection followed by softmax:

G(x)=Softmax(Wgx)G(x) = \text{Softmax}(W_g \cdot x)

where WgRN×dmodelW_g \in \mathbb{R}^{N \times d_{\text{model}}} is the gating weight matrix. This produces a vector G(x)RNG(x) \in \mathbb{R}^N where each entry represents the probability of routing to that expert.

Top-K Routing

We don't use all experts. Instead, we select the top-K experts by gating probability and zero out the rest:

Gsparse(x)=TopK(Softmax(Wgx),  K)G_{\text{sparse}}(x) = \text{TopK}\big(\text{Softmax}(W_g \cdot x),\; K\big)

The final MoE output for a single token is:

y=iTopKGsparse(x)iEi(x)y = \sum_{i \in \text{TopK}} G_{\text{sparse}}(x)_i \cdot E_i(x)

where Ei(x)E_i(x) is the output of expert ii applied to input xx. The gating weights Gsparse(x)iG_{\text{sparse}}(x)_i serve as the mixing coefficients - they determine how much each selected expert contributes to the final output.

Noisy Top-K Gating (Shazeer et al., 2017)

A critical improvement from the landmark 2017 paper "Outrageously Large Neural Networks" was adding noise to the gating logits before the softmax. This encourages exploration during training and prevents the router from collapsing to always select the same experts:

H(x)=Wgx+ϵSoftplus(Wnoisex)H(x) = W_g \cdot x + \epsilon \cdot \text{Softplus}(W_{\text{noise}} \cdot x)

G(x)=TopK(Softmax(H(x)),  K)G(x) = \text{TopK}\big(\text{Softmax}(H(x)),\; K\big)

where ϵN(0,1)\epsilon \sim \mathcal{N}(0, 1) is standard Gaussian noise and WnoiseW_{\text{noise}} is a learned noise scale. The noise injection is only applied during training, not inference.

Why Top-K and Not Full Softmax?

Using the full softmax distribution (all experts weighted) would make the model dense again - every expert would process every token, just with different weights. The sparsity comes precisely from the hard Top-K selection, which zeros out NKN - K experts entirely.

This raises a subtle question: can you backpropagate through a hard Top-K selection? Yes - the gradients flow through the selected experts normally (their gating weights are differentiable through the softmax), and the non-selected experts receive zero gradient for that token. Over many tokens in a batch, each expert sees enough gradient signal to learn.

The Load Balancing Problem

Here is where MoE gets interesting - and difficult. Left to its own devices, the router will collapse. It will learn to send almost all tokens to one or two experts, leaving the rest idle. This is called expert collapse, and it is the central challenge of training MoE models.

Why Does Collapse Happen?

It's a positive feedback loop. Early in training, one expert might be slightly better than the others by random chance. The router sends more tokens to it. With more tokens, it gets more gradient updates and improves faster. The router notices it's even better now and sends even more tokens. Within a few thousand steps, a single expert handles 90%+ of all tokens, and the rest are essentially dead parameters.

The Load Balancing Loss

The fix is an auxiliary loss that penalizes uneven routing. The most common formulation (from Shazeer et al., 2017, refined in Switch Transformer) is:

Laux=αNi=1NfiPi\mathcal{L}_{\text{aux}} = \alpha \cdot N \cdot \sum_{i=1}^{N} f_i \cdot P_i

where:

  • NN = number of experts
  • fif_i = fraction of tokens actually routed to expert ii in the batch
  • PiP_i = average router probability assigned to expert ii across the batch
  • α\alpha = a small coefficient (typically 10210^{-2} to 10110^{-1})

This loss is minimized when fi=Pi=1Nf_i = P_i = \frac{1}{N} for all experts (perfect balance). The intuition: if an expert is receiving too many tokens (fif_i high) and the router is assigning it high probability (PiP_i high), the product fiPif_i \cdot P_i is large, and the loss pushes back.

Expert Capacity and Token Dropping

Even with the balancing loss, imbalances still occur. To handle this at the implementation level, each expert has a capacity factor CC that limits how many tokens it can process:

Expert Capacity=Ctokens in batchN\text{Expert Capacity} = C \cdot \frac{\text{tokens in batch}}{N}

where CC is typically 1.0 to 1.5. If more tokens are routed to an expert than its capacity allows, the overflow tokens are dropped - they skip the MoE layer entirely and pass through via the residual connection.

This sounds alarming, but in practice it works because:

  1. The load balancing loss keeps drops rare
  2. Dropped tokens still have the residual connection (the MoE layer is additive)
  3. The capacity factor C>1C > 1 provides headroom

The Expert Networks

Each expert is typically a standard feed-forward network (FFN), identical in architecture to the FFN in a normal transformer layer:

Ei(x)=W2(i)σ(W1(i)x)E_i(x) = W_2^{(i)} \cdot \sigma(W_1^{(i)} \cdot x)

where W1(i)Rdff×dmodelW_1^{(i)} \in \mathbb{R}^{d_{\text{ff}} \times d_{\text{model}}}, W2(i)Rdmodel×dffW_2^{(i)} \in \mathbb{R}^{d_{\text{model}} \times d_{\text{ff}}}, and σ\sigma is a non-linearity (ReLU, GELU, or SwiGLU in modern architectures).

Shared vs. Independent Experts

Some architectures include a shared expert that processes every token, combined with specialized sparse experts:

y=Eshared(x)+iTopKG(x)iEi(x)y = E_{\text{shared}}(x) + \sum_{i \in \text{TopK}} G(x)_i \cdot E_i(x)

The shared expert captures common patterns (basic syntax, frequent collocations), while the sparse experts specialize. DeepSeek-MoE (2024) uses this approach with 2 shared experts and 64 routed experts, arguing it reduces redundancy among the routed experts.

What Do Experts Specialize In?

This is a fascinating research question. Analysis of trained MoE models reveals that experts develop distinct specializations, though not always in ways humans would predict:

  • Linguistic domains: Some experts handle punctuation and formatting; others handle content words
  • Positional patterns: Certain experts preferentially activate for tokens at specific sequence positions
  • Semantic categories: In multilingual models, experts sometimes specialize by language
  • Frequency-based: Common tokens tend to route to a small set of "generalist" experts, while rare tokens activate different specialists

However, the specialization is soft - experts have overlapping capabilities, and the routing is statistical rather than categorical.

The Switch Transformer: Simplified Routing

In 2021, Fedus et al. published the Switch Transformer, which made a bold simplification: use Top-1 routing instead of Top-2.

Why Top-1 Works

This was counterintuitive. Surely using two experts per token gives more capacity? But Switch Transformer showed three advantages of Top-1:

1. Halved compute per token. With K=1K=1, each token only runs through one expert FFN instead of two. This doubles the number of experts you can afford for the same compute budget.

2. Simpler routing. No need to compute weighted combinations of multiple expert outputs. The output is simply G(x)iEi(x)G(x)_{i^*} \cdot E_{i^*}(x) where i=argmaxiG(x)ii^* = \arg\max_i G(x)_i.

3. Reduced communication. In distributed settings, each token only needs to be sent to one device (where its chosen expert resides) instead of two.

Switch Transformer Architecture

The Switch Transformer replaces the FFN in every other transformer layer with a Switch layer:

Standard Transformer Layer: x → LayerNorm → Multi-Head Attention → + → LayerNorm → FFN → + (residual connection) (residual connection) Switch Transformer Layer: x → LayerNorm → Multi-Head Attention → + → LayerNorm → Switch(MoE) → + (residual connection) (residual connection)

The attention mechanism is completely unchanged. MoE only modifies the FFN component. This is a crucial design point: the attention layers remain dense because every token potentially needs to attend to every other token, but the FFN computation (which processes each token independently) can be safely sparsified.

Scaling Results

The Switch Transformer demonstrated remarkable scaling properties:

ModelExpert CountTotal ParamsActive ParamsSpeedup vs Dense
Switch-Base1287.4B223M7x
Switch-Large12826B783M7x
Switch-XXL128395B3.1B4x

The "speedup" here means: to reach the same loss as a dense model, the Switch Transformer needs that factor fewer training steps. A Switch-Base model with 7.4 billion total parameters reaches the same quality as a dense model in one-seventh the training time, despite only using 223 million parameters per token.

Selective Precision

One practical contribution of Switch Transformer was showing that MoE layers are more sensitive to numerical precision than dense layers. The paper found that casting the router computation to float32 (while keeping expert computation in bfloat16) stabilized training significantly. This selective precision approach has become standard practice.

python
# Pseudocode for selective precision in Switch routing def switch_route(x, gate_weights): # Router computation in float32 for stability logits = torch.matmul( x.float(), # cast input to fp32 gate_weights.float() # cast gate weights to fp32 ) probs = F.softmax(logits, dim=-1) top1_idx = probs.argmax(dim=-1) top1_weight = probs.gather(-1, top1_idx.unsqueeze(-1)) # Back to bfloat16 for expert computation return top1_idx, top1_weight.to(x.dtype)

GShard and Expert Parallelism

When you have 128 or more experts, they don't fit on a single GPU. This introduces a new parallelism dimension: expert parallelism.

The Parallelism Landscape

Modern large model training uses multiple types of parallelism simultaneously:

Data Parallelism (DP): Replicate the entire model across devices, split the data batch. Each device processes a different subset of tokens. Gradients are synchronized via all-reduce.

Tensor Parallelism (TP): Split individual weight matrices across devices. Each device holds a slice of every layer. Requires communication within each layer's forward pass.

Pipeline Parallelism (PP): Split layers across devices. Device 1 runs layers 1-10, device 2 runs layers 11-20. Tokens flow through the pipeline.

Expert Parallelism (EP): Each device holds a subset of experts. Tokens are dynamically routed to the device hosting their assigned expert.

Expert parallelism is unique because the communication pattern is data-dependent. In DP, TP, and PP, you know in advance which data goes where. In EP, the routing decision happens at runtime, so you need an all-to-all communication operation.

All-to-All Communication

In expert parallelism, the all-to-all operation works as follows:

  1. Each device has a batch of tokens and a subset of experts
  2. The router on each device determines which expert each local token needs
  3. Tokens are sent to the device hosting their assigned expert (all-to-all dispatch)
  4. Each device runs its local experts on all received tokens
  5. Results are sent back to the originating devices (all-to-all combine)
python
# Simplified all-to-all dispatch pseudocode def expert_parallel_forward(tokens, router, experts, world_size): """ tokens: (local_batch, seq_len, d_model) on each device experts: subset of experts on this device """ # Step 1: Route tokens gates = router(tokens) # (local_batch * seq_len, num_experts) top1_expert = gates.argmax(dim=-1) # Step 2: All-to-all dispatch # Group tokens by target expert device send_buffers = group_by_device(tokens, top1_expert, world_size) recv_buffers = all_to_all(send_buffers) # Step 3: Run local experts outputs = run_local_experts(recv_buffers, experts) # Step 4: All-to-all combine (send results back) result_buffers = all_to_all(outputs) return reassemble(result_buffers, top1_expert)

GShard (Lepikhin et al., 2020)

GShard was one of the first systems to scale MoE to 600 billion parameters across 2048 TPU cores. Key innovations included:

Random routing for the second expert. In top-2 routing, the first expert is selected deterministically (highest gate value), but the second expert is selected proportionally to its gate probability. This increases exploration and load balance.

Local group dispatching. Instead of global all-to-all over all devices, GShard groups tokens locally and dispatches within groups. This reduces communication overhead at the cost of slightly less optimal routing.

Compiler-based optimization. GShard expressed the MoE computation as XLA operations and let the compiler handle partitioning and communication insertion, rather than manually coding the distributed logic.

Mixtral and Modern MoE

In December 2023, Mistral AI released Mixtral 8x7B, which brought MoE to the open-source frontier and demonstrated that MoE could match dense models 3-4x its active size.

Mixtral 8x7B Architecture

Mixtral's naming is revealing: 8 experts, each 7B parameters in the FFN layers. But the total model is not 56B parameters:

  • Shared parameters (embedding, attention layers, layer norms): ~6.7B
  • Expert parameters (8 expert FFNs across all layers): ~39.1B
  • Total parameters: ~46.7B
  • Active parameters per token: ~12.9B (shared + 2 active experts)

The model uses top-2 routing with 8 experts per MoE layer, replacing the FFN in every transformer layer (32 layers total). Each expert FFN uses SwiGLU activation:

Ei(x)=W3(i)(SiLU(W1(i)x)W2(i)x)E_i(x) = W_3^{(i)} \cdot (\text{SiLU}(W_1^{(i)} \cdot x) \odot W_2^{(i)} \cdot x)

where \odot is element-wise multiplication and SiLU is the Sigmoid Linear Unit.

Why 8x7B is Not 56B

This is a common source of confusion. Here's the accounting:

Per MoE Layer: - Router: d_model × 8 parameters (tiny) - 8 experts: 8 × (3 × d_model × d_ff) (3 weight matrices per SwiGLU expert) - Active: 2 × (3 × d_model × d_ff) (only top-2 experts run) Per Attention Layer: - Q, K, V, O projections (always active, same as dense) - Layer norms (always active) Result: - Total param count ≈ 46.7B (all experts exist) - Active param count ≈ 12.9B (only 2/8 experts compute per token) - Inference FLOPs ≈ equivalent to a ~13B dense model - Quality ≈ comparable to a ~34B dense model

The magic: you get the knowledge capacity of a 47B model with the inference cost of a 13B model. The extra parameters store specialized knowledge that is conditionally accessed.

Mixtral Performance

On standard benchmarks, Mixtral 8x7B matched or exceeded LLaMA 2 70B (a dense model with 5.4x more active parameters) on most tasks:

  • MMLU: Mixtral 8x7B (70.6) vs LLaMA 2 70B (68.9)
  • HellaSwag: Mixtral (86.7) vs LLaMA 2 (87.3) - roughly tied
  • ARC-Challenge: Mixtral (66.4) vs LLaMA 2 (64.6)
  • Code (HumanEval): Mixtral (40.2) vs LLaMA 2 (29.9) - significant gap

This is the promise of MoE: dense-model quality at a fraction of the compute.

Expert Specialization in Mixtral

Analysis of Mixtral's routing patterns reveals interesting specialization. Unlike earlier theoretical expectations of clean domain separation, the pattern is more nuanced:

  • No expert is specialized in a single domain. Each expert handles a broad mix of inputs.
  • Specialization is syntactic, not semantic. Experts tend to specialize in token types (punctuation, operators, content words) rather than topics (science vs. law).
  • Routing is surprisingly consistent across layers. A given token tends to be routed to similar expert IDs across different layers, suggesting a form of "expert affinity."
  • The first token of each sentence often routes differently from subsequent tokens, suggesting experts learn position-dependent patterns.

DeepSeek-MoE: Fine-Grained Experts

DeepSeek (2024) pushed MoE further with two innovations: fine-grained expert segmentation and shared experts.

Fine-Grained Experts

Instead of 8 large experts, DeepSeek-MoE uses 64 small experts with top-6 routing. The idea: smaller experts can specialize more narrowly, reducing redundancy. If each expert is smaller, you need to activate more of them (higher K), but the total activated parameters can still be less.

Standard: 8 experts, K=2    25% activated\text{Standard: } 8 \text{ experts, } K=2 \implies 25\% \text{ activated}

DeepSeek: 64 experts, K=6    9.4% activated\text{DeepSeek: } 64 \text{ experts, } K=6 \implies 9.4\% \text{ activated}

Fewer total parameters activated, but more distinct specialization paths. DeepSeek-MoE 16B (2.8B active) matched the performance of a dense 7B model - a 2.5x efficiency gain.

Shared Experts

DeepSeek added 2 shared experts that process every token, regardless of routing:

y=s=1NsEsshared(x)+iTopKG(x)iEirouted(x)y = \sum_{s=1}^{N_s} E_s^{\text{shared}}(x) + \sum_{i \in \text{TopK}} G(x)_i \cdot E_i^{\text{routed}}(x)

The shared experts capture universal patterns (basic language structure), freeing the routed experts to focus purely on specialized knowledge. This reduced expert redundancy and improved performance.

Training Challenges

Training MoE models is significantly harder than training dense models. Here are the major challenges and their solutions.

1. Training Instability

MoE models are notoriously unstable during training. The router creates a discrete, data-dependent computation graph that changes every step. Small perturbations in the router weights can cause large shifts in which experts see which tokens.

Solutions:

  • Router z-loss (Zoph et al., 2022): An additional penalty on the magnitude of router logits before softmax, preventing them from becoming too large and causing numerical issues:

Lz=1Bx(logi=1Nezi(x))2\mathcal{L}_z = \frac{1}{B} \sum_{x} \left(\log \sum_{i=1}^{N} e^{z_i(x)}\right)^2

where zi(x)z_i(x) are the raw router logits. This keeps the logits in a numerically stable range.

  • Selective precision (mentioned above): float32 for router, bfloat16 for experts
  • Lower learning rate for router: Some practitioners use a 10x smaller learning rate for the gating weights compared to expert weights
  • Gradient clipping: Aggressive gradient clipping (max norm 0.3-1.0) prevents individual steps from destabilizing routing

2. Expert Collapse

As described earlier, experts collapse when the router learns to always select the same few experts:

Solutions:

  • Auxiliary load balancing loss (see above)
  • Random routing noise during training (Shazeer et al., 2017)
  • Expert dropout: Randomly disable experts during training to force the router to use alternatives
  • Batch-prioritized routing (Riquelme et al., 2021): Process the highest-priority tokens first, ensuring important tokens are never dropped

3. Memory Overhead

While MoE reduces compute, it doesn't reduce memory. All expert parameters must be stored, even though most are idle for any given token:

python
# Memory comparison dense_7b_memory = 7e9 * 2 # 14 GB in float16 moe_47b_memory = 46.7e9 * 2 # 93.4 GB in float16 # Compute comparison (FLOPs per token) dense_7b_flops = 14e9 # proportional to params moe_13b_active_flops = 25.8e9 # proportional to active params # The paradox: MoE uses less compute but MORE memory

This is the MoE memory paradox: you save FLOPs but pay in memory. Solutions include expert offloading (keeping inactive experts on CPU/disk), expert pruning (removing underutilized experts after training), and quantization (MoE models are often good candidates for aggressive quantization because individual experts are less sensitive to precision loss).

4. Fine-tuning Difficulty

MoE models can be harder to fine-tune than dense models because:

  • The router may not adapt well to a new domain's token distribution
  • Fine-tuning data is typically much smaller, so many experts see very few examples
  • Load balancing dynamics change with the new data distribution

Solutions:

  • Freeze the router during fine-tuning and only update expert weights
  • Use a higher auxiliary loss coefficient during fine-tuning
  • Fine-tune only the top-K most utilized experts for the target domain

PyTorch Implementation

Let's build a complete MoE layer from scratch. We'll implement the router, experts, top-K selection, and load balancing loss.

The Expert FFN

python
import torch import torch.nn as nn import torch.nn.functional as F class ExpertFFN(nn.Module): """A single expert: standard FFN with SwiGLU activation.""" def __init__(self, d_model: int, d_ff: int): super().__init__() self.w1 = nn.Linear(d_model, d_ff, bias=False) self.w2 = nn.Linear(d_ff, d_model, bias=False) self.w3 = nn.Linear(d_model, d_ff, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: # SwiGLU: W2 * (SiLU(W1 * x) ⊙ W3 * x) return self.w2(F.silu(self.w1(x)) * self.w3(x))

The Router

python
class TopKRouter(nn.Module): """Learned router with top-K selection and load balancing.""" def __init__( self, d_model: int, num_experts: int, top_k: int = 2, noise_std: float = 0.1, ): super().__init__() self.num_experts = num_experts self.top_k = top_k self.noise_std = noise_std self.gate = nn.Linear(d_model, num_experts, bias=False) def forward( self, x: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x: (batch, seq_len, d_model) Returns: top_k_weights: (batch * seq_len, top_k) - normalized gate weights top_k_indices: (batch * seq_len, top_k) - expert indices aux_loss: scalar - load balancing loss """ batch, seq_len, d_model = x.shape x_flat = x.view(-1, d_model) # (B*S, d_model) # Compute gate logits in float32 for stability logits = self.gate(x_flat.float()) # (B*S, num_experts) # Add noise during training for exploration if self.training and self.noise_std > 0: noise = torch.randn_like(logits) * self.noise_std logits = logits + noise # Softmax to get routing probabilities probs = F.softmax(logits, dim=-1) # (B*S, num_experts) # Top-K selection top_k_weights, top_k_indices = torch.topk( probs, self.top_k, dim=-1 ) # both (B*S, top_k) # Renormalize weights to sum to 1 top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True) # Cast back to input dtype top_k_weights = top_k_weights.to(x.dtype) # ---- Load balancing auxiliary loss ---- # f_i: fraction of tokens routed to each expert num_tokens = x_flat.shape[0] # Create one-hot of selected experts and sum expert_mask = F.one_hot( top_k_indices, self.num_experts ).float() # (B*S, top_k, E) expert_mask = expert_mask.sum(dim=1) # (B*S, E) - may be > 1 if top_k > 1 f = expert_mask.sum(dim=0) / num_tokens # (E,) # P_i: average probability assigned to each expert P = probs.mean(dim=0) # (E,) # Auxiliary loss: N * sum(f_i * P_i) aux_loss = self.num_experts * (f * P).sum() return top_k_weights, top_k_indices, aux_loss

The Full MoE Layer

python
class MoELayer(nn.Module): """ Mixture of Experts layer that replaces a standard FFN in a Transformer block. """ def __init__( self, d_model: int, d_ff: int, num_experts: int = 8, top_k: int = 2, capacity_factor: float = 1.25, ): super().__init__() self.num_experts = num_experts self.top_k = top_k self.capacity_factor = capacity_factor # Create expert FFNs self.experts = nn.ModuleList( [ExpertFFN(d_model, d_ff) for _ in range(num_experts)] ) # Router self.router = TopKRouter( d_model, num_experts, top_k=top_k ) def forward( self, x: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: """ Args: x: (batch, seq_len, d_model) Returns: output: (batch, seq_len, d_model) aux_loss: scalar """ batch, seq_len, d_model = x.shape num_tokens = batch * seq_len x_flat = x.view(num_tokens, d_model) # Route tokens to experts weights, indices, aux_loss = self.router(x) # weights: (num_tokens, top_k) # indices: (num_tokens, top_k) # Initialize output output = torch.zeros_like(x_flat) # Process each expert for expert_idx in range(self.num_experts): # Find which tokens are routed to this expert # and which top-K slot they occupy expert_mask = (indices == expert_idx) # (num_tokens, top_k) # Get token indices and their corresponding weights for k in range(self.top_k): token_mask = expert_mask[:, k] # (num_tokens,) if not token_mask.any(): continue # Apply capacity constraint token_indices = token_mask.nonzero(as_tuple=True)[0] capacity = int( self.capacity_factor * num_tokens / self.num_experts ) if len(token_indices) > capacity: token_indices = token_indices[:capacity] # Get tokens for this expert expert_input = x_flat[token_indices] # (n, d_model) # Run expert expert_output = self.experts[expert_idx]( expert_input ) # (n, d_model) # Weight by gating coefficient gate_weights = weights[token_indices, k].unsqueeze( -1 ) # (n, 1) # Accumulate into output output[token_indices] += gate_weights * expert_output output = output.view(batch, seq_len, d_model) return output, aux_loss

Putting It in a Transformer Block

python
class MoETransformerBlock(nn.Module): """A single transformer block with MoE replacing the FFN.""" def __init__( self, d_model: int = 768, n_heads: int = 12, d_ff: int = 2048, num_experts: int = 8, top_k: int = 2, ): super().__init__() self.ln1 = nn.LayerNorm(d_model) self.attn = nn.MultiheadAttention( d_model, n_heads, batch_first=True ) self.ln2 = nn.LayerNorm(d_model) self.moe = MoELayer( d_model, d_ff, num_experts=num_experts, top_k=top_k, ) def forward( self, x: torch.Tensor, mask: torch.Tensor = None ) -> tuple[torch.Tensor, torch.Tensor]: # Self-attention with residual h = self.ln1(x) h, _ = self.attn(h, h, h, attn_mask=mask) x = x + h # MoE FFN with residual h = self.ln2(x) h, aux_loss = self.moe(h) x = x + h return x, aux_loss

Training Loop Integration

The key difference from a standard training loop is incorporating the auxiliary loss:

python
def train_step(model, batch, optimizer, aux_loss_coeff=0.01): """Training step for an MoE model.""" optimizer.zero_grad() tokens, targets = batch logits, total_aux_loss = model(tokens) # Standard language modeling loss lm_loss = F.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1), ) # Combined loss: primary task + load balancing total_loss = lm_loss + aux_loss_coeff * total_aux_loss total_loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() return { "lm_loss": lm_loss.item(), "aux_loss": total_aux_loss.item(), "total_loss": total_loss.item(), }

Testing Our Implementation

python
# Quick sanity check d_model = 512 d_ff = 1024 num_experts = 8 top_k = 2 batch_size = 4 seq_len = 32 moe_block = MoETransformerBlock( d_model=d_model, n_heads=8, d_ff=d_ff, num_experts=num_experts, top_k=top_k, ) x = torch.randn(batch_size, seq_len, d_model) output, aux_loss = moe_block(x) print(f"Input shape: {x.shape}") print(f"Output shape: {output.shape}") print(f"Aux loss: {aux_loss.item():.4f}") # Count parameters total_params = sum(p.numel() for p in moe_block.parameters()) expert_params = sum( p.numel() for expert in moe_block.moe.experts for p in expert.parameters() ) active_params = total_params - expert_params + (expert_params // num_experts) * top_k print(f"\nTotal parameters: {total_params:,}") print(f"Expert parameters: {expert_params:,}") print(f"Active per token: {active_params:,}") print(f"Sparsity ratio: {1 - active_params / total_params:.1%}")

Expected output:

Input shape: torch.Size([4, 32, 512]) Output shape: torch.Size([4, 32, 512]) Aux loss: 0.1254 Total parameters: 16,589,312 Expert parameters: 12,582,912 Active per token: 7,152,256 Sparsity ratio: 56.9%

Efficient MoE: Implementation Tricks

The naive loop-over-experts implementation above works for understanding, but is impractical for real systems. Here are the key optimizations:

Batched Expert Computation

Instead of looping over experts sequentially, group all tokens by their assigned expert and use a single batched matrix multiplication per expert:

python
def efficient_moe_forward(x_flat, weights, indices, experts, num_experts, top_k): """Efficient MoE forward without Python loops over experts.""" num_tokens, d_model = x_flat.shape # Sort tokens by expert assignment for batched computation # Flatten top-K indices: each token appears K times flat_indices = indices.view(-1) # (num_tokens * top_k,) flat_weights = weights.view(-1, 1) # (num_tokens * top_k, 1) flat_tokens = x_flat.repeat_interleave( # (num_tokens * top_k, d_model) top_k, dim=0 ) # Sort by expert index for coalesced memory access sort_order = flat_indices.argsort() sorted_indices = flat_indices[sort_order] sorted_tokens = flat_tokens[sort_order] sorted_weights = flat_weights[sort_order] # Find boundaries for each expert's batch expert_counts = torch.bincount( sorted_indices, minlength=num_experts ) expert_offsets = torch.cumsum(expert_counts, dim=0) expert_offsets = torch.cat( [torch.zeros(1, dtype=torch.long, device=x_flat.device), expert_offsets] ) # Run each expert on its batch (now contiguous in memory) output_sorted = torch.zeros_like(sorted_tokens) for i, expert in enumerate(experts): start = expert_offsets[i].item() end = expert_offsets[i + 1].item() if start < end: output_sorted[start:end] = expert(sorted_tokens[start:end]) # Weight and unsort output_sorted = output_sorted * sorted_weights unsort_order = sort_order.argsort() output_flat = output_sorted[unsort_order] # Reduce across top-K dimension output = output_flat.view(num_tokens, top_k, d_model).sum(dim=1) return output

Megablocks: Sparse Matrix Approach

The most efficient approach, used in libraries like Megablocks (Gale et al., 2023), formulates the entire MoE computation as a sparse matrix multiplication. Instead of grouping tokens by expert, they construct a block-sparse matrix that maps tokens to experts:

Tokens: [t0, t1, t2, t3, t4, t5] Experts: [E0, E1, E2, E3] Routing: t0→E0,E2 t1→E1,E3 t2→E0,E1 t3→E2,E3 t4→E0,E2 t5→E1,E3 Sparse dispatch matrix (non-zero = gating weight): E0 E1 E2 E3 t0 [0.6 0 0.4 0 ] t1 [ 0 0.5 0 0.5 ] t2 [0.7 0.3 0 0 ] t3 [ 0 0 0.4 0.6 ] t4 [0.3 0 0.7 0 ] t5 [ 0 0.6 0 0.4 ] This sparse matrix × stacked expert weights = output

This approach eliminates Python loops entirely and leverages optimized sparse CUDA kernels.

The Broader Landscape

MoE has become a dominant architecture pattern. Here's a summary of the major models:

YearModelExpertsTop-KTotal ParamsActive Params
2017Shazeer et al.20482137B~1B
2020GShard20482600B~1.5B
2021Switch Transformer1281395B3.1B
2021GLaM6421.2T97B
2023Mixtral 8x7B8246.7B12.9B
2024DeepSeek-MoE64+2616.4B2.8B
2024Mixtral 8x22B82141B39B
2024DBRX164132B36B
2024Grok-182314B~80B
2025DeepSeek-V3256+18671B37B

The trend is clear: more experts, finer granularity, and increasingly sophisticated routing.

Open Problems

MoE is far from a solved problem. Active research areas include:

1. Expert Merging. Can we train an MoE model and then merge experts back into a dense model for more efficient inference? Early results show this is possible with some quality loss.

2. Dynamic K. Instead of fixed top-K routing, can the router dynamically choose how many experts each token needs? Simple tokens might need K=1K=1; complex tokens might need K=4K=4.

3. Hierarchical Routing. Instead of flat routing to a list of experts, route through a tree: first select a cluster of experts, then select within the cluster. This reduces the router's decision space.

4. Multi-Modal Experts. In multi-modal models, should different modalities (text, images, audio) have dedicated experts? Initial evidence suggests that cross-modal expert sharing can be beneficial.

5. Continual Learning. MoE's modular structure makes it a natural fit for continual learning: new knowledge could be added by adding new experts without retraining existing ones. This is largely unexplored.

Key Takeaways

  1. MoE decouples model capacity from compute cost. Total parameters determine knowledge storage; active parameters determine inference speed. These can be scaled independently.

  2. The router is the critical component. Everything depends on learning a good token-to-expert assignment. Load balancing is essential to prevent collapse.

  3. MoE modifies only the FFN. Attention layers remain dense. The modification is surgical and composable with other architectural improvements.

  4. The trade-off is memory. MoE saves FLOPs but requires storing all expert parameters in memory. Expert parallelism across multiple devices is the standard solution.

  5. Modern MoE works. Mixtral proved that open-source MoE models can match dense models with 3-4x more active parameters. DeepSeek showed that fine-grained experts push this ratio further.

The Mixture of Experts architecture represents a fundamental shift in how we think about scaling. Instead of the brute-force approach of making every parameter participate in every computation, MoE introduces a principle from biological neural networks: specialization. Not every neuron fires for every stimulus. Not every expert processes every token. And that selective activation is what makes truly massive models practically deployable.


This post is part of the deep learning fundamentals series. Previous posts cover neural network building blocks, transformers, KV caching, flash attention, language modeling & RNNs, optimizers, and regularization.