dawoodsarfraz[dot]cs[@]gmail[dot]com   |   dawoodsarfraz0346[@]gmail[dot]com

Just a Demo Blog

Authors: Dawood Sarfraz
Published: November 17, 2025
Reading Time: 8 minutes


Introduction

Modern transformers rely heavily on softmax attention, which costs \(O(n^2)\) time and memory. Efficient transformer variants attempt to reduce this complexity. This blog explains how linear attention and the Delta Rule allow us to compute attention in \(O(n)\) time.

We start from the classical attention formulation:

\[ \text{Att}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V \]

Linear attention replaces softmax with a kernel feature map \(\phi(\cdot)\), giving:

\[ \text{LinearAtt}(Q,K,V) = \phi(Q)\left(\phi(K)^T V\right) \]

This allows us to accumulate information incrementally — which is where the Delta Rule (Hebbian-style update) comes in:

\[ W_t = W_{t-1} + \phi(k_t)v_t^T \]


Implementing Linear Attention

Let’s walk through a minimal PyTorch implementation. Below is a clean, readable version. Click to expand the code.


import torch
import torch.nn as nn
class LinearAttention(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    def phi(self, x):
        return torch.nn.functional.elu(x) + 1 # Feature map for linear attention (ELU + 1 example)
    def forward(self, Q, K, V):
        Q, K, V = map(self.phi, (Q, K, V))
        S = torch.zeros(Q.size(-1), V.size(-1)) # Compute KV^T incrementally
        for t in range(K.size(0)):
            S += torch.outer(K[t], V[t])  # Hebbian update
            out = torch.matmul(Q, S)
        return out

                    

Why Delta Rule Helps

The Delta Rule enables incremental updates without storing the entire sequence. Instead of doing:

\[ Q(K^T V) \]

…we maintain a running representation:

\[ S_t = S_{t-1} + \phi(k_t)v_t^T \]

This means attention can be computed in a single sequence pass — perfect for long documents and video.


import torch
phi = lambda x: torch.nn.functional.elu(x) + 1

K = torch.randn(5, 4)  # 5 tokens, dim=4
V = torch.randn(5, 4)
S = torch.zeros(4, 4)

for t in range(5):
    S += torch.outer(phi(K[t]), V[t])
    print("Final S matrix:\n", S)

References

Linear attention combined with the Delta Rule gives a powerful way to scale transformers to extremely long sequences. It is memory-efficient, fast, and mathematically elegant. Future posts in this series will show: