Thoughts on my first semester of ML Interpretability Research

This semester I started doing research at MIT with Eric Michaud, in the Max Tegmark lab. It was an interesting experience, and I had lots of independence to pursue my own experiments while learning from my mentor.

Context

I’m interested in research that aims at understanding how neural networks manage to learn complex ideas and patterns. Looking at ChatGPT, DALL-E, or initiatives like AlphaFold, there’s a clear pattern of models developing reasoning abilities and accomplishing extraordinary things, but how? Surprisingly, although ML research churns out these technological advances, it feels like there’s still so many questions we don’t have answers to regarding how complex behaviors are learned.

This is intrinsically interesting, and also has important implications as models are used in more and more important/critical context.

Background

This research lies at the crossroads between the growing field of mechanistic interpretability and theory of deep learning. In A Mechanistic interpretability analysis of grokking, Neel Nanda studies how a transformer network does the following task: for a pre-established integer n, learn a function f that takes (a, b) two integers and outputs a + b mod n. At first the model memorizes the training set, and then at some point it generalizes with a sharp decrease in test loss. This is called a phase change, a rapid decrease in loss over time.

This is related to the idea of a circuit, a part of the model implementing an algorithm through its layers, that you could in theory directly read in the weights. Neel actually does this, and the idea is that once the circuit forms an easy way to decrease loss is to simply assign more weight to the circuit and decrease the value of the rest, giving a phase change.

Similarly, in In-context learning and Induction Heads, researchers find a moment in language model training where the model sees a phase change. This is related to the emergence of induction heads, a circuit that recognizes patterns like [string A] [string B], and when it sees [string A] later in the prompt, it assigns higher probability to the next token being [string B].

This circuit seems to be quite important at this part of training and it’s conjectured that its emergence causes the phase change, capturing a wide range of behaviors that are important for the model to learn.

These results paint a certain picture of model training: how much are phase changes and circuit responsible for learning in general? Can we view models as an assortment of several circuits being learned at the same time and having phase changes at their scale, and then our smooth loss curves would be a result of the composition of all these phase changes? Are phase changes everywhere?

Research

We want to understand the link between the macro-scale pattern of loss curves and learning dynamics to the theory of how models learn and circuits form. This would be helpful for mechanistic interpretability and it would also allow us to build better theories for how models learn several things at the same time and how emergence of capabilities work.

What we did: