The Personal Site of Lalo Morales


Multi-Head Latent Attention and Mixture of Experts (MoE) in Transformers

Multi Head Latent Attention

Introduction

Transformers have revolutionized natural language processing (NLP) and have increasingly found applications in domains such as vision, speech, and reinforcement learning. Their hallmark feature is self-attention, which learns to focus on relevant parts of a sequence to generate contextually rich representations. However, as we push for ever-larger models, we encounter computational bottlenecks and difficulties in training. Two key innovations that address these challenges are:

  1. Multi-Head Latent Attention – the cornerstone of Transformer architectures, enabling the model to capture multiple “aspects” or “sub-spaces” of relationships between tokens.
  2. Mixture of Experts (MoE) – a strategy for modularizing computation across multiple specialized sub-networks (experts), so that each token is routed only to the experts most relevant to it.

In this blog post, we will break down these concepts in detail, discussing:

  • The step-by-step flow of Multi-Head Attention (and why positional encodings matter).
  • How a Mixture of Experts mechanism augments or replaces the typical Feed-Forward Network in a Transformer.
  • The typical tensor shapes and data flows for both components.
  • Practical considerations and glossary terms that unify these ideas.

By the end, you should have a solid conceptual grasp of these two major building blocks and how they synergize in modern large-scale Transformers.


Part I: Multi-Head Latent Attention

1. Why Attention?

Traditional sequence models (like vanilla RNNs) struggle with long-range dependencies because each new step in a sequence processes hidden states that are sequentially updated over time. Attention mechanisms address this by allowing each token to “look back” at all other tokens in a single step. This short-circuits the difficulty of long-distance relationships—if the model needs context from a token far away in the sequence, it can directly attend to that token instead of only passing information forward through multiple time steps.

2. Multi-Head Attention Overview

Multi-head attention is an extension of basic attention. It performs multiple “attention” operations in parallel, each with its own learnable projection of the input. Concretely:

  • You begin with an input tensor XXX of shape (B,T,d)(B, T, d)(B,T,d) representing a batch of sequences, each of length TTT, with an embedding dimension ddd.
  • You create three distinct “views” (or transformations) of this input, called the Query (Q), Key (K), and Value (V) vectors. These are typically computed by multiplying XXX by different learned weight matrices.
  • These Q, K, and V tensors are then split into multiple “heads.” For example, if your embedding dimension ddd is 512 and you have 8 heads, each head is 64 units wide. This means each head processes a 64-dimensional subspace of Q, K, and V.
  • Each head performs a scaled dot-product attention: it computes softmax(QKT/dh)V\text{softmax}(Q K^T / \sqrt{d_h}) Vsoftmax(QKT/dh​​)V, where dhd_hdh​ is the dimension of each head (64 in our example). Dividing by dh\sqrt{d_h}dh​​ helps stabilize gradients and keep values in a manageable range.
  • The results from all heads are then concatenated back together along the last dimension, restoring the shape to (B,T,d)(B, T, d)(B,T,d). A final learned linear projection is applied to combine the multiple heads into a single representation.

3. Why Multiple Heads?

A single head might focus heavily on one aspect of the sequence relationships—for example, it might lock onto subject-verb agreement or focus on distant tokens related by a certain linguistic pattern. By using multiple heads, the model can learn different “attention patterns” in parallel. One head might capture syntactic structure, another might capture co-references, another might track temporal connections, and so on.

4. Positions Matter: Positional Encodings

In a vanilla attention setup, if you only have the embeddings of tokens, the model lacks direct information about which token is at position 1, 2, 3, etc. Transformers solve this with positional encodings that inject some notion of “where in the sequence” each token is located. Common approaches include:

  • Absolute positional encodings – a fixed sinusoidal pattern added to each token embedding.
  • Learned positional embeddings – a trainable vector for each possible position.
  • Rotary Positional Embeddings (RoPE) – a technique where you multiply queries and keys by a rotation matrix that depends on position.
  • Relative positional encodings – where attention is computed using offsets between positions rather than absolute positions.

In the diagram that references “R(n)P(K)” and “R(n)P(Q),” you may see these as operations that insert positional knowledge into your Key and Query tensors. The reason you often add positional encodings to Keys and Queries (rather than Values) is that the alignment function (i.e., the dot product between Q and K) relies on positional information to properly weigh tokens’ similarities.

5. Typical Tensor Shapes and Flow

