In a standard causal language model, each token's hidden state is pulled in two directions at once. It needs to represent the current token faithfully (what word is this, what role does it play in the sentence), but it also needs to encode the prediction for what comes next. These two responsibilities compete for the same vector, and you can measure the resulting "representational drift" using tools like the logit lens: intermediate layers gradually shift away from representing the current token and toward anticipating the next one.
My work explores whether this tension is a fixable design choice rather than an inherent property of autoregressive models. The approach is to introduce a minimal architectural change that gives the model a separate place to do its predictive work, so the main token representations can stay focused on encoding meaning. The key constraint is that the modification should be nearly parameter-neutral. If cleaner representations emerge, they should come from the structural separation itself, not from added capacity.
Preliminary results are encouraging. The architectural change preserves prediction quality (measured by perplexity) while significantly improving how well intermediate representations encode properties of the current token, with the effect concentrating in deeper layers where representational drift is most severe. One interesting exception: representations of syntactic structure actually benefit from the standard setup, suggesting that some amount of cross-token information mixing is useful for encoding hierarchical properties like parse depth. That tradeoff is itself informative about what transformers use their hidden states for.
I'm also involved in a funded research collaboration I can't discuss publicly yet.
Causal Language ModelsMechanistic InterpretabilityTransformer ArchitectureRepresentation Learning