EXPLORATIONS - April 2026

What I Learned Building Attention Residuals from Scratch

Naively reimplementing a paper in PyTorch changed how I think about how transformers route information, and about the gap between academic math and physical silicon.

I wanted to understand how transformers actually route information between layers. Not at the level of “attention computes weighted averages,” but at the level of what physically happens to a tensor as it moves through the network. What gets preserved, what gets overwritten, and why.

A video helped build some of the initial intuition. From there I picked up the paper, Attention Residuals (Kimi Team, 2026), and decided to reimplement it from scratch. No HuggingFace, no pre-built transformer blocks. Just torch.einsum, nn.Parameter, and a toy dataset small enough that I could trace every matrix multiplication by hand.

A rotated sketch snapshot from the attention residuals article work-in-progress
This is a frozen snapshot of when things started to click for me. I never finished the drawing, and my handwriting here might honestly be worse than kenny's, but I like it as a raw artifact of the moment when the math started to feel concrete :p

Right around here, I remember asking Gemini, “make me a toy problem that i can do by hand, faithful to the math in the paper.” I needed something small enough that I could trace every tensor manually, but still honest to the actual routing mechanism the paper was studying.

One thing that clicked for me from the video was that ordinary attention, plus the KV cache, is basically a breadth-wise memory. You keep appending keys and values along sequence length so the model can revisit earlier positions in time. Attention Residuals felt like that same instinct turned depth-wise. Instead of caching over previous tokens, you are effectively caching over previous layers and intermediate computations.

Later, Xander sent me this Kimi tweet, with two lines I loved: “Rotating an LSTM gives us residuals” and “What is attention rotated by 90 degrees?” That framing made the whole thing feel especially beautiful to me. It was old machinery getting turned on a new axis, not discarded, just repurposed for a new problem.

The exercise turned out to be far more instructive than I expected. Not because the paper is complex. The core idea literally fits in a paragraph. But implementing it forced me to confront questions about memory management, gradient flow, and hardware constraints that never come up when you're working at the API level.

This post walks through the architecture, some code, and the specific moments where my understanding broke and reformed.


The problem with standard residuals

A standard transformer block computes x = x + layer(x). The residual connection is one of the most important ideas in deep learning. It gives gradients a direct path through the network and prevents the “vanishing gradient” problem from killing training.

But there's a subtle limitation. Each layer can only see the output of the layer immediately before it. Layer 12 cannot directly read the raw embedding. Layer 8 cannot inspect what Layer 2 computed. Information propagates through a chain of additions, like a game of telephone, and by the time a signal reaches the deeper layers, it has been mixed, transformed, and potentially diluted beyond recognition.

The residual stream helps, because the original signal is additively preserved. But the relative contribution of early layers shrinks as more terms get summed in. There's no mechanism for a later layer to say: “I specifically need 80% of the embedding and only 5% of Layer 3.”

Attention Residuals replaces this fixed plumbing with a learned routing mechanism. Every layer gets to decide, for itself, exactly how much of each previous output to use.

Standard ResidualAttention ResidualsEmbeddingMHA 1MLP 1MHA 2x = x + layer(x)KV storeh0h1h2h3EmbeddingMHA 1MLP 1MHA 21.000.210.790.120.330.55writeread (gated)
Left: Standard residual connections forward information one layer at a time. Right: Attention Residuals let every layer query the full history.

A growing database of past outputs

The first data structure I had to build was what I started calling the “history database.” Instead of passing a single tensor forward through the network, I maintain a Python list of every output generated so far:

states = [x] # start with the embedding
for layer in self.layers:
states = layer(states) # the layer reads AND appends
return self.final_proj(states[-1])

Each FullAttnResLayer receives this list, reads from it to construct its inputs, and appends its own MHA and MLP outputs before passing the updated list to the next layer. The database grows by two entries per layer, one from the multi-head attention block, one from the MLP.

One dumb thing I thought should work at first was states.append(layer(states)). It sounds natural when you say it fast. But layer(states) already mutates and returns the same list object, so that line would make states append itself into itself. Layer 2's torch.stack() would then immediately walk into a cursed self-referential list and explode.

The History Databaseprevious_states: list[Tensor]EmbStart1 entryEmbMHA₁After MHA 12 entriesEmbMHA₁MLP₁After MLP 13 entriesEmbMHA₁MLP₁MHA₂After MHA 24 entriesEmbMHA₁MLP₁MHA₂MLP₂After MLP 25 entriestorch.stack()Each entry is a [B, T, D] tensor. The list only stores pointers — torch.stack() creates contiguous memory when needed.
The database grows by two entries per layer: one from MHA, one from MLP. Layer 2's MHA can directly query Layer 1's embedding.