Let’s clarify a typical scenario for multi-head attention:

  1. Input (B,T,d)(B, T, d)(B,T,d)
  2. Linear transforms:
    • Q=XWQQ = XW_QQ=XWQ​, K=XWKK = XW_KK=XWK​, V=XWVV = XW_VV=XWV​, each shaped (B,T,d)(B, T, d)(B,T,d)
  3. Reshape into heads:
    • For 8 heads, each dimension is dh=d/8d_h = d / 8dh​=d/8.
    • So, QQQ becomes (B,T,8,dh)(B, T, 8, d_h)(B,T,8,dh​), then rearranged to (B,8,T,dh)(B, 8, T, d_h)(B,8,T,dh​).
  4. Compute scaled dot-product:
    • scores=(QKT)/dh\text{scores} = (QK^T) / \sqrt{d_h}scores=(QKT)/dh​​, yielding shape (B,8,T,T)(B, 8, T, T)(B,8,T,T)
    • \text{attn_weights} = \text{softmax}(\text{scores}), also (B,8,T,T)(B, 8, T, T)(B,8,T,T)
    • \text{context} = \text{attn_weights} \times V => (B,8,T,dh)(B, 8, T, d_h)(B,8,T,dh​)
  5. Concatenate heads:
    • (B,8,T,dh)(B, 8, T, d_h)(B,8,T,dh​) => (B,T,8×dh)=(B,T,d)(B, T, 8 \times d_h) = (B, T, d)(B,T,8×dh​)=(B,T,d)
  6. Final projection:
    • Pass through a linear layer (B,T,d)→(B,T,d)(B, T, d)\to(B, T, d)(B,T,d)→(B,T,d).

This process is repeated in each Transformer layer, often with additional details like masking (for instance, preventing a token from attending to future tokens) in tasks like language modeling.


Part II: Mixture of Experts

1. The Motivation for MoE

As models grow in size (billions or even trillions of parameters), training them can become prohibitively expensive. A Mixture of Experts (MoE) approach addresses this by distributing computation across multiple smaller “expert” sub-networks while maintaining a large overall capacity. The high-level idea:

  1. Each token in the sequence doesn’t need to be processed by every single part of the network.
  2. Instead, we let a router or gater direct each token to only the most appropriate (or “expert”) sub-network(s).
  3. By doing so, we effectively keep the overall model’s capacity large while limiting the computation that each token requires.

2. Router and Experts

An MoE layer generally consists of:

  • Router: A function that takes the incoming hidden states and produces a set of gating logits. Each logit corresponds to how relevant each expert is for that token. After a softmax, these logits become probabilities or weights that indicate how heavily to rely on each expert.
  • Experts: Each expert is typically a standard MLP (or feed-forward network) that can process the input embeddings. However, each expert’s parameters are isolated from the others, allowing them to specialize in certain types of tokens.

Key Points to Note:

  • If you have many experts (say 32, 64, or more), then the total number of parameters in the network can be huge. However, at each forward pass, each token is only sent to one or a few experts (e.g., top-1 or top-2 gating).
  • The distribution of tokens to experts is typically learned. Over time, some experts may become specialized in certain patterns (e.g., numeric tokens, punctuation, domain-specific jargon), while others handle general textual patterns.

3. Shape Considerations

If our input to the MoE layer is (B,T,d)(B, T, d)(B,T,d), we can consider:

  1. Router – Maps (B, T, d) \to (B, T, \text{num_experts}).
    • The router might be a single linear layer or a small MLP.
    • The output is gating logits; you typically do a softmax along the experts dimension, so each token has a distribution over \text{num_experts}.
  2. Experts – Each is a function Experti:Rd→Rd\text{Expert}_i: \mathbb{R}^d \to \mathbb{R}^dExperti​:Rd→Rd.
    • The shape of the transformation is the same for each expert, so each can handle a (B,T)(B, T)(B,T)-sized chunk of the data if needed.
    • In practice, you often apply experts on a token-by-token basis (or via a more advanced grouping strategy).
  3. Combining Outputs – We multiply each expert’s output by the corresponding gating weight for that expert and then sum (or otherwise combine) them all together. The final shape is again (B,T,d)(B, T, d)(B,T,d).

4. Router Strategies and Bottlenecks

A standard approach is Top-k routing, where each token is routed to the top kkk experts (most commonly, k=1k = 1k=1 or k=2k = 2k=2). This can cut down on computational overhead. However, it also introduces complexities:

  • Load balancing: If your gating mechanism always picks the same small subset of experts, you’ll overload those experts and under-utilize the rest. Techniques like load balancing losses or smoothing factors are used to encourage more balanced usage.
  • Backpropagation: During training, you might rely on discrete gating (picking only top-1 or top-2) or use continuous approximations to keep the process differentiable.
  • Implementation complexity: Naively sending tokens to different experts can be tricky to parallelize. Approaches like “expert parallel” or “all-to-all” communication are used in distributed training setups.

Putting It All Together: A Transformer Block with Multi-Head Attention + MoE

