During the last half a year I became more confident that the recurrent networks can be more interpretable. At least, the algorithms they model could be more intuitive for us; it’s what we’re used to. In contrast, the transformers do not have loops unless you’re doing explicit/verbalized chain-of-thought. It made me believe that so much of the capabilities is hardcoded, and will have poor generalization. Not to mention about them encoding the world knowledge redundantly.
Now, I became interested in these “modern” recurrent models (aka state-space-models). Those, in contrast to the RNN/GRU/LSTMs we knew, have no non-linearity on the state transition.
Interestingly, there were almost no works on the interpretability of SSM models (the only one I know of is ROMBA, ROME for Mamba). After some time I realized it’s because for mech. interpretability most researchers use the standard python libraries, but those do not support ssm models. Moreover, they have this dual nature of model, where, for example, in mamba, it’s recurrent during the inference, and can be viewed in training/teacher-forced setup. This makes extremely difficult to capture the states and intercept them.
Anyway, because I believe more in interpretability of recurrent models, and also that there’s so much stuff to do, I just decided to focus my study on them, and maybe someday we can apply the findings on the transformers as well.
SSM models are gaining momentum, but at this moment there are almost no models performing nearly as good as those SoTA transformer models. In my mech interpretability studies I’m often interested in open-ended language models, and don’t believe that models trained on syntetic data are even remotely close to internals of pretrained language models we use now. For that, my goal is to use only sota pretrained language models, but unfortunately there was no obvious choice.
Hybrid models (those interleaving attention layers with recurrent layers), however, are performing even better than transformers given the same budget. But, study on these models could be tricky; any property I could look for could be mediated by both attention and recurrent counterparts, making the study more complicated.
In order to make sure in the further study the attention is not doing the heavy lifting, we design the task in such a way that attention solve the task without recurrent part. The hybrid models supposed to be linear, right? If you have attention and recurrent, it still has quadratic complexity.
To overcome that issue, most of these models use sliding window attention. And we can use that fact in favour of us. We’ll consider such tasks where the essential information paths are way outside of the scope of the attention window, so the recurrent states are the only way to carry the data.
We generate a synthetic dataset of phone lookup problem, query-first!
# Task
Remember Isaac Newton's phone number
# Phonebook
- Darwyn Jacobs: 524544318
- Concepcion Abshire: 898812961
- Bell Dicki: 125095571
- Kay Mertz: 139010740
... <- hundreds of records here
- Moesha Welch: 456434872
**- Isaac Newton: 356306997**
- Eston Emard: 554920799
- Cortney Buckridge: 492057075
... <- hundreds of records here
- Sing Ferry: 389031246
- Harriett Brown: 068654208
# Remembered Contacts
- Isaac Newton: **356306997**
This should be very easy task for transformers - just one induction head + copy head will do the job. If recurrent models can encode the information and retain it (remember, we made the context so long, that it’s outiside the sliding window) over big number of steps.
Well, it turns out recurrentgemma-9b-it
(griffin architecture, hybrid, sliding window size of 2048) can handle the task pretty well:
Model | Exact Match |
---|---|
google/recurrentgemma-9b-it | 86.66% |
Now, our task is to understand how the information is captured, encoded, retained using the recurrent states.
The dataset consists of 10 unique first names and 10 unique last names for the entity we’re looking for. Then we manually make sure that generated 100 random phone numbers have uniformly all the possible 10 first digits. For each such row, we generate 10 random contexts.