A subtle but important implementation detail: the list itself only stores pointers. When you call previous_states.append(mha_out), Python doesn't copy the tensor. It just jots down the memory address. This is essentially free. The expensive operation comes later, when we need to do math across all entries and must call torch.stack() to flatten the scattered pointers into a single, contiguous 4D tensor that the GPU can parallelize over.

The alpha gate: learned routing

Before each MHA or MLP block runs, a small routing engine decides how to blend the history. I call this the alpha gate. It computes a set of learned, softmax-normalized percentages (the α weights) that determine how much of each past output to include.

The mechanism is a strict 4-step pipeline:

  1. RMSNorm. Normalize the stacked history tensor to ensure fair scoring across entries that may live at different scales.
  2. Score. Dot-product a learned 1D query vector (wl) against the normalized history. This produces one raw logit per history entry.
  3. Softmax. Convert the logits into percentages that sum to 1.0 along the depth dimension. These are the alpha gates.
  4. Blend. Weighted sum of the raw, un-normalized history using the alpha percentages. This becomes the input to the next sublayer.
The Alpha Gate[h₀, h₁]1RMSNormK = RMSNorm([h₀, h₁])2Scorelogits = wₗ · K → [1.33, 0.01]wₗ3Softmaxα = softmax → [0.79, 0.21]4Blendout = 0.79·h₀ + 0.21·h₁uses raw historyinputExample: MLP 1 routing over history = [embedding, MHA 1 output]
The alpha gate: a 4-step routing engine that runs before every MHA and MLP block.

Here is the actual implementation:

def alpha_gating(self, history, layer_type):
normed_history = (
self.mlp_res_norm(history)
if layer_type == "mlp"
else self.mha_res_norm(history)
)
pre_scores = torch.einsum(
"d, s b t d -> s b t",
self.mlp_query if layer_type == "mlp" else self.mha_query,
normed_history,
)
scores = F.softmax(pre_scores, dim=0)
return torch.einsum(
"s b t, s b t d -> b t d",
scores,
history, # raw history, not the normed version
)

Two things worth noting. First, the query vector is initialized to all zeros via nn.Parameter(torch.zeros(D)). At the start of training, every alpha gate produces a uniform distribution over the history, roughly equal weight to all past entries, and then learns to specialize. Second, the blend step operates on the raw history, not the normalized version. The RMSNorm exists purely to stabilize the scoring. We don't want to feed the sublayer a signal that has been normalized twice.

A note on the “useless” first query

MHA 1 is the first module to run. At that point, the database contains exactly one entry: the embedding. The alpha gate computes softmax over a single element, which is always [1.0]. The first query vector is mathematically locked. Its gradients will always be zero, and it will never learn anything. It exists in the code purely for structural symmetry.

I only understood this after making a very silly mistake on paper. I had written down the alpha gate from embedding to MHA 1 as something like [1, 1], because in my head I was already counting both the embedding and MHA 1. But MHA 1 hasn't run yet when it queries the database. There is only one thing to attend to. That was the moment it clicked that the first query vector isn't just unimportant, it is mathematically useless.

The dual-norm system

One of the more confusing aspects of the implementation was keeping track of the normalization layers. There are two completely independent sets of nn.RMSNorm modules per layer, and they serve entirely different purposes:

  • Routing norms (mha_res_norm, mlp_res_norm). These live inside the alpha gate. They normalize the entire history stack to stabilize the scoring mechanism.
  • Pre-norms (mha_input_norm, mlp_input_norm). These live outside the alpha gate. They normalize the single blended vector right before it enters the heavy weight matrices of MHA or MLP, preventing gradient explosion.

Critically, the routing norms and pre-norms must be separate modules. If MHA and MLP share a single norm, they are forced to evaluate the history with the same learned scale, which cripples their ability to develop distinct routing strategies.

# Each sublayer gets its own routing norm AND its own pre-norm
self.mha_res_norm = nn.RMSNorm(D) # for scoring
self.mha_input_norm = nn.RMSNorm(D) # for stabilizing
self.mlp_res_norm = nn.RMSNorm(D)
self.mlp_input_norm = nn.RMSNorm(D)

The complete forward pass

Each layer executes two sequential phases. MHA runs first, appends its output to the database, and then MLP runs with the newly updated database, meaning the MLP can immediately use the current layer's MHA output as an ingredient.

def forward_attn_res(self, previous_states):
# Phase 1: MHA
V = torch.stack(previous_states)
gated_input = self.alpha_gating(V, "mha")
mha_out = self.mha(self.mha_input_norm(gated_input))
previous_states.append(mha_out)
# Phase 2: MLP
V_updated = torch.stack(previous_states)
gated_input = self.alpha_gating(V_updated, "mlp")
mlp_out = self.transform(self.mlp_input_norm(gated_input))
previous_states.append(mlp_out)
return previous_states

Notice the two calls to torch.stack. The first stacks N entries; the second stacks N+1. Each call allocates a brand-new contiguous block of GPU memory, copies every tensor from the scattered Python list into it, and hands the result to einsum. This is expensive, but necessary. GPU tensor cores cannot operate on a Python list of pointers. They need a single, aligned memory block.

I also had a moment where the two variable names, V and V_updated, felt hella memory inefficient. Like, wait, did I just double memory by naming the thing twice? But the names are free. They're just sticky notes attached to GPU allocations. The real memory cost is the new torch.stack() call itself, which allocates fresh contiguous storage no matter what you name the result.


What the numbers showed

I trained two identical architectures, 4 layers, hidden size 16, 4 attention heads, on a synthetic sequence-reversal task. The only difference: one uses attention residuals, the other uses standard x = x + layer(x) skip connections. Both use sinusoidal positional encodings, MSE loss, and the Adam optimizer at a learning rate of 0.005 over 50 epochs.

I hooked into the weight gradient of the first Linear layer in Layer 1's MLP to monitor how well gradients propagate from the output all the way back to the earliest parameters.

AttnRes ON
Standard Residual

Loss Curves

0.820.860.910.951.00Loss11020304050Epoch

Layer-1 Gradient Magnitude

0.00000.00030.00070.00100.0014Gradient Mag.11020304050Epoch
Training curves over 50 epochs. Attention Residuals achieve lower loss and sustain growing gradients in Layer 1 - standard residuals plateau.

Two results stand out:

Lower final loss. Attention Residuals reaches a loss of 0.830 versus 0.865 for the standard baseline, a meaningful gap on a toy task where both models have identical parameter counts.

Gradient growth vs. plateau.This is the more interesting finding. With attention residuals, the average gradient magnitude in Layer 1 grows roughly 11× over training (from 0.000107 to 0.001230). With standard residuals, gradients start higher but quickly plateau around 0.00095, showing only ~1.8× growth.

The interpretation is straightforward: the alpha gates create direct gradient highways from the loss back to early layers. Instead of gradients having to flow through every intermediate computation, they can propagate directly through the softmax-weighted connections. This is the same principle that made DenseNet work in computer vision. Direct connections beat chains.


The deeper lessons

The architecture itself is simple. The real education was in the implementation details, the places where abstract math collides with physical hardware constraints.

I also felt like I got to understand PyTorch a little better through all of this. Part of why I even started the project was a conversation Surya had with Zachary Cetinic, which was the initial spark for all of this. That made me want to give it a shot myself and hand-write most of the machinery and repetitive plumbing, just to see if I could make the abstractions feel concrete.

torch.stack vs. .append

Python lists and PyTorch tensors occupy fundamentally different worlds. When you .append() a tensor to a list, Python just writes down a memory address. Essentially free, O(1), no data movement. When you torch.stack() that list into a tensor, PyTorch must find a contiguous block of GPU memory large enough to hold every entry, then physically copy all the data into it.

Python list + .append()O(1): just store addressestorch.stack()Copy into one contiguous tensorTensor[S, B, T, D] in one fresh allocationptr[0]0x1A4Fptr[1]0x8C21ptr[2]0x57D9torch.stack()
.append() only records three pointers to scattered tensor chunks. torch.stack() allocates a new contiguous buffer and copies those chunks into one aligned block the GPU kernel can read in a single pass.

This means every layer pays the cost of two full stack operations. It felt wasteful at first. But there's no alternative. einsum is a C++/CUDA kernel that requires perfectly aligned memory. You cannot run a matrix multiplication on a Python list of scattered pointers. The cost is the price of parallelism.

I originally thought maybe skipping torch.stack() would just make einsumslower. It wouldn't be slower. It would just crash. einsum has no idea what a Python list even is. It wants contiguous tensor memory or nothing.

Pass-by-reference saves you

When a layer receives previous_statesand appends to it, it is mutating the list in place. Python lists are passed by reference. The function receives a pointer to the same object, not a copy. The outer training loop's states variable sees the mutations automatically. Writing states = layer(states)looks like it's “rewriting” the list, but under the hood it's just moving a name tag from one pointer to the same pointer. Zero data movement.

Softmax is stateless; norms are stateful

A pattern that clarified the PyTorch programming model for me: operations that have no learned parameters (softmax, GELU, einsum) live in torch.nn.functional and are called in forward(). Operations that do have learned parameters (RMSNorm, Linear) must be instantiated in __init__ so that nn.Parametercan tell PyTorch's memory allocator to permanently reserve space for their weights and register them with the autograd engine.

The connection to hardware