In a typical Transformer block, you have two major sub-layers:

  1. Attention sub-layer: Multi-head self-attention (or cross-attention, depending on the architecture).
  2. Feed-Forward sub-layer: A position-wise MLP that processes each token embedding.

In an MoE-augmented Transformer, you replace the Feed-Forward sub-layer with a Mixture of Experts sub-layer. So the block might look like:

  1. Input xxx goes through multi-head self-attention:
    • xattn=Attention(x)x_\text{attn} = \text{Attention}(x)xattn​=Attention(x)
    • Then apply skip connection + normalization: xnormed=LayerNorm(x+xattn)x_\text{normed} = \text{LayerNorm}(x + x_\text{attn})xnormed​=LayerNorm(x+xattn​)
  2. Instead of a single MLP, you have:
    • xmoe=MoE(xnormed)x_\text{moe} = \text{MoE}(x_\text{normed})xmoe​=MoE(xnormed​)
    • Another skip connection + normalization: xout=LayerNorm(xnormed+xmoe)x_\text{out} = \text{LayerNorm}(x_\text{normed} + x_\text{moe})xout​=LayerNorm(xnormed​+xmoe​)

The result is that each token is processed by a specialized MLP (one or two experts, depending on gating), which can yield better capacity utilization for large models.


Glossary of Key Terms

Attention

A mechanism to compute a weighted combination of a set of “value” vectors, where the weights are learned dynamically based on a similarity function (e.g., dot product) between “query” and “key” vectors.

Q, K, V

Short for Query, Key, Value. These are linear transformations of the input in self-attention:

  • Q decides “what am I looking for?”
  • K indicates “what do I have available?”
  • V provides the actual “content” to be gathered.

Scaled Dot Product

The operation softmax(QKT/dh)×V\text{softmax}(QK^T / \sqrt{d_h}) \times Vsoftmax(QKT/dh​​)×V. Dividing by dh\sqrt{d_h}dh​​ normalizes the dot products to prevent large magnitudes when dhd_hdh​ is large.

Multi-Head

Instead of computing one attention distribution, multi-head attention splits the embedding space across multiple heads, each with its own projection and attention computation. The results are concatenated and fused by a final linear layer.

Positional Encoding (Absolute, Relative, Rotary, etc.)

Methods to inject positional awareness into Transformers, ensuring that each token knows where it appears in the sequence.

Mixture of Experts (MoE)

A set of parallel sub-networks (experts), each potentially specialized in certain input patterns. A router decides which expert(s) each input token is sent to.

Router

A module that computes gating scores to determine which experts to use for each token. Often implemented as a linear (or small MLP) transformation followed by softmax or top-k selection.

Feed-Forward Network (FFN)

A standard MLP block (often with a hidden dimension 2–4 times the model dimension) used in Transformers. In an MoE Transformer, each expert is essentially one such FFN.

Gating

The process of assigning tokens to experts. A softmax gating approach assigns fractional weights to each expert, while top-k gating selects the top few experts for each token.

Load Balancing

A mechanism to prevent certain experts from “hogging” all the tokens while others remain idle. Common strategies include adding a regularization term that penalizes imbalances.


Practical Considerations

1. Computational Trade-Offs

  • Multi-Head Attention: The cost scales roughly with O(T2×d)\mathcal{O}(T^2 \times d)O(T2×d) because for each token you compute attention with every other token in the sequence. For very long sequences, memory or compute might become a bottleneck, and you might consider sparse attention methods or efficient attention approximations.
  • Mixture of Experts: If you have EEE experts, each with an MLP cost of O(d×hidden_dim)\mathcal{O}(d \times \text{hidden\_dim})O(d×hidden_dim), using top-1 or top-2 gating means each token is only processed by 1 or 2 experts. This can reduce the per-token cost from O(E×d×hidden_dim)\mathcal{O}(E \times d \times \text{hidden\_dim})O(E×d×hidden_dim) to O(d×hidden_dim)\mathcal{O}(d \times \text{hidden\_dim})O(d×hidden_dim) or O(2×d×hidden_dim)\mathcal{O}(2 \times d \times \text{hidden\_dim})O(2×d×hidden_dim). However, you still have overhead in terms of routing and balancing computations.

2. Training Stability

  • Transformers can be sensitive to initialization, learning rates, and normalization layers.
  • MoE layers add an additional complexity: if gating becomes highly skewed, some experts may receive far more gradients, while others receive almost none. Solutions often involve special gating regularization terms.
  • Layer Normalization or RMSNorm is commonly used around both the attention sub-layer and the FFN/MoE sub-layer.

3. Large-Scale Distributed Training

  • Multi-head attention is usually straightforward to distribute: you can split across batch, sequence, or heads.
  • MoE introduces more complexity because you need to route tokens to experts, potentially scattering them across different devices. The all-to-all communication pattern can be a major engineering challenge. Libraries like DeepSpeed (Microsoft) or the MoE implementation in Google’s TensorFlow/TPU stack provide specialized support for distributed MoE.

