Per Layer Embeddings

text
Published

April 2026

Per-layer embeddings (used in Gemma4 models).

In a few words the idea - to fix a separate small vector of meaning representations (embeddings) for each layer of the artificial neural network, for every token. At the very first inference (generating an answer on the device) call, all of these per-layer vectors are looked up at once, only one time (which is one of the reasons the first model call is longer than the subsequent ones). The dimensionality of these per-layer vectors is much smaller than the main hidden state (e.g. 256 dimensions vs. 1536 or 2560). At each layer, the small per-layer vector is injected into the residual stream - so the layer is “reminded” of what this token is about, adding the “shades” of meaning when it’s relevant.  And most importantly - this whole lookup table can be stored in flash memory rather than in RAM! (RAM is needed for computing operations over vectors and is therefore the bottleneck).

Instead of, during inference, the incoming token getting looked up in the embedding table at the input layer, and the that single embedding having to “drag” all of the token’s semantic context along with it through every layer of the network (how it was in most transformer-like models).

Thank you Maarten Grootendorst for visual explanation!