This little exploration reshaped how I think about the boundary between algorithms and silicon. The alpha gate is a beautiful mathematical idea, learned softmax routing over a depth history. But making it run fast on a GPU requires understanding that torch.stack is not free, that einsum needs contiguous memory, and that the real bottleneck in modern AI hardware is not compute but memory bandwidth, how fast you can ferry data between slow HBM and fast SRAM.

The same principle underlies FlashAttention, which doesn't change the math of attention at all but rewrites the memory access pattern to avoid materializing the T×T score matrix in HBM. It trades slightly more compute for vastly less memory I/O, because on modern GPUs, computing a number twice is often faster than memorizing it once.

Another thing I got weirdly stuck on for a while was softmax. I kept thinking, okay, other words live across the columns, so how are rows somehow independent? The unlock was realizing that walking across one row is scanning the columns. That was the thing that finally made attention feel embarrassingly parallel to me, and that is basically the whole intuition underneath why FlashAttention works at all.

I also briefly made the classic bad leap of, “oh, if each row is parallelizable then attention is O(N).” Not true. The total work is still O(N²). Parallelism only changes how much of that work you can hide across cores, not the exponent. I think part of why I got mixed up is that I already had this fuzzy fact in my head that some attention memory stories grow like O(N) while compute grows like O(N²), and I probably blurred that together with the parallelism story along the way.

How my toy results compare with the paper

The Kimi Team trained AttnRes on a 48B-parameter MoE Transformer with 1.4 trillion tokens of real text. I trained a 4-layer model with D=16 on 5,000 synthetic sequences for 50 epochs. The raw loss numbers are not comparable at all:

My experimentPaper
Loss functionMSECross-entropy
TaskSequence reversal (synthetic)Language modeling (1.4T tokens)
Model4 layers, D=16, ~few K params48B total / 3B active, 54 layers
Loss range0.83 – 1.001.15 – 1.93

MSE and cross-entropy live on completely different scales with different theoretical minimums. An MSE of 0.83 and a cross-entropy of 1.69 are not comparable quantities.

But the relative behavior tells the same story.

AttnRes consistently beats the baseline. The paper sees ~1.5% lower validation loss at the 528M scale (1.692 vs. 1.719). I see ~4% lower loss (0.830 vs. 0.865). The larger relative gap in my experiment makes sense: a tiny model is more capacity-constrained, so routing efficiency matters more per parameter.

Statistics comparison image from the original attention residuals article

Gradient distribution is the big one. Figure 5(c) in the paper shows the baseline produces disproportionately large gradients in early layers with no mechanism to regulate flow across depth. AttnRes yields a substantially more uniform gradient distribution. My Layer 1 gradient hook shows exactly this: AttnRes gradients grow 11× over training while standard residuals plateau at ~1.8× growth. The alpha gates are creating direct gradient highways, just like the paper claims.

PreNorm dilution. Figure 5(b) in the paper shows baseline output magnitudes growing monotonically with depth. Deeper layers learn increasingly large outputs to stay influential over the accumulated residual. AttnRes bounds this growth. I did not measure output magnitudes directly, but the improved gradient flow I observed is a downstream consequence of the same mechanism.

So a toy experiment at D=16 with 50 epochs on synthetic data independently reproduced the paper's core finding: learned softmax routing fixes gradient distribution across depth. The loss numbers are apples-to-oranges, but the dynamics are the same.


Why building from scratch matters

You can read a paper and nod along at the equations. You can import a library and watch the loss go down. But neither of those experiences will teach you that torch.stack allocates new memory, that Python lists are passed by reference, or that the first alpha gate in the network is mathematically useless.

Those are the kinds of details that matter when you're debugging a training run at 3 AM, or designing a custom hardware accelerator, or trying to invent the next architecture. I think the gap between “I understand the math” and “I understand the plumbing” is where engineering lives.

The full implementation is comprised of hacky PyTorch plumbing with a bunch of stream-of-consciousness comments/notes that helped me carefully comb through my thoughts, but probably won't be important to you.

Hindsight

If I kept pushing this further, the next things I'd want to try are:

  • Implement Block AttnRes.
  • Change MHA to GQA.
  • Train it for more than 50 epochs lol.
  • Use a more modern optimizer.
  • Replace the sinusoidal embeddings with RoPE. I skipped that this time because I didn't really understand RoPE well enough yet to feel good implementing it.
  • Switch from encoder layers to decoder layers and add a KV cache.

And before anybody polices me about “tHiS iSnT hOw yOu dO rEsEaRcH,” well, 1. I'm not a researcher, 2. I know very little about training, and 3. if you've made it this far, give me tips on how i could do better :)


The paper is Attention Residuals (Kimi Team, 2026), arXiv:2603.15031.

Full code and some training logs in the repository