4. Real-World Examples

  • GPT-3: Does not use MoE, but heavily relies on multi-head attention to scale to 175B parameters.
  • GLaM: A large MoE-based language model from Google.
  • Switch Transformers: Another MoE-based architecture that replaced standard FFNs with experts, achieving impressive scaling behaviors.

A Hypothetical Example Walkthrough

Imagine a simplified scenario:

  • We have a vocabulary of tokens representing words in English text.
  • We embed these tokens into a (B,T,d)(B, T, d)(B,T,d) representation.
  • We feed them into our Multi-Head Self-Attention layer. The model learns to align subject and verb across different positions, detect relevant context for each word (like named entities or references), and produce a contextually enriched output of shape (B,T,d)(B, T, d)(B,T,d).
  • Next, we pass these enriched embeddings into the MoE feed-forward sub-layer. The router sees each token embedding and decides:
    • “Expert 0 is good at handling numeric data (maybe it has learned that through training).”
    • “Expert 1 is specialized in domain-specific jargon for legal text.”
    • “Expert 2 is specialized in common everyday language.”
    • etc.
  • Each token’s gating distribution might be heavily skewed toward Expert 2 if the text is everyday conversation. But a date or a number token might get strongly routed to Expert 0.
  • The token’s output is then a weighted combination of the relevant expert’s transformations. This specialization can yield more powerful transformations while not drastically increasing the per-token computation.

Combining Multi-Head Attention and MoE in Practice

When you see a diagram that places Multi-Head Attention at the top and Mixture of Experts at the bottom, it typically represents a single Transformer layer:

  1. The top portion (Multi-Head Latent Attention) indicates how the input tokens attend to one another via Q, K, V, with optional positional encodings.
  2. The bottom portion (Mixture of Experts) shows how the output from the attention sub-layer is passed into the router and relevant experts, and how their outputs are recombined into a single tensor.

Hence, the entire flow is:

  1. Embed tokens (and possibly add or incorporate positional encodings).
  2. Compute multi-head attention over the sequence, generating context-aware representations.
  3. Normalize + skip connection (a standard practice in Transformers).
  4. Route tokens to experts using the gating mechanism.
  5. Experts process tokens (or token embeddings), each providing a specialized transformation.
  6. Combine expert outputs into a final representation, then another normalize + skip connection step.

Key Takeaways and Future Directions

  1. Multi-Head Attention is essential for capturing a wide range of relationships in sequence data. Each head learns a different representation, leading to richer embeddings.
  2. Positional Information is crucial in any attention-based model that processes sequences. Without it, the model cannot distinguish the order of tokens.
  3. Mixture of Experts provides a way to scale model capacity without linearly increasing computation cost for every token. Only a subset of parameters (i.e., the chosen experts) is used for each token.
  4. Balancing and Routing complexities are non-trivial. Engineers building MoE-based models must handle load balancing, gating, and distributed communication carefully.
  5. Production-Ready Implementations often rely on specialized frameworks or libraries (e.g., DeepSpeed, Google’s TensorFlow MoE) to handle the distributed nature of large MoE models.
  6. Research Trends:
    • Sparse Attention: Combining MoE with more efficient forms of attention to handle extremely long sequences (e.g., 10K or 100K tokens).
    • Dynamic Architectures: Using dynamic routing not just at the expert level, but also in attention sub-layers or in how the model is extended to new domains.
    • Vision and Multimodal: MoE is also being explored in vision transformers and multimodal networks to handle complex data like images, audio, video, or text simultaneously.

Conclusion

We have walked through the conceptual blocks of Multi-Head Latent Attention and the Mixture of Experts framework, highlighting their motivations, typical shapes and dimensions, as well as the step-by-step flow. The synergy between these mechanisms allows for powerful, scalable Transformer architectures:

  • Multi-Head Attention ensures broad expressivity in capturing relationships within sequences, distributing attention across multiple sub-representations.
  • Mixture of Experts scales network capacity by distributing feed-forward computations across specialized experts, with a router that learns which expert should handle each token.

The interplay of these elements is what drives many of today’s cutting-edge large language models. As you delve deeper into the actual implementations, you’ll see how these building blocks manifest in frameworks like PyTorch or TensorFlow, complete with intricacies such as top-k gating, load balancing losses, and distributed data parallelism.

Ultimately, both Multi-Head Attention and Mixture of Experts are about making the best use of a large capacity model: attention focuses on relevant parts of a sequence at every step, while MoE focuses on relevant sub-networks for every token. Together, they continue to push the boundaries of what’s possible in NLP and beyond.

Share via
Copy link