From 7c1d1419e4e0855da0ea67f8e59dc6099d18b311 Mon Sep 17 00:00:00 2001 From: degenfabian Date: Mon, 18 Aug 2025 19:19:48 +0200 Subject: [PATCH 1/7] updated loading in exploratory analysis demo to use transformer bridge --- demos/Exploratory_Analysis_Demo.ipynb | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/demos/Exploratory_Analysis_Demo.ipynb b/demos/Exploratory_Analysis_Demo.ipynb index d7e29f11d..b12304844 100644 --- a/demos/Exploratory_Analysis_Demo.ipynb +++ b/demos/Exploratory_Analysis_Demo.ipynb @@ -100,7 +100,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -118,7 +118,8 @@ "from jaxtyping import Float\n", "\n", "import transformer_lens.utils as utils\n", - "from transformer_lens import ActivationCache, HookedTransformer" + "from transformer_lens import ActivationCache\n", + "from transformer_lens.model_bridge import TransformerBridge" ] }, { @@ -245,12 +246,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The first step is to load in our model, GPT-2 Small, a 12 layer and 80M parameter transformer with `HookedTransformer.from_pretrained`. The various flags are simplifications that preserve the model's output but simplify its internals." + "The first step is to load in our model, GPT-2 Small, a 12 layer and 80M parameter transformer with `TransformerBridge.boot_transformers`. The various flags are simplifications that preserve the model's output but simplify its internals." ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -270,13 +271,14 @@ ], "source": [ "# NBVAL_IGNORE_OUTPUT\n", - "model = HookedTransformer.from_pretrained(\n", - " \"gpt2-small\",\n", + "model = TransformerBridge.boot_transformers(\n", + " \"gpt2\",\n", " center_unembed=True,\n", " center_writing_weights=True,\n", " fold_ln=True,\n", " refactor_factored_attn_matrices=True,\n", ")\n", + "model.enable_compatibility_mode()\n", "\n", "# Get the default device used\n", "device: torch.device = utils.get_device()" @@ -372,7 +374,7 @@ "\n", "We want models that can take in arbitrary text, but models need to have a fixed vocabulary. So the solution is to define a vocabulary of **tokens** and to deterministically break up arbitrary text into tokens. Tokens are, essentially, subwords, and are determined by finding the most frequent substrings - this means that tokens vary a lot in length and frequency! \n", "\n", - "Tokens are a *massive* headache and are one of the most annoying things about reverse engineering language models... Different names will be different numbers of tokens, different prompts will have the relevant tokens at different positions, different prompts will have different total numbers of tokens, etc. Language models often devote significant amounts of parameters in early layers to convert inputs from tokens to a more sensible internal format (and do the reverse in later layers). You really, really want to avoid needing to think about tokenization wherever possible when doing exploratory analysis (though, of course, it's relevant later when trying to flesh out your analysis and make it rigorous!). HookedTransformer comes with several helper methods to deal with tokens: `to_tokens, to_string, to_str_tokens, to_single_token, get_token_position`\n", + "Tokens are a *massive* headache and are one of the most annoying things about reverse engineering language models... Different names will be different numbers of tokens, different prompts will have the relevant tokens at different positions, different prompts will have different total numbers of tokens, etc. Language models often devote significant amounts of parameters in early layers to convert inputs from tokens to a more sensible internal format (and do the reverse in later layers). You really, really want to avoid needing to think about tokenization wherever possible when doing exploratory analysis (though, of course, it's relevant later when trying to flesh out your analysis and make it rigorous!). TransformerBridge comes with several helper methods to deal with tokens: `to_tokens, to_string, to_str_tokens, to_single_token, get_token_position`\n", "\n", "**Exercise:** I recommend using `model.to_str_tokens` to explore how the model tokenizes different strings. In particular, try adding or removing spaces at the start, or changing capitalization - these change tokenization!" ] From 1f3a5bcf31926712f956fca41c1a48a88f2e70ce Mon Sep 17 00:00:00 2001 From: degenfabian Date: Mon, 18 Aug 2025 19:19:48 +0200 Subject: [PATCH 2/7] updated loading in exploratory analysis demo to use transformer bridge --- demos/Exploratory_Analysis_Demo.ipynb | 40444 ++++++++++++------------ 1 file changed, 20223 insertions(+), 20221 deletions(-) diff --git a/demos/Exploratory_Analysis_Demo.ipynb b/demos/Exploratory_Analysis_Demo.ipynb index 48caba76c..b12304844 100644 --- a/demos/Exploratory_Analysis_Demo.ipynb +++ b/demos/Exploratory_Analysis_Demo.ipynb @@ -1,20353 +1,20355 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Exploratory_Analysis_Demo.ipynb)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Exploratory Analysis Demo\n", - "\n", - "This notebook demonstrates how to use the\n", - "[TransformerLens](https://github.com/TransformerLensOrg/TransformerLens/) library to perform exploratory\n", - "analysis. The notebook tries to replicate the analysis of the Indirect Object Identification circuit\n", - "in the [Interpretability in the Wild](https://arxiv.org/abs/2211.00593) paper." - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Tips for Reading This\n", - "\n", - "* If running in Google Colab, go to Runtime > Change Runtime Type and select GPU as the hardware\n", - "accelerator.\n", - "* Look up unfamiliar terms in [the mech interp explainer](https://neelnanda.io/glossary)\n", - "* You can run all this code for yourself\n", - "* The graphs are interactive\n", - "* Use the table of contents pane in the sidebar to navigate (in Colab) or VSCode's \"Outline\" in the\n", - " explorer tab.\n", - "* Collapse irrelevant sections with the dropdown arrows\n", - "* Search the page using the search in the sidebar (with Colab) not CTRL+F" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Setup" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Environment Setup (ignore)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**You can ignore this part:** It's just for use internally to setup the tutorial in different\n", - "environments. You can delete this section if using in your own repo." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "# Detect if we're running in Google Colab\n", - "try:\n", - " import google.colab\n", - " IN_COLAB = True\n", - " print(\"Running as a Colab notebook\")\n", - "except:\n", - " IN_COLAB = False\n", - "\n", - "# Install if in Colab\n", - "if IN_COLAB:\n", - " %pip install transformer_lens\n", - " %pip install circuitsvis\n", - " # Install a faster Node version\n", - " !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs # noqa\n", - "\n", - "# Hot reload in development mode & not running on the CD\n", - "if not IN_COLAB:\n", - " from IPython import get_ipython\n", - " ip = get_ipython()\n", - " if 'autoreload' not in ip.extension_manager.loaded:\n", - " ip.extension_manager.load_extension('autoreload')\n", - " %autoreload 2\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Imports" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "from functools import partial\n", - "from typing import List, Optional, Union\n", - "\n", - "import einops\n", - "import numpy as np\n", - "import plotly.express as px\n", - "import plotly.io as pio\n", - "import torch\n", - "from circuitsvis.attention import attention_heads\n", - "from fancy_einsum import einsum\n", - "from IPython.display import HTML, IFrame\n", - "from jaxtyping import Float\n", - "\n", - "import transformer_lens.utils as utils\n", - "from transformer_lens import ActivationCache, HookedTransformer" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### PyTorch Setup" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Disabled automatic differentiation\n" - ] - } - ], - "source": [ - "torch.set_grad_enabled(False)\n", - "print(\"Disabled automatic differentiation\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Plotting Helper Functions (ignore)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Some plotting helper functions are included here (for simplicity)." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "def imshow(tensor, **kwargs):\n", - " px.imshow(\n", - " utils.to_numpy(tensor),\n", - " color_continuous_midpoint=0.0,\n", - " color_continuous_scale=\"RdBu\",\n", - " **kwargs,\n", - " ).show()\n", - "\n", - "\n", - "def line(tensor, **kwargs):\n", - " px.line(\n", - " y=utils.to_numpy(tensor),\n", - " **kwargs,\n", - " ).show()\n", - "\n", - "\n", - "def scatter(x, y, xaxis=\"\", yaxis=\"\", caxis=\"\", **kwargs):\n", - " x = utils.to_numpy(x)\n", - " y = utils.to_numpy(y)\n", - " px.scatter(\n", - " y=y,\n", - " x=x,\n", - " labels={\"x\": xaxis, \"y\": yaxis, \"color\": caxis},\n", - " **kwargs,\n", - " ).show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Introduction\n", - "\n", - "This is a demo notebook for [TransformerLens](https://github.com/TransformerLensOrg/TransformerLens), a library for mechanistic interpretability of GPT-2 style transformer language models. A core design principle of the library is to enable exploratory analysis - one of the most fun parts of mechanistic interpretability compared to normal ML is the extremely short feedback loops! The point of this library is to keep the gap between having an experiment idea and seeing the results as small as possible, to make it easy for **research to feel like play** and to enter a flow state.\n", - "\n", - "The goal of this notebook is to demonstrate what exploratory analysis looks like in practice with the library. I use my standard toolkit of basic mechanistic interpretability techniques to try interpreting a real circuit in GPT-2 small. Check out [the main demo](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Main_Demo.ipynb) for an introduction to the library and how to use it. \n", - "\n", - "Stylistically, I will go fairly slowly and explain in detail what I'm doing and why, aiming to help convey how to do this kind of research yourself! But the code itself is written to be simple and generic, and easy to copy and paste into your own projects for different tasks and models.\n", - "\n", - "Details tags contain asides, flavour + interpretability intuitions. These are more in the weeds and you don't need to read them or understand them, but they're helpful if you want to learn how to do mechanistic interpretability yourself! I star the ones I think are most important.\n", - "
(*) Example details tagExample aside!
" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Indirect Object Identification\n", - "\n", - "The first step when trying to reverse engineer a circuit in a model is to identify *what* capability\n", - "I want to reverse engineer. Indirect Object Identification is a task studied in Redwood Research's\n", - "excellent [Interpretability in the Wild](https://arxiv.org/abs/2211.00593) paper (see [my interview\n", - "with the authors](https://www.youtube.com/watch?v=gzwj0jWbvbo) or [Kevin Wang's Twitter\n", - "thread](https://threadreaderapp.com/thread/1587601532639494146.html) for an overview). The task is\n", - "to complete sentences like \"After John and Mary went to the shops, John gave a bottle of milk to\"\n", - "with \" Mary\" rather than \" John\". \n", - "\n", - "In the paper they rigorously reverse engineer a 26 head circuit, with 7 separate categories of heads\n", - "used to perform this capability. Their rigorous methods are fairly involved, so in this notebook,\n", - "I'm going to skimp on rigour and instead try to speed run the process of finding suggestive evidence\n", - "for this circuit!\n", - "\n", - "The circuit they found roughly breaks down into three parts:\n", - "1. Identify what names are in the sentence\n", - "2. Identify which names are duplicated\n", - "3. Predict the name that is *not* duplicated" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The first step is to load in our model, GPT-2 Small, a 12 layer and 80M parameter transformer with `HookedTransformer.from_pretrained`. The various flags are simplifications that preserve the model's output but simplify its internals." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using pad_token, but it is not set yet.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Loaded pretrained model gpt2-small into HookedTransformer\n" - ] - } - ], - "source": [ - "# NBVAL_IGNORE_OUTPUT\n", - "model = HookedTransformer.from_pretrained(\n", - " \"gpt2-small\",\n", - " center_unembed=True,\n", - " center_writing_weights=True,\n", - " fold_ln=True,\n", - " refactor_factored_attn_matrices=True,\n", - ")\n", - "\n", - "# Get the default device used\n", - "device: torch.device = utils.get_device()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The next step is to verify that the model can *actually* do the task! Here we use `utils.test_prompt`, and see that the model is significantly better at predicting Mary than John! \n", - "\n", - "
Asides:\n", - "\n", - "Note: If we were being careful, we'd want to run the model on a range of prompts and find the average performance\n", - "\n", - "`prepend_bos` is a flag to add a BOS (beginning of sequence) to the start of the prompt. GPT-2 was not trained with this, but I find that it often makes model behaviour more stable, as the first token is treated weirdly.\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tokenized prompt: ['<|endoftext|>', 'After', ' John', ' and', ' Mary', ' went', ' to', ' the', ' store', ',', ' John', ' gave', ' a', ' bottle', ' of', ' milk', ' to']\n", - "Tokenized answer: [' Mary']\n" - ] - }, - { - "data": { - "text/html": [ - "
Performance on answer token:\n",
-       "Rank: 0        Logit: 18.09 Prob: 70.07% Token: | Mary|\n",
-       "
\n" - ], - "text/plain": [ - "Performance on answer token:\n", - "\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m18.09\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m70.07\u001b[0m\u001b[1m% Token: | Mary|\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Top 0th token. Logit: 18.09 Prob: 70.07% Token: | Mary|\n", - "Top 1th token. Logit: 15.38 Prob: 4.67% Token: | the|\n", - "Top 2th token. Logit: 15.35 Prob: 4.54% Token: | John|\n", - "Top 3th token. Logit: 15.25 Prob: 4.11% Token: | them|\n", - "Top 4th token. Logit: 14.84 Prob: 2.73% Token: | his|\n", - "Top 5th token. Logit: 14.06 Prob: 1.24% Token: | her|\n", - "Top 6th token. Logit: 13.54 Prob: 0.74% Token: | a|\n", - "Top 7th token. Logit: 13.52 Prob: 0.73% Token: | their|\n", - "Top 8th token. Logit: 13.13 Prob: 0.49% Token: | Jesus|\n", - "Top 9th token. Logit: 12.97 Prob: 0.42% Token: | him|\n" - ] - }, - { - "data": { - "text/html": [ - "
Ranks of the answer tokens: [(' Mary', 0)]\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Mary'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "example_prompt = \"After John and Mary went to the store, John gave a bottle of milk to\"\n", - "example_answer = \" Mary\"\n", - "utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We now want to find a reference prompt to run the model on. Even though our ultimate goal is to reverse engineer how this behaviour is done in general, often the best way to start out in mechanistic interpretability is by zooming in on a concrete example and understanding it in detail, and only *then* zooming out and verifying that our analysis generalises.\n", - "\n", - "We'll run the model on 4 instances of this task, each prompt given twice - one with the first name as the indirect object, one with the second name. To make our lives easier, we'll carefully choose prompts with single token names and the corresponding names in the same token positions.\n", - "\n", - "
(*) Aside on tokenization\n", - "\n", - "We want models that can take in arbitrary text, but models need to have a fixed vocabulary. So the solution is to define a vocabulary of **tokens** and to deterministically break up arbitrary text into tokens. Tokens are, essentially, subwords, and are determined by finding the most frequent substrings - this means that tokens vary a lot in length and frequency! \n", - "\n", - "Tokens are a *massive* headache and are one of the most annoying things about reverse engineering language models... Different names will be different numbers of tokens, different prompts will have the relevant tokens at different positions, different prompts will have different total numbers of tokens, etc. Language models often devote significant amounts of parameters in early layers to convert inputs from tokens to a more sensible internal format (and do the reverse in later layers). You really, really want to avoid needing to think about tokenization wherever possible when doing exploratory analysis (though, of course, it's relevant later when trying to flesh out your analysis and make it rigorous!). HookedTransformer comes with several helper methods to deal with tokens: `to_tokens, to_string, to_str_tokens, to_single_token, get_token_position`\n", - "\n", - "**Exercise:** I recommend using `model.to_str_tokens` to explore how the model tokenizes different strings. In particular, try adding or removing spaces at the start, or changing capitalization - these change tokenization!
" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['When John and Mary went to the shops, John gave the bag to', 'When John and Mary went to the shops, Mary gave the bag to', 'When Tom and James went to the park, James gave the ball to', 'When Tom and James went to the park, Tom gave the ball to', 'When Dan and Sid went to the shops, Sid gave an apple to', 'When Dan and Sid went to the shops, Dan gave an apple to', 'After Martin and Amy went to the park, Amy gave a drink to', 'After Martin and Amy went to the park, Martin gave a drink to']\n", - "[(' Mary', ' John'), (' John', ' Mary'), (' Tom', ' James'), (' James', ' Tom'), (' Dan', ' Sid'), (' Sid', ' Dan'), (' Martin', ' Amy'), (' Amy', ' Martin')]\n" - ] - } - ], - "source": [ - "prompt_format = [\n", - " \"When John and Mary went to the shops,{} gave the bag to\",\n", - " \"When Tom and James went to the park,{} gave the ball to\",\n", - " \"When Dan and Sid went to the shops,{} gave an apple to\",\n", - " \"After Martin and Amy went to the park,{} gave a drink to\",\n", - "]\n", - "names = [\n", - " (\" Mary\", \" John\"),\n", - " (\" Tom\", \" James\"),\n", - " (\" Dan\", \" Sid\"),\n", - " (\" Martin\", \" Amy\"),\n", - "]\n", - "# List of prompts\n", - "prompts = []\n", - "# List of answers, in the format (correct, incorrect)\n", - "answers = []\n", - "# List of the token (ie an integer) corresponding to each answer, in the format (correct_token, incorrect_token)\n", - "answer_tokens = []\n", - "for i in range(len(prompt_format)):\n", - " for j in range(2):\n", - " answers.append((names[i][j], names[i][1 - j]))\n", - " answer_tokens.append(\n", - " (\n", - " model.to_single_token(answers[-1][0]),\n", - " model.to_single_token(answers[-1][1]),\n", - " )\n", - " )\n", - " # Insert the *incorrect* answer to the prompt, making the correct answer the indirect object.\n", - " prompts.append(prompt_format[i].format(answers[-1][1]))\n", - "answer_tokens = torch.tensor(answer_tokens).to(device)\n", - "print(prompts)\n", - "print(answers)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Gotcha**: It's important that all of your prompts have the same number of tokens. If they're different lengths, then the position of the \"final\" logit where you can check logit difference will differ between prompts, and this will break the below code. The easiest solution is just to choose your prompts carefully to have the same number of tokens (you can eg add filler words like The, or newlines to start).\n", - "\n", - "There's a range of other ways of solving this, eg you can index more intelligently to get the final logit. A better way is to just use left padding by setting `model.tokenizer.padding_side = 'left'` before tokenizing the inputs and running the model; this way, you can use something like `logits[:, -1, :]` to easily access the final token outputs without complicated indexing. TransformerLens checks the value of `padding_side` of the tokenizer internally, and if the flag is set to be `'left'`, it adjusts the calculation of absolute position embedding and causal masking accordingly.\n", - "\n", - "In this demo, though, we stick to using the prompts of the same number of tokens because we want to show some visualisations aggregated along the batch dimension later in the demo." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Prompt length: 15\n", - "Prompt as tokens: ['<|endoftext|>', 'When', ' John', ' and', ' Mary', ' went', ' to', ' the', ' shops', ',', ' John', ' gave', ' the', ' bag', ' to']\n", - "Prompt length: 15\n", - "Prompt as tokens: ['<|endoftext|>', 'When', ' John', ' and', ' Mary', ' went', ' to', ' the', ' shops', ',', ' Mary', ' gave', ' the', ' bag', ' to']\n", - "Prompt length: 15\n", - "Prompt as tokens: ['<|endoftext|>', 'When', ' Tom', ' and', ' James', ' went', ' to', ' the', ' park', ',', ' James', ' gave', ' the', ' ball', ' to']\n", - "Prompt length: 15\n", - "Prompt as tokens: ['<|endoftext|>', 'When', ' Tom', ' and', ' James', ' went', ' to', ' the', ' park', ',', ' Tom', ' gave', ' the', ' ball', ' to']\n", - "Prompt length: 15\n", - "Prompt as tokens: ['<|endoftext|>', 'When', ' Dan', ' and', ' Sid', ' went', ' to', ' the', ' shops', ',', ' Sid', ' gave', ' an', ' apple', ' to']\n", - "Prompt length: 15\n", - "Prompt as tokens: ['<|endoftext|>', 'When', ' Dan', ' and', ' Sid', ' went', ' to', ' the', ' shops', ',', ' Dan', ' gave', ' an', ' apple', ' to']\n", - "Prompt length: 15\n", - "Prompt as tokens: ['<|endoftext|>', 'After', ' Martin', ' and', ' Amy', ' went', ' to', ' the', ' park', ',', ' Amy', ' gave', ' a', ' drink', ' to']\n", - "Prompt length: 15\n", - "Prompt as tokens: ['<|endoftext|>', 'After', ' Martin', ' and', ' Amy', ' went', ' to', ' the', ' park', ',', ' Martin', ' gave', ' a', ' drink', ' to']\n" - ] - } - ], - "source": [ - "for prompt in prompts:\n", - " str_tokens = model.to_str_tokens(prompt)\n", - " print(\"Prompt length:\", len(str_tokens))\n", - " print(\"Prompt as tokens:\", str_tokens)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We now run the model on these prompts and use `run_with_cache` to get both the logits and a cache of all internal activations for later analysis" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "tokens = model.to_tokens(prompts, prepend_bos=True)\n", - "\n", - "# Run the model and cache all activations\n", - "original_logits, cache = model.run_with_cache(tokens)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We'll later be evaluating how model performance differs upon performing various interventions, so it's useful to have a metric to measure model performance. Our metric here will be the **logit difference**, the difference in logit between the indirect object's name and the subject's name (eg, `logit(Mary)-logit(John)`). " - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Per prompt logit difference: tensor([3.3370, 3.2020, 2.7090, 3.7970, 1.7200, 5.2810, 2.6010, 5.7670])\n", - "Average logit difference: 3.552\n" - ] - } - ], - "source": [ - "def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):\n", - " # Only the final logits are relevant for the answer\n", - " final_logits = logits[:, -1, :]\n", - " answer_logits = final_logits.gather(dim=-1, index=answer_tokens)\n", - " answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]\n", - " if per_prompt:\n", - " return answer_logit_diff\n", - " else:\n", - " return answer_logit_diff.mean()\n", - "\n", - "\n", - "print(\n", - " \"Per prompt logit difference:\",\n", - " logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True)\n", - " .detach()\n", - " .cpu()\n", - " .round(decimals=3),\n", - ")\n", - "original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)\n", - "print(\n", - " \"Average logit difference:\",\n", - " round(logits_to_ave_logit_diff(original_logits, answer_tokens).item(), 3),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We see that the average logit difference is 3.5 - for context, this represents putting an $e^{3.5}\\approx 33\\times$ higher probability on the correct answer. " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Brainstorm What's Actually Going On (Optional)\n", - "\n", - "Before diving into running experiments, it's often useful to spend some time actually reasoning about how the behaviour in question could be implemented in the transformer. **This is optional, and you'll likely get the most out of engaging with this section if you have a decent understanding already of what a transformer is and how it works!**\n", - "\n", - "You don't have to do this and forming hypotheses after exploration is also reasonable, but I think it's often easier to explore and interpret results with some grounding in what you might find. In this particular case, I'm cheating somewhat, since I know the answer, but I'm trying to simulate the process of reasoning about it!\n", - "\n", - "Note that often your hypothesis will be wrong in some ways and often be completely off. We're doing science here, and the goal is to understand how the model *actually* works, and to form true beliefs! There are two separate traps here at two extremes that it's worth tracking:\n", - "* Confusion: Having no hypotheses at all, getting a lot of data and not knowing what to do with it, and just floundering around\n", - "* Dogmatism: Being overconfident in an incorrect hypothesis and being unwilling to let go of it when reality contradicts you, or flinching away from running the experiments that might disconfirm it.\n", - "\n", - "**Exercise:** Spend some time thinking through how you might imagine this behaviour being implemented in a transformer. Try to think through this for yourself before reading through my thoughts! \n", - "\n", - "
(*) My reasoning\n", - "\n", - "

Brainstorming:

\n", - "\n", - "So, what's hard about the task? Let's focus on the concrete example of the first prompt, \"When John and Mary went to the shops, John gave the bag to\" -> \" Mary\". \n", - "\n", - "A good starting point is thinking though whether a tiny model could do this, eg a 1L Attn-Only model. I'm pretty sure the answer is no! Attention is really good at the primitive operations of looking nearby, or copying information. I can believe a tiny model could figure out that at `to` it should look for names and predict that those names came next (eg the skip trigram \" John...to -> John\"). But it's much harder to tell how many of each previous name there are - attending 0.3 to each copy of John will look exactly the same as attending 0.6 to a single John token. So this will be pretty hard to figure out on the \" to\" token!\n", - "\n", - "The natural place to break this symmetry is on the second \" John\" token - telling whether there is an earlier copy of the current token should be a much easier task. So I might expect there to be a head which detects duplicate tokens on the second \" John\" token, and then another head which moves that information from the second \" John\" token to the \" to\" token. \n", - "\n", - "The model then needs to learn to predict \" Mary\" and not \" John\". I can see two natural ways to do this: \n", - "1. Detect all preceding names and move this information to \" to\" and then delete the any name corresponding to the duplicate token feature. This feels easier done with a non-linearity, since precisely cancelling out vectors is hard, so I'd imagine an MLP layer deletes the \" John\" direction of the residual stream\n", - "2. Have a head which attends to all previous names, but where the duplicate token features inhibit it from attending to specific names. So this only attends to Mary. And then the output of this head maps to the logits. \n", - "\n", - "(Spoiler: It's the second one).\n", - "\n", - "

Experiment Ideas

\n", - "\n", - "A test that could distinguish these two is to look at which components of the model add directly to the logits - if it's mostly attention heads which attend to \" Mary\" and to neither \" John\" it's probably hypothesis 2, if it's mostly MLPs it's probably hypothesis 1.\n", - "\n", - "And we should be able to identify duplicate token heads by finding ones which attend from \" John\" to \" John\", and whose outputs are then moved to the \" to\" token by V-Composition with another head (Spoiler: It's more complicated than that!)\n", - "\n", - "Note that all of the above reasoning is very simplistic and could easily break in a real model! There'll be significant parts of the model that figure out whether to use this circuit at all (we don't want to inhibit duplicated names when, eg, figuring out what goes at the start of the next sentence), and may be parts towards the end of the model that do \"post-processing\" just before the final output. But it's a good starting point for thinking about what's going on." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Direct Logit Attribution" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "*Look up unfamiliar terms in the [mech interp explainer](https://neelnanda.io/glossary)*\n", - "\n", - "Further, the easiest part of the model to understand is the output - this is what the model is trained to optimize, and so it can always be directly interpreted! Often the right approach to reverse engineering a circuit is to start at the end, understand how the model produces the right answer, and to then work backwards. The main technique used to do this is called **direct logit attribution**\n", - "\n", - "**Background:** The central object of a transformer is the **residual stream**. This is the sum of the outputs of each layer and of the original token and positional embedding. Importantly, this means that any linear function of the residual stream can be perfectly decomposed into the contribution of each layer of the transformer. Further, each attention layer's output can be broken down into the sum of the output of each head (See [A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html) for details), and each MLP layer's output can be broken down into the sum of the output of each neuron (and a bias term for each layer). \n", - "\n", - "The logits of a model are `logits=Unembed(LayerNorm(final_residual_stream))`. The Unembed is a linear map, and LayerNorm is approximately a linear map, so we can decompose the logits into the sum of the contributions of each component, and look at which components contribute the most to the logit of the correct token! This is called **direct logit attribution**. Here we look at the direct attribution to the logit difference!\n", - "\n", - "
(*) Background and motivation of the logit difference\n", - "\n", - "Logit difference is actually a *really* nice and elegant metric and is a particularly nice aspect of the setup of Indirect Object Identification. In general, there are two natural ways to interpret the model's outputs: the output logits, or the output log probabilities (or probabilities). \n", - "\n", - "The logits are much nicer and easier to understand, as noted above. However, the model is trained to optimize the cross-entropy loss (the average of log probability of the correct token). This means it does not directly optimize the logits, and indeed if the model adds an arbitrary constant to every logit, the log probabilities are unchanged. \n", - "\n", - "But `log_probs == logits.log_softmax(dim=-1) == logits - logsumexp(logits)`, and so `log_probs(\" Mary\") - log_probs(\" John\") = logits(\" Mary\") - logits(\" John\")` - the ability to add an arbitrary constant cancels out!\n", - "\n", - "Further, the metric helps us isolate the precise capability we care about - figuring out *which* name is the Indirect Object. There are many other components of the task - deciding whether to return an article (the) or pronoun (her) or name, realising that the sentence wants a person next at all, etc. By taking the logit difference we control for all of that.\n", - "\n", - "Our metric is further refined, because each prompt is repeated twice, for each possible indirect object. This controls for irrelevant behaviour such as the model learning that John is a more frequent token than Mary (this actually happens! The final layernorm bias increases the John logit by 1 relative to the Mary logit)\n", - "\n", - "
\n", - "\n", - "
Ignoring LayerNorm\n", - "\n", - "LayerNorm is an analogous normalization technique to BatchNorm (that's friendlier to massive parallelization) that transformers use. Every time a transformer layer reads information from the residual stream, it applies a LayerNorm to normalize the vector at each position (translating to set the mean to 0 and scaling to set the variance to 1) and then applying a learned vector of weights and biases to scale and translate the normalized vector. This is *almost* a linear map, apart from the scaling step, because that divides by the norm of the vector and the norm is not a linear function. (The `fold_ln` flag when loading a model factors out all the linear parts).\n", - "\n", - "But if we fixed the scale factor, the LayerNorm would be fully linear. And the scale of the residual stream is a global property that's a function of *all* components of the stream, while in practice there is normally just a few directions relevant to any particular component, so in practice this is an acceptable approximation. So when doing direct logit attribution we use the `apply_ln` flag on the `cache` to apply the global layernorm scaling factor to each constant. See [my clean GPT-2 implementation](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/clean-transformer-demo/Clean_Transformer_Demo.ipynb#scrollTo=Clean_Transformer_Implementation) for more on LayerNorm.\n", - "
" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Getting an output logit is equivalent to projecting onto a direction in the residual stream. We use `model.tokens_to_residual_directions` to map the answer tokens to that direction, and then convert this to a logit difference direction for each batch" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Answer residual directions shape: torch.Size([8, 2, 768])\n", - "Logit difference directions shape: torch.Size([8, 768])\n" - ] - } - ], - "source": [ - "answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)\n", - "print(\"Answer residual directions shape:\", answer_residual_directions.shape)\n", - "logit_diff_directions = (\n", - " answer_residual_directions[:, 0] - answer_residual_directions[:, 1]\n", - ")\n", - "print(\"Logit difference directions shape:\", logit_diff_directions.shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "To verify that this works, we can apply this to the final residual stream for our cached prompts (after applying LayerNorm scaling) and verify that we get the same answer. \n", - "\n", - "
Technical details\n", - "\n", - "`logits = Unembed(LayerNorm(final_residual_stream))`, so we technically need to account for the centering, and then learned translation and scaling of the layernorm, not just the variance 1 scaling. \n", - "\n", - "The centering is accounted for with the preprocessing flag `center_writing_weights` which ensures that every weight matrix writing to the residual stream has mean zero. \n", - "\n", - "The learned scaling is folded into the unembedding weights `model.unembed.W_U` via `W_U_fold = layer_norm.weights[:, None] * unembed.W_U`\n", - "\n", - "The learned translation is folded to `model.unembed.b_U`, a bias added to the logits (note that GPT-2 is not trained with an existing `b_U`). This roughly represents unigram statistics. But we can ignore this because each prompt occurs twice with names in the opposite order, so this perfectly cancels out. \n", - "\n", - "Note that rather than using layernorm scaling we could just study cache[\"ln_final.hook_normalised\"]\n", - "\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Final residual stream shape: torch.Size([8, 15, 768])\n", - "Calculated average logit diff: 3.552\n", - "Original logit difference: 3.552\n" - ] - } - ], - "source": [ - "# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].\n", - "final_residual_stream = cache[\"resid_post\", -1]\n", - "print(\"Final residual stream shape:\", final_residual_stream.shape)\n", - "final_token_residual_stream = final_residual_stream[:, -1, :]\n", - "# Apply LayerNorm scaling\n", - "# pos_slice is the subset of the positions we take - here the final token of each prompt\n", - "scaled_final_token_residual_stream = cache.apply_ln_to_stack(\n", - " final_token_residual_stream, layer=-1, pos_slice=-1\n", - ")\n", - "\n", - "average_logit_diff = einsum(\n", - " \"batch d_model, batch d_model -> \",\n", - " scaled_final_token_residual_stream,\n", - " logit_diff_directions,\n", - ") / len(prompts)\n", - "print(\"Calculated average logit diff:\", round(average_logit_diff.item(), 3))\n", - "print(\"Original logit difference:\", round(original_average_logit_diff.item(), 3))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Logit Lens" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can now decompose the residual stream! First we apply a technique called the [**logit lens**](https://www.alignmentforum.org/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens) - this looks at the residual stream after each layer and calculates the logit difference from that. This simulates what happens if we delete all subsequence layers. " - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "def residual_stack_to_logit_diff(\n", - " residual_stack: Float[torch.Tensor, \"components batch d_model\"],\n", - " cache: ActivationCache,\n", - ") -> float:\n", - " scaled_residual_stack = cache.apply_ln_to_stack(\n", - " residual_stack, layer=-1, pos_slice=-1\n", - " )\n", - " return einsum(\n", - " \"... batch d_model, batch d_model -> ...\",\n", - " scaled_residual_stack,\n", - " logit_diff_directions,\n", - " ) / len(prompts)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Fascinatingly, we see that the model is utterly unable to do the task until layer 7, almost all performance comes from attention layer 9, and performance actually *decreases* from there.\n", - "\n", - "**Note:** Hover over each data point to see what residual stream position it's from!\n", - "\n", - "
Details on `accumulated_resid`\n", - "**Key:** `n_pre` means the residual stream at the start of layer n, `n_mid` means the residual stream after the attention part of layer n (`n_post` is the same as `n+1_pre` so is not included)\n", - "\n", - "* `layer` is the layer for which we input the residual stream (this is used to identify *which* layer norm scaling factor we want)\n", - "* `incl_mid` is whether to include the residual stream in the middle of a layer, ie after attention & before MLP\n", - "* `pos_slice` is the subset of the positions used. See `utils.Slice` for details on the syntax.\n", - "* return_labels is whether to return the labels for each component returned (useful for plotting)\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ + "cells": [ { - "hovertemplate": "%{hovertext}

x=%{x}
y=%{y}", - "hovertext": [ - "0_pre", - "0_mid", - "1_pre", - "1_mid", - "2_pre", - "2_mid", - "3_pre", - "3_mid", - "4_pre", - "4_mid", - "5_pre", - "5_mid", - "6_pre", - "6_mid", - "7_pre", - "7_mid", - "8_pre", - "8_mid", - "9_pre", - "9_mid", - "10_pre", - "10_mid", - "11_pre", - "11_mid", - "final_post" - ], - "legendgroup": "", - "line": { - "color": "#636efa", - "dash": "solid" - }, - "marker": { - "symbol": "circle" - }, - "mode": "lines", - "name": "", - "orientation": "v", - "showlegend": false, - "type": "scatter", - "x": [ - 0, - 0.5, - 1, - 1.5, - 2, - 2.5, - 3, - 3.5, - 4, - 4.5, - 5, - 5.5, - 6, - 6.5, - 7, - 7.5, - 8, - 8.5, - 9, - 9.5, - 10, - 10.5, - 11, - 11.5, - 12 - ], - "xaxis": "x", - "y": [ - 1.2937933206558228e-05, - -0.006643360480666161, - -0.007525032386183739, - -0.009075596928596497, - -0.008736769668757915, - -0.008685456588864326, - -0.006480347365140915, - -0.007939882576465607, - -0.009661720134317875, - -0.015095856040716171, - -0.01419061329215765, - -0.019930001348257065, - -0.00912435818463564, - -0.027298055589199066, - -0.02985510788857937, - 0.2497255504131317, - 0.250558078289032, - 0.45005205273628235, - 0.45996904373168945, - 5.02545166015625, - 5.142900466918945, - 4.730565071105957, - 4.887058258056641, - 3.445383071899414, - 3.5518720149993896 - ], - "yaxis": "y" - } - ], - "layout": { - "legend": { - "tracegroupgap": 0 + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Exploratory_Analysis_Demo.ipynb)" + ] }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Exploratory Analysis Demo\n", + "\n", + "This notebook demonstrates how to use the\n", + "[TransformerLens](https://github.com/TransformerLensOrg/TransformerLens/) library to perform exploratory\n", + "analysis. The notebook tries to replicate the analysis of the Indirect Object Identification circuit\n", + "in the [Interpretability in the Wild](https://arxiv.org/abs/2211.00593) paper." ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Tips for Reading This\n", + "\n", + "* If running in Google Colab, go to Runtime > Change Runtime Type and select GPU as the hardware\n", + "accelerator.\n", + "* Look up unfamiliar terms in [the mech interp explainer](https://neelnanda.io/glossary)\n", + "* You can run all this code for yourself\n", + "* The graphs are interactive\n", + "* Use the table of contents pane in the sidebar to navigate (in Colab) or VSCode's \"Outline\" in the\n", + " explorer tab.\n", + "* Collapse irrelevant sections with the dropdown arrows\n", + "* Search the page using the search in the sidebar (with Colab) not CTRL+F" ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } }, - "title": { - "text": "Logit Difference From Accumulate Residual Stream" + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Environment Setup (ignore)" + ] }, - "xaxis": { - "anchor": "y", - "domain": [ - 0, - 1 - ], - "title": { - "text": "x" - } + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**You can ignore this part:** It's just for use internally to setup the tutorial in different\n", + "environments. You can delete this section if using in your own repo." + ] }, - "yaxis": { - "anchor": "x", - "domain": [ - 0, - 1 - ], - "title": { - "text": "y" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "accumulated_residual, labels = cache.accumulated_resid(\n", - " layer=-1, incl_mid=True, pos_slice=-1, return_labels=True\n", - ")\n", - "logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, cache)\n", - "line(\n", - " logit_lens_logit_diffs,\n", - " x=np.arange(model.cfg.n_layers * 2 + 1) / 2,\n", - " hover_name=labels,\n", - " title=\"Logit Difference From Accumulate Residual Stream\",\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Layer Attribution" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can repeat the above analysis but for each layer (this is equivalent to the differences between adjacent residual streams)\n", - "\n", - "Note: Annoying terminology overload - layer k of a transformer means the kth **transformer block**, but each block consists of an **attention layer** (to move information around) *and* an **MLP layer** (to process information). \n", - "\n", - "We see that only attention layers matter, which makes sense! The IOI task is about moving information around (ie moving the correct name and not the incorrect name), and less about processing it. And again we note that attention layer 9 improves things a lot, while attention 10 and attention 11 *decrease* performance" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ { - "hovertemplate": "%{hovertext}

x=%{x}
y=%{y}", - "hovertext": [ - "embed", - "pos_embed", - "0_attn_out", - "0_mlp_out", - "1_attn_out", - "1_mlp_out", - "2_attn_out", - "2_mlp_out", - "3_attn_out", - "3_mlp_out", - "4_attn_out", - "4_mlp_out", - "5_attn_out", - "5_mlp_out", - "6_attn_out", - "6_mlp_out", - "7_attn_out", - "7_mlp_out", - "8_attn_out", - "8_mlp_out", - "9_attn_out", - "9_mlp_out", - "10_attn_out", - "10_mlp_out", - "11_attn_out", - "11_mlp_out" - ], - "legendgroup": "", - "line": { - "color": "#636efa", - "dash": "solid" - }, - "marker": { - "symbol": "circle" - }, - "mode": "lines", - "name": "", - "orientation": "v", - "showlegend": false, - "type": "scatter", - "x": [ - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - 10, - 11, - 12, - 13, - 14, - 15, - 16, - 17, - 18, - 19, - 20, - 21, - 22, - 23, - 24, - 25 - ], - "xaxis": "x", - "y": [ - -0.00028366726473905146, - 0.00029660604195669293, - -0.0066563040018081665, - -0.0008816685294732451, - -0.0015505650080740452, - 0.00033882574643939734, - 5.131529178470373e-05, - 0.0022051138803362846, - -0.0014595506945624948, - -0.0017218313878402114, - -0.005434143822640181, - 0.0009052485693246126, - -0.0057394010946154594, - 0.010805649682879448, - -0.018173698335886, - -0.002557049971073866, - 0.27958065271377563, - 0.0008325176313519478, - 0.19949400424957275, - 0.00991708692163229, - 4.565483093261719, - 0.11744903028011322, - -0.4123360514640808, - 0.15649384260177612, - -1.4416757822036743, - 0.10648896545171738 - ], - "yaxis": "y" - } - ], - "layout": { - "legend": { - "tracegroupgap": 0 + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# Detect if we're running in Google Colab\n", + "try:\n", + " import google.colab\n", + " IN_COLAB = True\n", + " print(\"Running as a Colab notebook\")\n", + "except:\n", + " IN_COLAB = False\n", + "\n", + "# Install if in Colab\n", + "if IN_COLAB:\n", + " %pip install transformer_lens\n", + " %pip install circuitsvis\n", + " # Install a faster Node version\n", + " !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs # noqa\n", + "\n", + "# Hot reload in development mode & not running on the CD\n", + "if not IN_COLAB:\n", + " from IPython import get_ipython\n", + " ip = get_ipython()\n", + " if not ip.extension_manager.loaded:\n", + " ip.extension_manager.load('autoreload')\n", + " %autoreload 2\n" + ] }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Imports" ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from functools import partial\n", + "from typing import List, Optional, Union\n", + "\n", + "import einops\n", + "import numpy as np\n", + "import plotly.express as px\n", + "import plotly.io as pio\n", + "import torch\n", + "from circuitsvis.attention import attention_heads\n", + "from fancy_einsum import einsum\n", + "from IPython.display import HTML, IFrame\n", + "from jaxtyping import Float\n", + "\n", + "import transformer_lens.utils as utils\n", + "from transformer_lens import ActivationCache\n", + "from transformer_lens.model_bridge import TransformerBridge" ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### PyTorch Setup" ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } }, - "title": { - "text": "Logit Difference From Each Layer" + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training." + ] }, - "xaxis": { - "anchor": "y", - "domain": [ - 0, - 1 - ], - "title": { - "text": "x" - } + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Disabled automatic differentiation\n" + ] + } + ], + "source": [ + "torch.set_grad_enabled(False)\n", + "print(\"Disabled automatic differentiation\")" + ] }, - "yaxis": { - "anchor": "x", - "domain": [ - 0, - 1 - ], - "title": { - "text": "y" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "per_layer_residual, labels = cache.decompose_resid(\n", - " layer=-1, pos_slice=-1, return_labels=True\n", - ")\n", - "per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, cache)\n", - "line(per_layer_logit_diffs, hover_name=labels, title=\"Logit Difference From Each Layer\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Head Attribution" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can further break down the output of each attention layer into the sum of the outputs of each attention head. Each attention layer consists of 12 heads, which each act independently and additively.\n", - "\n", - "
Decomposing attention output into sums of heads \n", - "The standard way to compute the output of an attention layer is by concatenating the mixed values of each head, and multiplying by a big output weight matrix. But as described in [A Mathematical Framework](https://transformer-circuits.pub/2021/framework/index.html) this is equivalent to splitting the output weight matrix into a per-head output (here `model.blocks[k].attn.W_O`) and adding them up (including an overall bias term for the entire layer)\n", - "
\n", - "\n", - "We see that only a few heads really matter - heads L9H6 and L9H9 contribute a lot positively (explaining why attention layer 9 is so important), while heads L10H7 and L11H10 contribute a lot negatively (explaining why attention layer 10 and layer 11 are actively harmful). These correspond to (some of) the name movers and negative name movers discussed in the paper. There are also several heads that matter positively or negatively but less strongly (other name movers and backup name movers)\n", - "\n", - "There are a few meta observations worth making here - our model has 144 heads, yet we could localise this behaviour to a handful of specific heads, using straightforward, general techniques. This supports the claim in [A Mathematical Framework](https://transformer-circuits.pub/2021/framework/index.html) that attention heads are the right level of abstraction to understand attention. It also really surprising that there are *negative* heads - eg L10H7 makes the incorrect logit 7x *more* likely. I'm not sure what's going on there, though the paper discusses some possibilities." - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tried to stack head results when they weren't cached. Computing head results now\n" - ] - }, - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ { - "coloraxis": "coloraxis", - "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", - "name": "0", - "type": "heatmap", - "xaxis": "x", - "yaxis": "y", - "z": [ - [ - -0.0020563392899930477, - -0.0005101899732835591, - 0.0004685786843765527, - 0.00012512074317783117, - -0.0006028738571330905, - -0.0002429460291750729, - -0.0023189077619463205, - -0.002758360467851162, - 0.000564602785743773, - 0.0009697531932033598, - -0.0002504526637494564, - 4.737317794933915e-06 - ], - [ - -0.0010070882271975279, - 0.00039470894262194633, - -0.00154874159488827, - 0.0014034928753972054, - -0.0012653048615902662, - -0.0011358022456988692, - -0.00281596090644598, - -0.0029645217582583427, - 0.0029190476052463055, - 0.0025743592996150255, - 0.00036239007022231817, - 0.0017548729665577412 - ], - [ - 0.0005569400964304805, - -0.001126631861552596, - -0.0017353934235870838, - -0.0014514457434415817, - -0.00028735760133713484, - 0.0017211002996191382, - 0.0026658899150788784, - 0.00311466702260077, - 0.0005667927907779813, - -0.003666515462100506, - -0.0018847601022571325, - 7.039372576400638e-06 - ], - [ - -0.0007264417363330722, - 0.00011364505917299539, - 0.0014301587361842394, - 0.0007490540738217533, - 0.0020184689201414585, - 0.0007436950691044331, - -0.00046178390039131045, - -0.0039057559333741665, - 0.0011406694538891315, - -4.022853681817651e-05, - -0.0013293239753693342, - -0.0017636751290410757 - ], - [ - -0.0028280913829803467, - 0.00033634810824878514, - -0.0014248639345169067, - -0.003777273464947939, - 0.0015998880844563246, - 0.0002989505883306265, - -0.000804675742983818, - 0.002038792008534074, - -0.0015593919670209289, - -0.0006436670082621276, - 0.0011168173514306545, - -0.00035012533771805465 - ], - [ - 0.0011338205076754093, - 0.0011259170714765787, - -0.002516670385375619, - -0.0014790185960009694, - 0.0003878737334161997, - -6.408110493794084e-05, - -0.0005096744280308485, - -0.0008840755908749998, - 0.0006398351397365332, - -0.0010097370250150561, - -0.006759158335626125, - 0.0033667823299765587 - ], - [ - -0.01514742337167263, - -0.0021350777242332697, - 0.002593174111098051, - -0.00042678468162193894, - -0.005558924749493599, - 0.0026658528950065374, - 0.006411008536815643, - -0.003826778382062912, - -0.0003843410813715309, - -0.0016430341638624668, - -0.0013344454346224666, - -9.20506427064538e-05 - ], - [ - -9.476230479776859e-05, - -0.0057889921590685844, - -0.0006383581785485148, - 0.13493388891220093, - -0.001768707763403654, - -0.018917907029390335, - 0.003873429261147976, - -0.0021450775675475597, - -0.010327338241040707, - 0.18325845897197723, - -0.0007747983909212053, - -0.00104526337236166 - ], - [ - -0.003833949100226164, - -0.0008046097937040031, - -0.012673400342464447, - 0.00804573018103838, - 0.003604492638260126, - -0.009398287162184715, - -0.08272082358598709, - 0.003555194940418005, - -0.018404025584459305, - 0.0017587244510650635, - 0.2896133363246918, - 0.022854052484035492 - ], - [ - 0.08595258742570877, - -0.0006932877004146576, - 0.06817055493593216, - 0.013111240230500698, - -0.021098043769598007, - 0.05112447217106819, - 1.3844914436340332, - 0.045836858451366425, - -0.03830280900001526, - 2.985445976257324, - 0.0019662054255604744, - -0.008030137047171593 - ], - [ - 0.5608693957328796, - 0.17083050310611725, - -0.03361757844686508, - 0.05821544677019119, - -0.0024530249647796154, - 0.0018771197646856308, - 0.28827205300331116, - -1.8986485004425049, - -0.0015286931302398443, - -0.035129792988300323, - 0.4802178740501404, - -0.0009115453576669097 - ], - [ - 0.016075748950242996, - -0.03986122086644173, - -0.3879126012325287, - 0.011123123578727245, - -0.005477819126099348, - -0.0025129620917141438, - -0.08056175708770752, - 0.007518616039305925, - 0.0430111438035965, - -0.040082238614559174, - -0.9702364802360535, - 0.011862239800393581 - ] - ] - } - ], - "layout": { - "coloraxis": { - "cmid": 0, - "colorscale": [ - [ - 0, - "rgb(103,0,31)" - ], - [ - 0.1, - "rgb(178,24,43)" - ], - [ - 0.2, - "rgb(214,96,77)" - ], - [ - 0.3, - "rgb(244,165,130)" - ], - [ - 0.4, - "rgb(253,219,199)" - ], - [ - 0.5, - "rgb(247,247,247)" - ], - [ - 0.6, - "rgb(209,229,240)" - ], - [ - 0.7, - "rgb(146,197,222)" - ], - [ - 0.8, - "rgb(67,147,195)" - ], - [ - 0.9, - "rgb(33,102,172)" - ], - [ - 1, - "rgb(5,48,97)" - ] - ] + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plotting Helper Functions (ignore)" + ] }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Some plotting helper functions are included here (for simplicity)." ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "def imshow(tensor, **kwargs):\n", + " px.imshow(\n", + " utils.to_numpy(tensor),\n", + " color_continuous_midpoint=0.0,\n", + " color_continuous_scale=\"RdBu\",\n", + " **kwargs,\n", + " ).show()\n", + "\n", + "\n", + "def line(tensor, **kwargs):\n", + " px.line(\n", + " y=utils.to_numpy(tensor),\n", + " **kwargs,\n", + " ).show()\n", + "\n", + "\n", + "def scatter(x, y, xaxis=\"\", yaxis=\"\", caxis=\"\", **kwargs):\n", + " x = utils.to_numpy(x)\n", + " y = utils.to_numpy(y)\n", + " px.scatter(\n", + " y=y,\n", + " x=x,\n", + " labels={\"x\": xaxis, \"y\": yaxis, \"color\": caxis},\n", + " **kwargs,\n", + " ).show()" ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Introduction\n", + "\n", + "This is a demo notebook for [TransformerLens](https://github.com/TransformerLensOrg/TransformerLens), a library for mechanistic interpretability of GPT-2 style transformer language models. A core design principle of the library is to enable exploratory analysis - one of the most fun parts of mechanistic interpretability compared to normal ML is the extremely short feedback loops! The point of this library is to keep the gap between having an experiment idea and seeing the results as small as possible, to make it easy for **research to feel like play** and to enter a flow state.\n", + "\n", + "The goal of this notebook is to demonstrate what exploratory analysis looks like in practice with the library. I use my standard toolkit of basic mechanistic interpretability techniques to try interpreting a real circuit in GPT-2 small. Check out [the main demo](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Main_Demo.ipynb) for an introduction to the library and how to use it. \n", + "\n", + "Stylistically, I will go fairly slowly and explain in detail what I'm doing and why, aiming to help convey how to do this kind of research yourself! But the code itself is written to be simple and generic, and easy to copy and paste into your own projects for different tasks and models.\n", + "\n", + "Details tags contain asides, flavour + interpretability intuitions. These are more in the weeds and you don't need to read them or understand them, but they're helpful if you want to learn how to do mechanistic interpretability yourself! I star the ones I think are most important.\n", + "
(*) Example details tagExample aside!
" ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } }, - "title": { - "text": "Logit Difference From Each Head" + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Indirect Object Identification\n", + "\n", + "The first step when trying to reverse engineer a circuit in a model is to identify *what* capability\n", + "I want to reverse engineer. Indirect Object Identification is a task studied in Redwood Research's\n", + "excellent [Interpretability in the Wild](https://arxiv.org/abs/2211.00593) paper (see [my interview\n", + "with the authors](https://www.youtube.com/watch?v=gzwj0jWbvbo) or [Kevin Wang's Twitter\n", + "thread](https://threadreaderapp.com/thread/1587601532639494146.html) for an overview). The task is\n", + "to complete sentences like \"After John and Mary went to the shops, John gave a bottle of milk to\"\n", + "with \" Mary\" rather than \" John\". \n", + "\n", + "In the paper they rigorously reverse engineer a 26 head circuit, with 7 separate categories of heads\n", + "used to perform this capability. Their rigorous methods are fairly involved, so in this notebook,\n", + "I'm going to skimp on rigour and instead try to speed run the process of finding suggestive evidence\n", + "for this circuit!\n", + "\n", + "The circuit they found roughly breaks down into three parts:\n", + "1. Identify what names are in the sentence\n", + "2. Identify which names are duplicated\n", + "3. Predict the name that is *not* duplicated" + ] }, - "xaxis": { - "anchor": "y", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "scaleanchor": "y", - "title": { - "text": "Head" - } + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The first step is to load in our model, GPT-2 Small, a 12 layer and 80M parameter transformer with `TransformerBridge.boot_transformers`. The various flags are simplifications that preserve the model's output but simplify its internals." + ] }, - "yaxis": { - "anchor": "x", - "autorange": "reversed", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "title": { - "text": "Layer" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "per_head_residual, labels = cache.stack_head_results(\n", - " layer=-1, pos_slice=-1, return_labels=True\n", - ")\n", - "per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)\n", - "per_head_logit_diffs = einops.rearrange(\n", - " per_head_logit_diffs,\n", - " \"(layer head_index) -> layer head_index\",\n", - " layer=model.cfg.n_layers,\n", - " head_index=model.cfg.n_heads,\n", - ")\n", - "imshow(\n", - " per_head_logit_diffs,\n", - " labels={\"x\": \"Head\", \"y\": \"Layer\"},\n", - " title=\"Logit Difference From Each Head\",\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Attention Analysis" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Attention heads are particularly easy to study because we can look directly at their attention patterns and study from what positions they move information from and two. This is particularly easy here as we're looking at the direct effect on the logits so we need only look at the attention patterns from the final token. \n", - "\n", - "We use Alan Cooney's circuitsvis library to visualize the attention patterns! We visualize the top 3 positive and negative heads by direct logit attribution, and show these for the first prompt (as an illustration).\n", - "\n", - "
Interpreting Attention Patterns \n", - "An easy mistake to make when looking at attention patterns is thinking that they must convey information about the token looked at (maybe accounting for the context of the token). But actually, all we can confidently say is that it moves information from the *residual stream position* corresponding to that input token. Especially later on in the model, there may be components in the residual stream that are nothing to do with the input token! Eg the period at the end of a sentence may contain summary information for that sentence, and the head may solely move that, rather than caring about whether it ends in \".\", \"!\" or \"?\"\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [], - "source": [ - "def visualize_attention_patterns(\n", - " heads: Union[List[int], int, Float[torch.Tensor, \"heads\"]],\n", - " local_cache: ActivationCache,\n", - " local_tokens: torch.Tensor,\n", - " title: Optional[str] = \"\",\n", - " max_width: Optional[int] = 700,\n", - ") -> str:\n", - " # If a single head is given, convert to a list\n", - " if isinstance(heads, int):\n", - " heads = [heads]\n", - "\n", - " # Create the plotting data\n", - " labels: List[str] = []\n", - " patterns: List[Float[torch.Tensor, \"dest_pos src_pos\"]] = []\n", - "\n", - " # Assume we have a single batch item\n", - " batch_index = 0\n", - "\n", - " for head in heads:\n", - " # Set the label\n", - " layer = head // model.cfg.n_heads\n", - " head_index = head % model.cfg.n_heads\n", - " labels.append(f\"L{layer}H{head_index}\")\n", - "\n", - " # Get the attention patterns for the head\n", - " # Attention patterns have shape [batch, head_index, query_pos, key_pos]\n", - " patterns.append(local_cache[\"attn\", layer][batch_index, head_index])\n", - "\n", - " # Convert the tokens to strings (for the axis labels)\n", - " str_tokens = model.to_str_tokens(local_tokens)\n", - "\n", - " # Combine the patterns into a single tensor\n", - " patterns: Float[torch.Tensor, \"head_index dest_pos src_pos\"] = torch.stack(\n", - " patterns, dim=0\n", - " )\n", - "\n", - " # Circuitsvis Plot (note we get the code version so we can concatenate with the title)\n", - " plot = attention_heads(\n", - " attention=patterns, tokens=str_tokens, attention_head_names=labels\n", - " ).show_code()\n", - "\n", - " # Display the title\n", - " title_html = f\"

{title}


\"\n", - "\n", - " # Return the visualisation as raw code\n", - " return f\"
{title_html + plot}
\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Inspecting the patterns, we can see that both types of name movers attend to the indirect object - this suggests they're simply copying the name attended to (with the OV circuit) and that the interesting part is the circuit behind the attention pattern that calculates *where* to move information from (the QK circuit)" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "

Top 3 Positive Logit Attribution Heads


\n", - "

Top 3 Negative Logit Attribution Heads


\n", - "
" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "top_k = 3\n", - "\n", - "top_positive_logit_attr_heads = torch.topk(\n", - " per_head_logit_diffs.flatten(), k=top_k\n", - ").indices\n", - "\n", - "positive_html = visualize_attention_patterns(\n", - " top_positive_logit_attr_heads,\n", - " cache,\n", - " tokens[0],\n", - " f\"Top {top_k} Positive Logit Attribution Heads\",\n", - ")\n", - "\n", - "top_negative_logit_attr_heads = torch.topk(\n", - " -per_head_logit_diffs.flatten(), k=top_k\n", - ").indices\n", - "\n", - "negative_html = visualize_attention_patterns(\n", - " top_negative_logit_attr_heads,\n", - " cache,\n", - " tokens[0],\n", - " title=f\"Top {top_k} Negative Logit Attribution Heads\",\n", - ")\n", - "\n", - "HTML(positive_html + negative_html)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Activation Patching" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**This section explains how to do activation patching conceptually by implementing it from scratch. To use it in practice with TransformerLens, see [this demonstration instead](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Activation_Patching_in_TL_Demo.ipynb)**.\n", - "\n", - "The obvious limitation to the techniques used above is that they only look at the very end of the circuit - the parts that directly affect the logits. Clearly this is not sufficient to understand the circuit! We want to understand how things compose together to produce this final output, and ideally to produce an end-to-end circuit fully explaining this behaviour. \n", - "\n", - "The technique we'll use to investigate this is called **activation patching**. This was first introduced in [David Bau and Kevin Meng's excellent ROME paper](https://rome.baulab.info/), there called causal tracing. \n", - "\n", - "The setup of activation patching is to take two runs of the model on two different inputs, the clean run and the corrupted run. The clean run outputs the correct answer and the corrupted run does not. The key idea is that we give the model the corrupted input, but then **intervene** on a specific activation and **patch** in the corresponding activation from the clean run (ie replace the corrupted activation with the clean activation), and then continue the run. And we then measure how much the output has updated towards the correct answer. \n", - "\n", - "We can then iterate over many possible activations and look at how much they affect the corrupted run. If patching in an activation significantly increases the probability of the correct answer, this allows us to *localise* which activations matter. \n", - "\n", - "The ability to localise is a key move in mechanistic interpretability - if the computation is diffuse and spread across the entire model, it is likely much harder to form a clean mechanistic story for what's going on. But if we can identify precisely which parts of the model matter, we can then zoom in and determine what they represent and how they connect up with each other, and ultimately reverse engineer the underlying circuit that they represent. \n", - "\n", - "Here's an animation from the ROME paper demonstrating this technique (they studied factual recall, and use stars to represent corruption applied to the subject of the sentence, but the same principles apply):\n", - "\n", - "![CT Animation](https://rome.baulab.info/images/small-ct-animation.gif)\n", - "\n", - "See also [the explanation in a mech interp explainer](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=qeWBvs-R-taFfcCq-S_hgMqx) and [this piece](https://www.neelnanda.io/mechanistic-interpretability/attribution-patching#how-to-think-about-activation-patching) describing how to think about patching on a conceptual level" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The above was all fairly abstract, so let's zoom in and lay out a concrete example to understand Indirect Object Identification.\n", - "\n", - "Here our clean input will be eg \"After John and Mary went to the store, **John** gave a bottle of milk to\" and our corrupted input will be eg \"After John and Mary went to the store, **Mary** gave a bottle of milk to\". These prompts are identical except for the name of the indirect object, and so patching is a causal intervention which will allow us to understand precisely which parts of the network are identifying the indirect object. \n", - "\n", - "One natural thing to patch in is the residual stream at a specific layer and specific position. For example, the model is likely initially doing some processing on the second subject token to realise that it's a duplicate, but then uses attention to move that information to the \" to\" token. So patching in the residual stream at the \" to\" token will likely matter a lot in later layers but not at all in early layers.\n", - "\n", - "We can zoom in much further and patch in specific activations from specific layers. For example, we think that the output of head L9H9 on the final token is significant for directly connecting to the logits\n", - "\n", - "We can patch in specific activations, and can zoom in as far as seems reasonable. For example, if we patch in the output of head L9H9 on the final token, we would predict that it will significantly affect performance. \n", - "\n", - "Note that this technique does *not* tell us how the components of the circuit connect up, just what they are. \n", - "\n", - "
Technical details \n", - "The choice of clean and corrupted prompt has both pros and cons. By carefully setting up the counterfactual, that only differs in the second subject, we avoid detecting the parts of the model doing irrelevant computation like detecting that the indirect object task is relevant at all or that it should be outputting a name rather than an article or pronoun. Or even context like that John and Mary are names at all. \n", - "\n", - "However, it *also* bakes in some details that *are* relevant to the task. Such as finding the location of the second subject, and of the names in the first clause. Or that the name mover heads have learned to copy whatever they look at. \n", - "\n", - "Some of these could be patched by also changing up the order of the names in the original sentence - patching in \"After John and Mary went to the store, John gave a bottle of milk to\" vs \"After Mary and John went to the store, John gave a bottle of milk to\".\n", - "\n", - "In the ROME paper they take a different tack. Rather than carefully setting up counterfactuals between two different but related inputs, they **corrupt** the clean input by adding Gaussian noise to the token embedding for the subject. This is in some ways much lower effort (you don't need to set up a similar but different prompt) but can also introduce some issues, such as ways this noise might break things. In practice, you should take care about how you choose your counterfactuals and try out several. Try to reason beforehand about what they will and will not tell you, and compare the results between different counterfactuals.\n", - "\n", - "I discuss some of these limitations and how the author's solved them with much more refined usage of these techniques in our interview\n", - "
" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Residual Stream" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Lets begin by patching in the residual stream at the start of each layer and for each token position. " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We first create a set of corrupted tokens - where we swap each pair of prompts to have the opposite answer." - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Corrupted Average Logit Diff -3.55\n", - "Clean Average Logit Diff 3.55\n" - ] - } - ], - "source": [ - "corrupted_prompts = []\n", - "for i in range(0, len(prompts), 2):\n", - " corrupted_prompts.append(prompts[i + 1])\n", - " corrupted_prompts.append(prompts[i])\n", - "corrupted_tokens = model.to_tokens(corrupted_prompts, prepend_bos=True)\n", - "corrupted_logits, corrupted_cache = model.run_with_cache(\n", - " corrupted_tokens, return_type=\"logits\"\n", - ")\n", - "corrupted_average_logit_diff = logits_to_ave_logit_diff(corrupted_logits, answer_tokens)\n", - "print(\"Corrupted Average Logit Diff\", round(corrupted_average_logit_diff.item(), 2))\n", - "print(\"Clean Average Logit Diff\", round(original_average_logit_diff.item(), 2))" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['<|endoftext|>When John and Mary went to the shops, Mary gave the bag to',\n", - " '<|endoftext|>When John and Mary went to the shops, John gave the bag to',\n", - " '<|endoftext|>When Tom and James went to the park, Tom gave the ball to',\n", - " '<|endoftext|>When Tom and James went to the park, James gave the ball to',\n", - " '<|endoftext|>When Dan and Sid went to the shops, Dan gave an apple to',\n", - " '<|endoftext|>When Dan and Sid went to the shops, Sid gave an apple to',\n", - " '<|endoftext|>After Martin and Amy went to the park, Martin gave a drink to',\n", - " '<|endoftext|>After Martin and Amy went to the park, Amy gave a drink to']" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model.to_string(corrupted_tokens)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We now intervene on the corrupted run and patch in the clean residual stream at a specific layer and position.\n", - "\n", - "We do the intervention using TransformerLens's `HookPoint` feature. We can design a hook function that takes in a specific activation and returns an edited copy, and temporarily add it in with `model.run_with_hooks`. " - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [], - "source": [ - "def patch_residual_component(\n", - " corrupted_residual_component: Float[torch.Tensor, \"batch pos d_model\"],\n", - " hook,\n", - " pos,\n", - " clean_cache,\n", - "):\n", - " corrupted_residual_component[:, pos, :] = clean_cache[hook.name][:, pos, :]\n", - " return corrupted_residual_component\n", - "\n", - "\n", - "def normalize_patched_logit_diff(patched_logit_diff):\n", - " # Subtract corrupted logit diff to measure the improvement, divide by the total improvement from clean to corrupted to normalise\n", - " # 0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively *improved* on clean performance\n", - " return (patched_logit_diff - corrupted_average_logit_diff) / (\n", - " original_average_logit_diff - corrupted_average_logit_diff\n", - " )\n", - "\n", - "\n", - "patched_residual_stream_diff = torch.zeros(\n", - " model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32\n", - ")\n", - "for layer in range(model.cfg.n_layers):\n", - " for position in range(tokens.shape[1]):\n", - " hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)\n", - " patched_logits = model.run_with_hooks(\n", - " corrupted_tokens,\n", - " fwd_hooks=[(utils.get_act_name(\"resid_pre\", layer), hook_fn)],\n", - " return_type=\"logits\",\n", - " )\n", - " patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)\n", - "\n", - " patched_residual_stream_diff[layer, position] = normalize_patched_logit_diff(\n", - " patched_logit_diff\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can immediately see that, exactly as predicted, originally all relevant computation happens on the second subject token, and at layers 7 and 8, the information is moved to the final token. Moving the residual stream at the correct position near *exactly* recovers performance!\n", - "\n", - "For reference, tokens and their index from the first prompt are on the x-axis. In an abuse of notation, note that the difference here is averaged over *all* 8 prompts, while the labels only come from the *first* prompt. \n", - "\n", - "To be easier to interpret, we normalise the logit difference, by subtracting the corrupted logit difference, and dividing by the total improvement from clean to corrupted to normalise\n", - "0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively *improved* on clean performance" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ { - "coloraxis": "coloraxis", - "hovertemplate": "Position: %{x}
Layer: %{y}
color: %{z}", - "name": "0", - "type": "heatmap", - "x": [ - "<|endoftext|>_0", - "When_1", - " John_2", - " and_3", - " Mary_4", - " went_5", - " to_6", - " the_7", - " shops_8", - ",_9", - " John_10", - " gave_11", - " the_12", - " bag_13", - " to_14" - ], - "xaxis": "x", - "yaxis": "y", - "z": [ - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 1.000650405883789, - -0.0002469856117386371, - 9.76665523921838e-06, - -0.00036458822432905436, - -4.8967522161547095e-05 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 1.001051902770996, - -2.7621845219982788e-05, - -1.9768245692830533e-05, - -0.0004596704675350338, - -0.0005947590689174831 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 1.0002663135528564, - 0.0008680911851115525, - 0.0005157867562957108, - -0.0009929431835189462, - -0.0008658089209347963 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.994907796382904, - 0.005429857410490513, - 0.0016050540143623948, - -0.0006193603039719164, - -0.0016324409516528249 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.9675672054290771, - 0.03134213387966156, - 0.0028418952133506536, - -0.0012302964460104704, - -0.000985861523076892 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.967520534992218, - 0.03100077249109745, - 0.0017823305679485202, - -0.00048668819363228977, - -0.0006467136554419994 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.9228319525718689, - 0.05134531855583191, - 0.004728672094643116, - 0.0009345446596853435, - 0.017046840861439705 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.6565483808517456, - 0.02385685034096241, - 0.002357019344344735, - -1.7183941963594407e-05, - 0.3186916410923004 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.027302566915750504, - 0.03142499923706055, - 0.0018202561186626554, - 0.0007990868762135506, - 0.9383866190910339 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.026841485872864723, - 0.02098155952990055, - 0.0012512058019638062, - 0.00032317222212441266, - 1.0048279762268066 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.005687985569238663, - 0.014263377524912357, - 0.00048709093243815005, - -8.977938705356792e-05, - 0.9914212226867676 - ] - ] - } - ], - "layout": { - "coloraxis": { - "cmid": 0, - "colorscale": [ - [ - 0, - "rgb(103,0,31)" - ], - [ - 0.1, - "rgb(178,24,43)" - ], - [ - 0.2, - "rgb(214,96,77)" - ], - [ - 0.3, - "rgb(244,165,130)" - ], - [ - 0.4, - "rgb(253,219,199)" - ], - [ - 0.5, - "rgb(247,247,247)" - ], - [ - 0.6, - "rgb(209,229,240)" - ], - [ - 0.7, - "rgb(146,197,222)" - ], - [ - 0.8, - "rgb(67,147,195)" - ], - [ - 0.9, - "rgb(33,102,172)" - ], - [ - 1, - "rgb(5,48,97)" - ] - ] + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using pad_token, but it is not set yet.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded pretrained model gpt2-small into HookedTransformer\n" + ] + } + ], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "model = TransformerBridge.boot_transformers(\n", + " \"gpt2\",\n", + " center_unembed=True,\n", + " center_writing_weights=True,\n", + " fold_ln=True,\n", + " refactor_factored_attn_matrices=True,\n", + ")\n", + "model.enable_compatibility_mode()\n", + "\n", + "# Get the default device used\n", + "device: torch.device = utils.get_device()" + ] }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The next step is to verify that the model can *actually* do the task! Here we use `utils.test_prompt`, and see that the model is significantly better at predicting Mary than John! \n", + "\n", + "
Asides:\n", + "\n", + "Note: If we were being careful, we'd want to run the model on a range of prompts and find the average performance\n", + "\n", + "`prepend_bos` is a flag to add a BOS (beginning of sequence) to the start of the prompt. GPT-2 was not trained with this, but I find that it often makes model behaviour more stable, as the first token is treated weirdly.\n", + "
" ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tokenized prompt: ['<|endoftext|>', 'After', ' John', ' and', ' Mary', ' went', ' to', ' the', ' store', ',', ' John', ' gave', ' a', ' bottle', ' of', ' milk', ' to']\n", + "Tokenized answer: [' Mary']\n" + ] + }, + { + "data": { + "text/html": [ + "
Performance on answer token:\n",
+                            "Rank: 0        Logit: 18.09 Prob: 70.07% Token: | Mary|\n",
+                            "
\n" + ], + "text/plain": [ + "Performance on answer token:\n", + "\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m18.09\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m70.07\u001b[0m\u001b[1m% Token: | Mary|\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Top 0th token. Logit: 18.09 Prob: 70.07% Token: | Mary|\n", + "Top 1th token. Logit: 15.38 Prob: 4.67% Token: | the|\n", + "Top 2th token. Logit: 15.35 Prob: 4.54% Token: | John|\n", + "Top 3th token. Logit: 15.25 Prob: 4.11% Token: | them|\n", + "Top 4th token. Logit: 14.84 Prob: 2.73% Token: | his|\n", + "Top 5th token. Logit: 14.06 Prob: 1.24% Token: | her|\n", + "Top 6th token. Logit: 13.54 Prob: 0.74% Token: | a|\n", + "Top 7th token. Logit: 13.52 Prob: 0.73% Token: | their|\n", + "Top 8th token. Logit: 13.13 Prob: 0.49% Token: | Jesus|\n", + "Top 9th token. Logit: 12.97 Prob: 0.42% Token: | him|\n" + ] + }, + { + "data": { + "text/html": [ + "
Ranks of the answer tokens: [(' Mary', 0)]\n",
+                            "
\n" + ], + "text/plain": [ + "\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Mary'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "example_prompt = \"After John and Mary went to the store, John gave a bottle of milk to\"\n", + "example_answer = \" Mary\"\n", + "utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)" ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now want to find a reference prompt to run the model on. Even though our ultimate goal is to reverse engineer how this behaviour is done in general, often the best way to start out in mechanistic interpretability is by zooming in on a concrete example and understanding it in detail, and only *then* zooming out and verifying that our analysis generalises.\n", + "\n", + "We'll run the model on 4 instances of this task, each prompt given twice - one with the first name as the indirect object, one with the second name. To make our lives easier, we'll carefully choose prompts with single token names and the corresponding names in the same token positions.\n", + "\n", + "
(*) Aside on tokenization\n", + "\n", + "We want models that can take in arbitrary text, but models need to have a fixed vocabulary. So the solution is to define a vocabulary of **tokens** and to deterministically break up arbitrary text into tokens. Tokens are, essentially, subwords, and are determined by finding the most frequent substrings - this means that tokens vary a lot in length and frequency! \n", + "\n", + "Tokens are a *massive* headache and are one of the most annoying things about reverse engineering language models... Different names will be different numbers of tokens, different prompts will have the relevant tokens at different positions, different prompts will have different total numbers of tokens, etc. Language models often devote significant amounts of parameters in early layers to convert inputs from tokens to a more sensible internal format (and do the reverse in later layers). You really, really want to avoid needing to think about tokenization wherever possible when doing exploratory analysis (though, of course, it's relevant later when trying to flesh out your analysis and make it rigorous!). TransformerBridge comes with several helper methods to deal with tokens: `to_tokens, to_string, to_str_tokens, to_single_token, get_token_position`\n", + "\n", + "**Exercise:** I recommend using `model.to_str_tokens` to explore how the model tokenizes different strings. In particular, try adding or removing spaces at the start, or changing capitalization - these change tokenization!
" ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } }, - "title": { - "text": "Logit Difference From Patched Residual Stream" + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['When John and Mary went to the shops, John gave the bag to', 'When John and Mary went to the shops, Mary gave the bag to', 'When Tom and James went to the park, James gave the ball to', 'When Tom and James went to the park, Tom gave the ball to', 'When Dan and Sid went to the shops, Sid gave an apple to', 'When Dan and Sid went to the shops, Dan gave an apple to', 'After Martin and Amy went to the park, Amy gave a drink to', 'After Martin and Amy went to the park, Martin gave a drink to']\n", + "[(' Mary', ' John'), (' John', ' Mary'), (' Tom', ' James'), (' James', ' Tom'), (' Dan', ' Sid'), (' Sid', ' Dan'), (' Martin', ' Amy'), (' Amy', ' Martin')]\n" + ] + } + ], + "source": [ + "prompt_format = [\n", + " \"When John and Mary went to the shops,{} gave the bag to\",\n", + " \"When Tom and James went to the park,{} gave the ball to\",\n", + " \"When Dan and Sid went to the shops,{} gave an apple to\",\n", + " \"After Martin and Amy went to the park,{} gave a drink to\",\n", + "]\n", + "names = [\n", + " (\" Mary\", \" John\"),\n", + " (\" Tom\", \" James\"),\n", + " (\" Dan\", \" Sid\"),\n", + " (\" Martin\", \" Amy\"),\n", + "]\n", + "# List of prompts\n", + "prompts = []\n", + "# List of answers, in the format (correct, incorrect)\n", + "answers = []\n", + "# List of the token (ie an integer) corresponding to each answer, in the format (correct_token, incorrect_token)\n", + "answer_tokens = []\n", + "for i in range(len(prompt_format)):\n", + " for j in range(2):\n", + " answers.append((names[i][j], names[i][1 - j]))\n", + " answer_tokens.append(\n", + " (\n", + " model.to_single_token(answers[-1][0]),\n", + " model.to_single_token(answers[-1][1]),\n", + " )\n", + " )\n", + " # Insert the *incorrect* answer to the prompt, making the correct answer the indirect object.\n", + " prompts.append(prompt_format[i].format(answers[-1][1]))\n", + "answer_tokens = torch.tensor(answer_tokens).to(device)\n", + "print(prompts)\n", + "print(answers)" + ] }, - "xaxis": { - "anchor": "y", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "scaleanchor": "y", - "title": { - "text": "Position" - } + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Gotcha**: It's important that all of your prompts have the same number of tokens. If they're different lengths, then the position of the \"final\" logit where you can check logit difference will differ between prompts, and this will break the below code. The easiest solution is just to choose your prompts carefully to have the same number of tokens (you can eg add filler words like The, or newlines to start).\n", + "\n", + "There's a range of other ways of solving this, eg you can index more intelligently to get the final logit. A better way is to just use left padding by setting `model.tokenizer.padding_side = 'left'` before tokenizing the inputs and running the model; this way, you can use something like `logits[:, -1, :]` to easily access the final token outputs without complicated indexing. TransformerLens checks the value of `padding_side` of the tokenizer internally, and if the flag is set to be `'left'`, it adjusts the calculation of absolute position embedding and causal masking accordingly.\n", + "\n", + "In this demo, though, we stick to using the prompts of the same number of tokens because we want to show some visualisations aggregated along the batch dimension later in the demo." + ] }, - "yaxis": { - "anchor": "x", - "autorange": "reversed", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "title": { - "text": "Layer" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "prompt_position_labels = [\n", - " f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(tokens[0]))\n", - "]\n", - "imshow(\n", - " patched_residual_stream_diff,\n", - " x=prompt_position_labels,\n", - " title=\"Logit Difference From Patched Residual Stream\",\n", - " labels={\"x\": \"Position\", \"y\": \"Layer\"},\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Layers" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can apply exactly the same idea, but this time patching in attention or MLP layers. These are also residual components with identical shapes to the residual stream terms, so we can reuse the same hooks." - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [], - "source": [ - "patched_attn_diff = torch.zeros(\n", - " model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32\n", - ")\n", - "patched_mlp_diff = torch.zeros(\n", - " model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32\n", - ")\n", - "for layer in range(model.cfg.n_layers):\n", - " for position in range(tokens.shape[1]):\n", - " hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)\n", - " patched_attn_logits = model.run_with_hooks(\n", - " corrupted_tokens,\n", - " fwd_hooks=[(utils.get_act_name(\"attn_out\", layer), hook_fn)],\n", - " return_type=\"logits\",\n", - " )\n", - " patched_attn_logit_diff = logits_to_ave_logit_diff(\n", - " patched_attn_logits, answer_tokens\n", - " )\n", - " patched_mlp_logits = model.run_with_hooks(\n", - " corrupted_tokens,\n", - " fwd_hooks=[(utils.get_act_name(\"mlp_out\", layer), hook_fn)],\n", - " return_type=\"logits\",\n", - " )\n", - " patched_mlp_logit_diff = logits_to_ave_logit_diff(\n", - " patched_mlp_logits, answer_tokens\n", - " )\n", - "\n", - " patched_attn_diff[layer, position] = normalize_patched_logit_diff(\n", - " patched_attn_logit_diff\n", - " )\n", - " patched_mlp_diff[layer, position] = normalize_patched_logit_diff(\n", - " patched_mlp_logit_diff\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We see that several attention layers are significant but that, matching the residual stream results, early layers matter on the second subject token, and later layers matter on the final token, and layers essentially don't matter on any other token. Extremely localised! As with direct logit attribution, layer 9 is positive and layers 10 and 11 are not, suggesting that the late layers only matter for direct logit effects, but we also see that layers 7 and 8 matter significantly. Presumably these are the heads that move information about which name is duplicated from the second subject token to the final token." - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ { - "coloraxis": "coloraxis", - "hovertemplate": "Position: %{x}
Layer: %{y}
color: %{z}", - "name": "0", - "type": "heatmap", - "x": [ - "<|endoftext|>_0", - "When_1", - " John_2", - " and_3", - " Mary_4", - " went_5", - " to_6", - " the_7", - " shops_8", - ",_9", - " John_10", - " gave_11", - " the_12", - " bag_13", - " to_14" - ], - "xaxis": "x", - "yaxis": "y", - "z": [ - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.035456884652376175, - -0.0002469856117386371, - 9.76665523921838e-06, - -0.00036458822432905436, - -4.8967522161547095e-05 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.0029848709236830473, - 7.950929284561425e-05, - 2.0842242520302534e-05, - 8.088535105343908e-05, - -0.0005967392353340983 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.0019131568260490894, - 0.0006668510613963008, - 0.00039482791908085346, - -0.0007051457650959492, - -0.00027282864903099835 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.1546323299407959, - 0.0038019807543605566, - 0.0005171628436073661, - -0.00011964991426793858, - -0.0005599213181994855 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.005406397394835949, - 0.019581740722060204, - 0.001007509301416576, - -0.0002424211270408705, - 0.0007936497568152845 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.3520970046520233, - 0.0010525835677981377, - 0.00022436455765273422, - 0.00013367898645810783, - 8.172441448550671e-05 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.11986024677753448, - 0.021243548020720482, - 0.002727783052250743, - 0.0013409851817414165, - 0.01797366514801979 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.013310473412275314, - 0.011509180068969727, - 0.00037542887730523944, - -4.094611358596012e-05, - 0.29760244488716125 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.0015009435592219234, - 0.017351653426885605, - 0.0005848917062394321, - 0.0010122752282768488, - 0.5697318911552429 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.00012901381705887616, - 0.00630143890157342, - 0.00014156615361571312, - 0.00031229801243171096, - 0.27152299880981445 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.0009373303619213402, - 8.669164526509121e-05, - 0.00033243544748984277, - 9.73309283835988e-07, - -0.1929796040058136 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.40617984533309937 - ] - ] - } - ], - "layout": { - "coloraxis": { - "cmid": 0, - "colorscale": [ - [ - 0, - "rgb(103,0,31)" - ], - [ - 0.1, - "rgb(178,24,43)" - ], - [ - 0.2, - "rgb(214,96,77)" - ], - [ - 0.3, - "rgb(244,165,130)" - ], - [ - 0.4, - "rgb(253,219,199)" - ], - [ - 0.5, - "rgb(247,247,247)" - ], - [ - 0.6, - "rgb(209,229,240)" - ], - [ - 0.7, - "rgb(146,197,222)" - ], - [ - 0.8, - "rgb(67,147,195)" - ], - [ - 0.9, - "rgb(33,102,172)" - ], - [ - 1, - "rgb(5,48,97)" - ] - ] + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prompt length: 15\n", + "Prompt as tokens: ['<|endoftext|>', 'When', ' John', ' and', ' Mary', ' went', ' to', ' the', ' shops', ',', ' John', ' gave', ' the', ' bag', ' to']\n", + "Prompt length: 15\n", + "Prompt as tokens: ['<|endoftext|>', 'When', ' John', ' and', ' Mary', ' went', ' to', ' the', ' shops', ',', ' Mary', ' gave', ' the', ' bag', ' to']\n", + "Prompt length: 15\n", + "Prompt as tokens: ['<|endoftext|>', 'When', ' Tom', ' and', ' James', ' went', ' to', ' the', ' park', ',', ' James', ' gave', ' the', ' ball', ' to']\n", + "Prompt length: 15\n", + "Prompt as tokens: ['<|endoftext|>', 'When', ' Tom', ' and', ' James', ' went', ' to', ' the', ' park', ',', ' Tom', ' gave', ' the', ' ball', ' to']\n", + "Prompt length: 15\n", + "Prompt as tokens: ['<|endoftext|>', 'When', ' Dan', ' and', ' Sid', ' went', ' to', ' the', ' shops', ',', ' Sid', ' gave', ' an', ' apple', ' to']\n", + "Prompt length: 15\n", + "Prompt as tokens: ['<|endoftext|>', 'When', ' Dan', ' and', ' Sid', ' went', ' to', ' the', ' shops', ',', ' Dan', ' gave', ' an', ' apple', ' to']\n", + "Prompt length: 15\n", + "Prompt as tokens: ['<|endoftext|>', 'After', ' Martin', ' and', ' Amy', ' went', ' to', ' the', ' park', ',', ' Amy', ' gave', ' a', ' drink', ' to']\n", + "Prompt length: 15\n", + "Prompt as tokens: ['<|endoftext|>', 'After', ' Martin', ' and', ' Amy', ' went', ' to', ' the', ' park', ',', ' Martin', ' gave', ' a', ' drink', ' to']\n" + ] + } + ], + "source": [ + "for prompt in prompts:\n", + " str_tokens = model.to_str_tokens(prompt)\n", + " print(\"Prompt length:\", len(str_tokens))\n", + " print(\"Prompt as tokens:\", str_tokens)" + ] }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now run the model on these prompts and use `run_with_cache` to get both the logits and a cache of all internal activations for later analysis" ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "tokens = model.to_tokens(prompts, prepend_bos=True)\n", + "\n", + "# Run the model and cache all activations\n", + "original_logits, cache = model.run_with_cache(tokens)" ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We'll later be evaluating how model performance differs upon performing various interventions, so it's useful to have a metric to measure model performance. Our metric here will be the **logit difference**, the difference in logit between the indirect object's name and the subject's name (eg, `logit(Mary)-logit(John)`). " ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } }, - "title": { - "text": "Logit Difference From Patched Attention Layer" + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Per prompt logit difference: tensor([3.3370, 3.2020, 2.7090, 3.7970, 1.7200, 5.2810, 2.6010, 5.7670])\n", + "Average logit difference: 3.552\n" + ] + } + ], + "source": [ + "def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):\n", + " # Only the final logits are relevant for the answer\n", + " final_logits = logits[:, -1, :]\n", + " answer_logits = final_logits.gather(dim=-1, index=answer_tokens)\n", + " answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]\n", + " if per_prompt:\n", + " return answer_logit_diff\n", + " else:\n", + " return answer_logit_diff.mean()\n", + "\n", + "\n", + "print(\n", + " \"Per prompt logit difference:\",\n", + " logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True)\n", + " .detach()\n", + " .cpu()\n", + " .round(decimals=3),\n", + ")\n", + "original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)\n", + "print(\n", + " \"Average logit difference:\",\n", + " round(logits_to_ave_logit_diff(original_logits, answer_tokens).item(), 3),\n", + ")" + ] }, - "xaxis": { - "anchor": "y", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "scaleanchor": "y", - "title": { - "text": "Position" - } + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We see that the average logit difference is 3.5 - for context, this represents putting an $e^{3.5}\\approx 33\\times$ higher probability on the correct answer. " + ] }, - "yaxis": { - "anchor": "x", - "autorange": "reversed", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "title": { - "text": "Layer" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "imshow(\n", - " patched_attn_diff,\n", - " x=prompt_position_labels,\n", - " title=\"Logit Difference From Patched Attention Layer\",\n", - " labels={\"x\": \"Position\", \"y\": \"Layer\"},\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In contrast, the MLP layers do not matter much. This makes sense, since this is more a task about moving information than about processing it, and the MLP layers specialise in processing information.\n", - "\n", - "The one exception is MLP 0, which matters a lot, but I think this is misleading and just a generally true statement about MLP 0 rather than being about the circuit on this task.\n", - "\n", - "
My takes on MLP0 \n", - "It's often observed on GPT-2 Small that MLP0 matters a lot, and that ablating it utterly destroys performance. My current best guess is that the first MLP layer is essentially acting as an extension of the embedding (for whatever reason) and that when later layers want to access the input tokens they mostly read in the output of the first MLP layer, rather than the token embeddings. Within this frame, the first attention layer doesn't do much. \n", - "\n", - "In this framing, it makes sense that MLP0 matters on the second subject token, because that's the one position with a different input token!\n", - "\n", - "I'm not entirely sure why this happens, but I would guess that it's because the embedding and unembedding matrices in GPT-2 Small are the same. This is pretty unprincipled, as the tasks of embedding and unembedding tokens are not inverses, but this is common practice, and plausibly models want to dedicate some parameters to overcoming this. \n", - "\n", - "I only have suggestive evidence of this, and would love to see someone look into this properly!\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ { - "coloraxis": "coloraxis", - "hovertemplate": "Position: %{x}
Layer: %{y}
color: %{z}", - "name": "0", - "type": "heatmap", - "x": [ - "<|endoftext|>_0", - "When_1", - " John_2", - " and_3", - " Mary_4", - " went_5", - " to_6", - " the_7", - " shops_8", - ",_9", - " John_10", - " gave_11", - " the_12", - " bag_13", - " to_14" - ], - "xaxis": "x", - "yaxis": "y", - "z": [ - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.8507890701293945, - -0.00027843358111567795, - -7.293107046280056e-05, - -0.00047373308916576207, - 4.0039929444901645e-05 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.008863994851708412, - 0.000222149450564757, - 0.00014938619278836995, - -4.853121208725497e-05, - 0.000304041663184762 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.013550343923270702, - 5.86334899708163e-05, - -0.0003296833310741931, - -0.0006382559076882899, - 0.0007730424986220896 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.0019468198297545314, - 0.0004995090421289206, - 0.00017318192112725228, - 0.00016871812113095075, - 0.00040764876757748425 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.019787074998021126, - 0.004128609783947468, - -4.86990247736685e-05, - -0.00017019486404024065, - 0.0007914346642792225 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.09652391821146011, - -0.0018826150335371494, - -0.0004844730719923973, - 0.0007094081956893206, - -0.00018335132335778326 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.015900013968348503, - -0.0008501688134856522, - 0.00012337534280959517, - 2.7521158699528314e-05, - -0.007238299585878849 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.010360540822148323, - 0.0031509376130998135, - 0.0005309234256856143, - 0.0002361114020459354, - 0.008496351540088654 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.012533102184534073, - 2.201692586822901e-05, - -0.00035374757135286927, - 8.615465048933402e-05, - -0.021631328389048576 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.00033465056912973523, - 0.0008094912045635283, - 1.6244195649051107e-05, - 0.00012924875773023814, - 0.03162466362118721 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.0013599144294857979, - -0.00019499746849760413, - -9.934466652339324e-05, - -0.00014217027637641877, - 0.028764141723513603 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.02044912613928318 - ] - ] - } - ], - "layout": { - "coloraxis": { - "cmid": 0, - "colorscale": [ - [ - 0, - "rgb(103,0,31)" - ], - [ - 0.1, - "rgb(178,24,43)" - ], - [ - 0.2, - "rgb(214,96,77)" - ], - [ - 0.3, - "rgb(244,165,130)" - ], - [ - 0.4, - "rgb(253,219,199)" - ], - [ - 0.5, - "rgb(247,247,247)" - ], - [ - 0.6, - "rgb(209,229,240)" - ], - [ - 0.7, - "rgb(146,197,222)" - ], - [ - 0.8, - "rgb(67,147,195)" - ], - [ - 0.9, - "rgb(33,102,172)" - ], - [ - 1, - "rgb(5,48,97)" - ] - ] + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Brainstorm What's Actually Going On (Optional)\n", + "\n", + "Before diving into running experiments, it's often useful to spend some time actually reasoning about how the behaviour in question could be implemented in the transformer. **This is optional, and you'll likely get the most out of engaging with this section if you have a decent understanding already of what a transformer is and how it works!**\n", + "\n", + "You don't have to do this and forming hypotheses after exploration is also reasonable, but I think it's often easier to explore and interpret results with some grounding in what you might find. In this particular case, I'm cheating somewhat, since I know the answer, but I'm trying to simulate the process of reasoning about it!\n", + "\n", + "Note that often your hypothesis will be wrong in some ways and often be completely off. We're doing science here, and the goal is to understand how the model *actually* works, and to form true beliefs! There are two separate traps here at two extremes that it's worth tracking:\n", + "* Confusion: Having no hypotheses at all, getting a lot of data and not knowing what to do with it, and just floundering around\n", + "* Dogmatism: Being overconfident in an incorrect hypothesis and being unwilling to let go of it when reality contradicts you, or flinching away from running the experiments that might disconfirm it.\n", + "\n", + "**Exercise:** Spend some time thinking through how you might imagine this behaviour being implemented in a transformer. Try to think through this for yourself before reading through my thoughts! \n", + "\n", + "
(*) My reasoning\n", + "\n", + "

Brainstorming:

\n", + "\n", + "So, what's hard about the task? Let's focus on the concrete example of the first prompt, \"When John and Mary went to the shops, John gave the bag to\" -> \" Mary\". \n", + "\n", + "A good starting point is thinking though whether a tiny model could do this, eg a 1L Attn-Only model. I'm pretty sure the answer is no! Attention is really good at the primitive operations of looking nearby, or copying information. I can believe a tiny model could figure out that at `to` it should look for names and predict that those names came next (eg the skip trigram \" John...to -> John\"). But it's much harder to tell how many of each previous name there are - attending 0.3 to each copy of John will look exactly the same as attending 0.6 to a single John token. So this will be pretty hard to figure out on the \" to\" token!\n", + "\n", + "The natural place to break this symmetry is on the second \" John\" token - telling whether there is an earlier copy of the current token should be a much easier task. So I might expect there to be a head which detects duplicate tokens on the second \" John\" token, and then another head which moves that information from the second \" John\" token to the \" to\" token. \n", + "\n", + "The model then needs to learn to predict \" Mary\" and not \" John\". I can see two natural ways to do this: \n", + "1. Detect all preceding names and move this information to \" to\" and then delete the any name corresponding to the duplicate token feature. This feels easier done with a non-linearity, since precisely cancelling out vectors is hard, so I'd imagine an MLP layer deletes the \" John\" direction of the residual stream\n", + "2. Have a head which attends to all previous names, but where the duplicate token features inhibit it from attending to specific names. So this only attends to Mary. And then the output of this head maps to the logits. \n", + "\n", + "(Spoiler: It's the second one).\n", + "\n", + "

Experiment Ideas

\n", + "\n", + "A test that could distinguish these two is to look at which components of the model add directly to the logits - if it's mostly attention heads which attend to \" Mary\" and to neither \" John\" it's probably hypothesis 2, if it's mostly MLPs it's probably hypothesis 1.\n", + "\n", + "And we should be able to identify duplicate token heads by finding ones which attend from \" John\" to \" John\", and whose outputs are then moved to the \" to\" token by V-Composition with another head (Spoiler: It's more complicated than that!)\n", + "\n", + "Note that all of the above reasoning is very simplistic and could easily break in a real model! There'll be significant parts of the model that figure out whether to use this circuit at all (we don't want to inhibit duplicated names when, eg, figuring out what goes at the start of the next sentence), and may be parts towards the end of the model that do \"post-processing\" just before the final output. But it's a good starting point for thinking about what's going on." + ] }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Direct Logit Attribution" ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "*Look up unfamiliar terms in the [mech interp explainer](https://neelnanda.io/glossary)*\n", + "\n", + "Further, the easiest part of the model to understand is the output - this is what the model is trained to optimize, and so it can always be directly interpreted! Often the right approach to reverse engineering a circuit is to start at the end, understand how the model produces the right answer, and to then work backwards. The main technique used to do this is called **direct logit attribution**\n", + "\n", + "**Background:** The central object of a transformer is the **residual stream**. This is the sum of the outputs of each layer and of the original token and positional embedding. Importantly, this means that any linear function of the residual stream can be perfectly decomposed into the contribution of each layer of the transformer. Further, each attention layer's output can be broken down into the sum of the output of each head (See [A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html) for details), and each MLP layer's output can be broken down into the sum of the output of each neuron (and a bias term for each layer). \n", + "\n", + "The logits of a model are `logits=Unembed(LayerNorm(final_residual_stream))`. The Unembed is a linear map, and LayerNorm is approximately a linear map, so we can decompose the logits into the sum of the contributions of each component, and look at which components contribute the most to the logit of the correct token! This is called **direct logit attribution**. Here we look at the direct attribution to the logit difference!\n", + "\n", + "
(*) Background and motivation of the logit difference\n", + "\n", + "Logit difference is actually a *really* nice and elegant metric and is a particularly nice aspect of the setup of Indirect Object Identification. In general, there are two natural ways to interpret the model's outputs: the output logits, or the output log probabilities (or probabilities). \n", + "\n", + "The logits are much nicer and easier to understand, as noted above. However, the model is trained to optimize the cross-entropy loss (the average of log probability of the correct token). This means it does not directly optimize the logits, and indeed if the model adds an arbitrary constant to every logit, the log probabilities are unchanged. \n", + "\n", + "But `log_probs == logits.log_softmax(dim=-1) == logits - logsumexp(logits)`, and so `log_probs(\" Mary\") - log_probs(\" John\") = logits(\" Mary\") - logits(\" John\")` - the ability to add an arbitrary constant cancels out!\n", + "\n", + "Further, the metric helps us isolate the precise capability we care about - figuring out *which* name is the Indirect Object. There are many other components of the task - deciding whether to return an article (the) or pronoun (her) or name, realising that the sentence wants a person next at all, etc. By taking the logit difference we control for all of that.\n", + "\n", + "Our metric is further refined, because each prompt is repeated twice, for each possible indirect object. This controls for irrelevant behaviour such as the model learning that John is a more frequent token than Mary (this actually happens! The final layernorm bias increases the John logit by 1 relative to the Mary logit)\n", + "\n", + "
\n", + "\n", + "
Ignoring LayerNorm\n", + "\n", + "LayerNorm is an analogous normalization technique to BatchNorm (that's friendlier to massive parallelization) that transformers use. Every time a transformer layer reads information from the residual stream, it applies a LayerNorm to normalize the vector at each position (translating to set the mean to 0 and scaling to set the variance to 1) and then applying a learned vector of weights and biases to scale and translate the normalized vector. This is *almost* a linear map, apart from the scaling step, because that divides by the norm of the vector and the norm is not a linear function. (The `fold_ln` flag when loading a model factors out all the linear parts).\n", + "\n", + "But if we fixed the scale factor, the LayerNorm would be fully linear. And the scale of the residual stream is a global property that's a function of *all* components of the stream, while in practice there is normally just a few directions relevant to any particular component, so in practice this is an acceptable approximation. So when doing direct logit attribution we use the `apply_ln` flag on the `cache` to apply the global layernorm scaling factor to each constant. See [my clean GPT-2 implementation](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/clean-transformer-demo/Clean_Transformer_Demo.ipynb#scrollTo=Clean_Transformer_Implementation) for more on LayerNorm.\n", + "
" ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Getting an output logit is equivalent to projecting onto a direction in the residual stream. We use `model.tokens_to_residual_directions` to map the answer tokens to that direction, and then convert this to a logit difference direction for each batch" ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } }, - "title": { - "text": "Logit Difference From Patched MLP Layer" + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Answer residual directions shape: torch.Size([8, 2, 768])\n", + "Logit difference directions shape: torch.Size([8, 768])\n" + ] + } + ], + "source": [ + "answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)\n", + "print(\"Answer residual directions shape:\", answer_residual_directions.shape)\n", + "logit_diff_directions = (\n", + " answer_residual_directions[:, 0] - answer_residual_directions[:, 1]\n", + ")\n", + "print(\"Logit difference directions shape:\", logit_diff_directions.shape)" + ] }, - "xaxis": { - "anchor": "y", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "scaleanchor": "y", - "title": { - "text": "Position" - } + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To verify that this works, we can apply this to the final residual stream for our cached prompts (after applying LayerNorm scaling) and verify that we get the same answer. \n", + "\n", + "
Technical details\n", + "\n", + "`logits = Unembed(LayerNorm(final_residual_stream))`, so we technically need to account for the centering, and then learned translation and scaling of the layernorm, not just the variance 1 scaling. \n", + "\n", + "The centering is accounted for with the preprocessing flag `center_writing_weights` which ensures that every weight matrix writing to the residual stream has mean zero. \n", + "\n", + "The learned scaling is folded into the unembedding weights `model.unembed.W_U` via `W_U_fold = layer_norm.weights[:, None] * unembed.W_U`\n", + "\n", + "The learned translation is folded to `model.unembed.b_U`, a bias added to the logits (note that GPT-2 is not trained with an existing `b_U`). This roughly represents unigram statistics. But we can ignore this because each prompt occurs twice with names in the opposite order, so this perfectly cancels out. \n", + "\n", + "Note that rather than using layernorm scaling we could just study cache[\"ln_final.hook_normalised\"]\n", + "\n", + "
" + ] }, - "yaxis": { - "anchor": "x", - "autorange": "reversed", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "title": { - "text": "Layer" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "imshow(\n", - " patched_mlp_diff,\n", - " x=prompt_position_labels,\n", - " title=\"Logit Difference From Patched MLP Layer\",\n", - " labels={\"x\": \"Position\", \"y\": \"Layer\"},\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Heads\n", - "\n", - "We can refine the above analysis by patching in individual heads! This is somewhat more annoying, because there are now three dimensions (head_index, position and layer), so for now lets patch in a head's output across all positions.\n", - "\n", - "The easiest way to do this is to patch in the activation `z`, the \"mixed value\" of the attention head. That is, the average of all previous values weighted by the attention pattern, ie the activation that is then multiplied by `W_O`, the output weights. " - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [], - "source": [ - "def patch_head_vector(\n", - " corrupted_head_vector: Float[torch.Tensor, \"batch pos head_index d_head\"],\n", - " hook,\n", - " head_index,\n", - " clean_cache,\n", - "):\n", - " corrupted_head_vector[:, :, head_index, :] = clean_cache[hook.name][\n", - " :, :, head_index, :\n", - " ]\n", - " return corrupted_head_vector\n", - "\n", - "\n", - "patched_head_z_diff = torch.zeros(\n", - " model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32\n", - ")\n", - "for layer in range(model.cfg.n_layers):\n", - " for head_index in range(model.cfg.n_heads):\n", - " hook_fn = partial(patch_head_vector, head_index=head_index, clean_cache=cache)\n", - " patched_logits = model.run_with_hooks(\n", - " corrupted_tokens,\n", - " fwd_hooks=[(utils.get_act_name(\"z\", layer, \"attn\"), hook_fn)],\n", - " return_type=\"logits\",\n", - " )\n", - " patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)\n", - "\n", - " patched_head_z_diff[layer, head_index] = normalize_patched_logit_diff(\n", - " patched_logit_diff\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can now see that, in addition to the name mover heads identified before, in mid-late layers the heads L8H6, L8H10, L7H9 matter and are presumably responsible for moving information from the second subject to the final token. And heads L5H5, L6H9, L3H0 also matter a lot, and are presumably involved in detecting duplicated tokens." - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ { - "coloraxis": "coloraxis", - "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", - "name": "0", - "type": "heatmap", - "xaxis": "x", - "yaxis": "y", - "z": [ - [ - 0.0009487751522101462, - 0.016124747693538666, - 0.0018548924708738923, - 0.0034389030188322067, - -0.00982347596436739, - 0.011058605276048183, - -0.004063969012349844, - -0.0015792781487107277, - -0.0012082795146852732, - 0.003828897839412093, - -0.004256919026374817, - -0.0011422622483223677 - ], - [ - -0.0010771177476271987, - -0.00037898647133260965, - 2.5171791548928013e-06, - -0.00026067905128002167, - -0.00014146546891424805, - 0.0038321535103023052, - -0.0004293300735298544, - -0.00142992555629462, - -0.0009228314156644046, - 0.0006944393389858305, - 0.00043302192352712154, - -0.0035714071709662676 - ], - [ - -0.0004967569257132709, - 0.0008057993836700916, - 0.0005424688570201397, - -0.0005309234256856143, - -0.0007159864180721343, - -0.0010389237431809306, - -0.0009490771917626262, - -8.649027586216107e-05, - 0.0002766547549981624, - 0.0021084228064864874, - -0.0001975146442418918, - -0.0016405630158260465 - ], - [ - 0.1162627637386322, - 0.0002507446042727679, - -0.0014675153652206063, - -0.00039680811460129917, - 0.018962211906909943, - -0.00018764731066767126, - 0.011170871555805206, - -0.0013301445869728923, - -0.0007356539717875421, - -0.00030253134900704026, - -0.00014683544577565044, - -0.00022228369198273867 - ], - [ - -0.001650598249398172, - 0.0002927311579696834, - -0.00143563118763268, - 0.03084198758006096, - -0.007432155776768923, - -0.00028236035723239183, - 0.006017433945089579, - -0.011007187888026237, - -0.001266107545234263, - 0.0014901700196787715, - -0.0001800622121663764, - 0.002944394713267684 - ], - [ - -0.004211106337606907, - 0.0029597999528050423, - 0.002045023487880826, - 0.0013397098518908024, - -0.0012190865818411112, - 0.34349915385246277, - 0.0005632104002870619, - -0.0001262281439267099, - -0.00515326950699091, - 0.016240738332271576, - 0.01709030382335186, - -0.004175194539129734 - ], - [ - 0.039775289595127106, - 0.015226684510707855, - -0.0010229480685666203, - 0.0008072761120274663, - -0.004935584031045437, - -0.002123525831848383, - -0.014274083077907562, - 0.0013746818294748664, - 0.0014838266652077436, - 0.1302703619003296, - -0.00033616088330745697, - 0.0012919505825266242 - ], - [ - 0.00037177055492065847, - 0.019514480605721474, - 0.00022255218937061727, - 0.124249167740345, - -0.00040352059295400977, - -0.007652895525097847, - 0.0013010123511776328, - -0.0011253133416175842, - -0.007449474185705185, - 0.19224143028259277, - -0.003275118535384536, - -0.0005017912480980158 - ], - [ - -0.001007912098430097, - 3.091096004936844e-05, - -0.0008595998515374959, - 0.012359987013041973, - -0.0004041247011628002, - -0.004328910261392593, - 0.3185553252696991, - 0.002330605871975422, - 0.0021182901691645384, - 0.0001405928487656638, - 0.2779357433319092, - 0.005738262087106705 - ], - [ - 0.0058898297138512135, - -0.0009689796715974808, - 0.00912561360746622, - 0.020675739273428917, - -0.03700518235564232, - 0.014263041317462921, - -0.04828466475009918, - 0.05834139883518219, - 0.0006514795240946114, - 0.26360899209976196, - 0.0004918567719869316, - -0.00261044898070395 - ], - [ - 0.08374208211898804, - 0.020676210522651672, - -0.003743582172319293, - 0.01085072010755539, - -0.001096583902835846, - 0.00047430366976186633, - 0.04818058758974075, - -0.4799128472805023, - 0.00018429107149131596, - 0.011861988343298435, - 0.06088569387793541, - 0.0008461413672193885 - ], - [ - 0.005328264087438583, - -0.011493473313748837, - -0.11350836604833603, - 0.006329597905278206, - 0.00031669469899497926, - -0.0011600167490541935, - -0.022669579833745956, - 0.004070379305630922, - 0.0073160636238753796, - -0.00834545586258173, - -0.27817651629447937, - 0.0036344374530017376 - ] - ] - } - ], - "layout": { - "coloraxis": { - "cmid": 0, - "colorscale": [ - [ - 0, - "rgb(103,0,31)" - ], - [ - 0.1, - "rgb(178,24,43)" - ], - [ - 0.2, - "rgb(214,96,77)" - ], - [ - 0.3, - "rgb(244,165,130)" - ], - [ - 0.4, - "rgb(253,219,199)" - ], - [ - 0.5, - "rgb(247,247,247)" - ], - [ - 0.6, - "rgb(209,229,240)" - ], - [ - 0.7, - "rgb(146,197,222)" - ], - [ - 0.8, - "rgb(67,147,195)" - ], - [ - 0.9, - "rgb(33,102,172)" - ], - [ - 1, - "rgb(5,48,97)" - ] - ] + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Final residual stream shape: torch.Size([8, 15, 768])\n", + "Calculated average logit diff: 3.552\n", + "Original logit difference: 3.552\n" + ] + } + ], + "source": [ + "# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].\n", + "final_residual_stream = cache[\"resid_post\", -1]\n", + "print(\"Final residual stream shape:\", final_residual_stream.shape)\n", + "final_token_residual_stream = final_residual_stream[:, -1, :]\n", + "# Apply LayerNorm scaling\n", + "# pos_slice is the subset of the positions we take - here the final token of each prompt\n", + "scaled_final_token_residual_stream = cache.apply_ln_to_stack(\n", + " final_token_residual_stream, layer=-1, pos_slice=-1\n", + ")\n", + "\n", + "average_logit_diff = einsum(\n", + " \"batch d_model, batch d_model -> \",\n", + " scaled_final_token_residual_stream,\n", + " logit_diff_directions,\n", + ") / len(prompts)\n", + "print(\"Calculated average logit diff:\", round(average_logit_diff.item(), 3))\n", + "print(\"Original logit difference:\", round(original_average_logit_diff.item(), 3))" + ] }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Logit Lens" ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now decompose the residual stream! First we apply a technique called the [**logit lens**](https://www.alignmentforum.org/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens) - this looks at the residual stream after each layer and calculates the logit difference from that. This simulates what happens if we delete all subsequence layers. " ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "def residual_stack_to_logit_diff(\n", + " residual_stack: Float[torch.Tensor, \"components batch d_model\"],\n", + " cache: ActivationCache,\n", + ") -> float:\n", + " scaled_residual_stack = cache.apply_ln_to_stack(\n", + " residual_stack, layer=-1, pos_slice=-1\n", + " )\n", + " return einsum(\n", + " \"... batch d_model, batch d_model -> ...\",\n", + " scaled_residual_stack,\n", + " logit_diff_directions,\n", + " ) / len(prompts)" ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } }, - "title": { - "text": "Logit Difference From Patched Head Output" + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Fascinatingly, we see that the model is utterly unable to do the task until layer 7, almost all performance comes from attention layer 9, and performance actually *decreases* from there.\n", + "\n", + "**Note:** Hover over each data point to see what residual stream position it's from!\n", + "\n", + "
Details on `accumulated_resid`\n", + "**Key:** `n_pre` means the residual stream at the start of layer n, `n_mid` means the residual stream after the attention part of layer n (`n_post` is the same as `n+1_pre` so is not included)\n", + "\n", + "* `layer` is the layer for which we input the residual stream (this is used to identify *which* layer norm scaling factor we want)\n", + "* `incl_mid` is whether to include the residual stream in the middle of a layer, ie after attention & before MLP\n", + "* `pos_slice` is the subset of the positions used. See `utils.Slice` for details on the syntax.\n", + "* return_labels is whether to return the labels for each component returned (useful for plotting)\n", + "
" + ] }, - "xaxis": { - "anchor": "y", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "scaleanchor": "y", - "title": { - "text": "Head" - } + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "hovertemplate": "%{hovertext}

x=%{x}
y=%{y}", + "hovertext": [ + "0_pre", + "0_mid", + "1_pre", + "1_mid", + "2_pre", + "2_mid", + "3_pre", + "3_mid", + "4_pre", + "4_mid", + "5_pre", + "5_mid", + "6_pre", + "6_mid", + "7_pre", + "7_mid", + "8_pre", + "8_mid", + "9_pre", + "9_mid", + "10_pre", + "10_mid", + "11_pre", + "11_mid", + "final_post" + ], + "legendgroup": "", + "line": { + "color": "#636efa", + "dash": "solid" + }, + "marker": { + "symbol": "circle" + }, + "mode": "lines", + "name": "", + "orientation": "v", + "showlegend": false, + "type": "scatter", + "x": [ + 0, + 0.5, + 1, + 1.5, + 2, + 2.5, + 3, + 3.5, + 4, + 4.5, + 5, + 5.5, + 6, + 6.5, + 7, + 7.5, + 8, + 8.5, + 9, + 9.5, + 10, + 10.5, + 11, + 11.5, + 12 + ], + "xaxis": "x", + "y": [ + 0.000012937933206558228, + -0.006643360480666161, + -0.007525032386183739, + -0.009075596928596497, + -0.008736769668757915, + -0.008685456588864326, + -0.006480347365140915, + -0.007939882576465607, + -0.009661720134317875, + -0.015095856040716171, + -0.01419061329215765, + -0.019930001348257065, + -0.00912435818463564, + -0.027298055589199066, + -0.02985510788857937, + 0.2497255504131317, + 0.250558078289032, + 0.45005205273628235, + 0.45996904373168945, + 5.02545166015625, + 5.142900466918945, + 4.730565071105957, + 4.887058258056641, + 3.445383071899414, + 3.5518720149993896 + ], + "yaxis": "y" + } + ], + "layout": { + "legend": { + "tracegroupgap": 0 + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Logit Difference From Accumulate Residual Stream" + }, + "xaxis": { + "anchor": "y", + "domain": [ + 0, + 1 + ], + "title": { + "text": "x" + } + }, + "yaxis": { + "anchor": "x", + "domain": [ + 0, + 1 + ], + "title": { + "text": "y" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "accumulated_residual, labels = cache.accumulated_resid(\n", + " layer=-1, incl_mid=True, pos_slice=-1, return_labels=True\n", + ")\n", + "logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, cache)\n", + "line(\n", + " logit_lens_logit_diffs,\n", + " x=np.arange(model.cfg.n_layers * 2 + 1) / 2,\n", + " hover_name=labels,\n", + " title=\"Logit Difference From Accumulate Residual Stream\",\n", + ")" + ] }, - "yaxis": { - "anchor": "x", - "autorange": "reversed", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "title": { - "text": "Layer" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "imshow(\n", - " patched_head_z_diff,\n", - " title=\"Logit Difference From Patched Head Output\",\n", - " labels={\"x\": \"Head\", \"y\": \"Layer\"},\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Decomposing Heads" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Decomposing attention layers into patching in individual heads has already helped us localise the behaviour a lot. But we can understand it further by decomposing heads. An attention head consists of two semi-independent operations - calculating *where* to move information from and to (represented by the attention pattern and implemented via the QK-circuit) and calculating *what* information to move (represented by the value vectors and implemented by the OV circuit). We can disentangle which of these is important by patching in just the attention pattern *or* the value vectors. (See [A Mathematical Framework](https://transformer-circuits.pub/2021/framework/index.html) or [my walkthrough video](https://www.youtube.com/watch?v=KV5gbOmHbjU) for more on this decomposition. If you're not familiar with the details of how attention is implemented, I recommend checking out [my clean transformer implementation](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/clean-transformer-demo/Clean_Transformer_Demo.ipynb#scrollTo=3Pb0NYbZ900e) to see how the code works))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "First let's patch in the value vectors, to measure when figuring out what to move is important. . This has the same shape as z ([batch, pos, head_index, d_head]) so we can reuse the same hook." - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": {}, - "outputs": [], - "source": [ - "patched_head_v_diff = torch.zeros(\n", - " model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32\n", - ")\n", - "for layer in range(model.cfg.n_layers):\n", - " for head_index in range(model.cfg.n_heads):\n", - " hook_fn = partial(patch_head_vector, head_index=head_index, clean_cache=cache)\n", - " patched_logits = model.run_with_hooks(\n", - " corrupted_tokens,\n", - " fwd_hooks=[(utils.get_act_name(\"v\", layer, \"attn\"), hook_fn)],\n", - " return_type=\"logits\",\n", - " )\n", - " patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)\n", - "\n", - " patched_head_v_diff[layer, head_index] = normalize_patched_logit_diff(\n", - " patched_logit_diff\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can plot this as a heatmap and it's initially hard to interpret." - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ { - "coloraxis": "coloraxis", - "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", - "name": "0", - "type": "heatmap", - "xaxis": "x", - "yaxis": "y", - "z": [ - [ - -0.00019892427371814847, - 0.005339574534446001, - 0.0006527548539452255, - 0.003504416672512889, - -0.00898387935012579, - 0.0034814265090972185, - -0.0008631910313852131, - -3.406582254683599e-05, - 0.0005166929331608117, - 0.00044255363172851503, - -0.0039068968035280704, - -0.0001880836207419634 - ], - [ - -0.0004399022145662457, - -0.00044510437874123454, - -6.73597096465528e-05, - 7.242763240355998e-05, - -3.6549441574607044e-05, - -0.0019323208834975958, - -0.0001572397886775434, - 1.6143509128596634e-05, - 0.00020593880617525429, - 0.000336798548232764, - 0.0003515324497129768, - -0.0005669358652085066 - ], - [ - 0.00021013410878367722, - -0.0007199132232926786, - 0.0004868560063187033, - -0.0005974104860797524, - -0.0005921411793678999, - -0.0005443819100037217, - -0.000227552984142676, - -0.0004809825913980603, - 0.00020570388005580753, - 0.001183376181870699, - -0.0003574058646336198, - -0.0009104468626901507 - ], - [ - 0.0010395278222858906, - -0.00012042184971505776, - -7.762980385450646e-05, - -0.0007275318494066596, - -0.001310007064603269, - -0.0023108376190066338, - 0.010987084358930588, - -5.0712766096694395e-05, - 0.00014314358122646809, - 0.00015069512301124632, - -7.957642083056271e-05, - -2.0238119759596884e-05 - ], - [ - -0.0005373673629947007, - -0.0008137872209772468, - -0.00013334336108528078, - 0.030609702691435814, - -0.007185807917267084, - 0.000148916311445646, - 0.0013340713921934366, - -0.01142292469739914, - -0.0005336419562809169, - 0.0005126654868945479, - 0.00037344868178479373, - 0.0029547319281846285 - ], - [ - 8.22278525447473e-06, - 6.477540864580078e-06, - 0.0015973682748153806, - 0.00034015480196103454, - -0.0012577504385262728, - -5.450531898532063e-05, - 0.0006331544718705118, - -0.00027081489679403603, - 7.427356467815116e-05, - -0.006704355590045452, - 0.003175975289195776, - -0.0017300404142588377 - ], - [ - 0.04863045737147331, - 0.015314852818846703, - -0.0004648726317100227, - -0.00011676354915834963, - -4.930314753437415e-05, - -0.003952810075134039, - -0.01737578585743904, - -0.00015421917487401515, - 0.0012194222072139382, - -0.00018090127559844404, - -0.00042647725786082447, - 0.00012334177154116333 - ], - [ - -2.956846401502844e-05, - -0.0013855225406587124, - -0.00012129446986364201, - 0.1332160234451294, - -0.00024490474606864154, - -0.007315828464925289, - 0.00033297244226559997, - -0.000795092957559973, - -0.007938209921121597, - 0.208413764834404, - -0.00019127204723190516, - -0.00020650937221944332 - ], - [ - -0.0020483459811657667, - -0.0003764357534237206, - -0.0033135139383375645, - -0.009666135534644127, - -0.00031723169377073646, - -0.005141589790582657, - 0.31717124581336975, - 0.0028427678626030684, - 0.0004723234742414206, - -0.0011529687326401472, - 0.2726709246635437, - -0.003175639547407627 - ], - [ - -0.00043929810635745525, - 5.7089622714556754e-05, - -0.0020629793871194124, - 0.020066648721694946, - -0.007871017791330814, - 0.011316264048218727, - 0.003056862158700824, - 0.06856372952461243, - -0.002747517777606845, - -0.009279227815568447, - 0.000506624230183661, - -0.0013159140944480896 - ], - [ - -0.012957162223756313, - -0.0030454176012426615, - -0.01792328804731369, - -0.0043589151464402676, - -0.0011521632550284266, - 0.0004999117809347808, - -0.0031131464056670666, - 0.019585633650422096, - 4.34632929682266e-05, - 0.01297028549015522, - -0.007695754989981651, - -0.0009146086522378027 - ], - [ - 0.004100752994418144, - -0.020459463819861412, - -0.035875942558050156, - 0.014656225219368935, - 0.0008441276149824262, - 0.0017804511589929461, - -0.01804223284125328, - 0.003519016318023205, - 0.008253024891018867, - -0.0017665562918409705, - 0.044167667627334595, - 0.006474285386502743 - ] - ] - } - ], - "layout": { - "coloraxis": { - "cmid": 0, - "colorscale": [ - [ - 0, - "rgb(103,0,31)" - ], - [ - 0.1, - "rgb(178,24,43)" - ], - [ - 0.2, - "rgb(214,96,77)" - ], - [ - 0.3, - "rgb(244,165,130)" - ], - [ - 0.4, - "rgb(253,219,199)" - ], - [ - 0.5, - "rgb(247,247,247)" - ], - [ - 0.6, - "rgb(209,229,240)" - ], - [ - 0.7, - "rgb(146,197,222)" - ], - [ - 0.8, - "rgb(67,147,195)" - ], - [ - 0.9, - "rgb(33,102,172)" - ], - [ - 1, - "rgb(5,48,97)" - ] - ] + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Layer Attribution" + ] }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can repeat the above analysis but for each layer (this is equivalent to the differences between adjacent residual streams)\n", + "\n", + "Note: Annoying terminology overload - layer k of a transformer means the kth **transformer block**, but each block consists of an **attention layer** (to move information around) *and* an **MLP layer** (to process information). \n", + "\n", + "We see that only attention layers matter, which makes sense! The IOI task is about moving information around (ie moving the correct name and not the incorrect name), and less about processing it. And again we note that attention layer 9 improves things a lot, while attention 10 and attention 11 *decrease* performance" ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "hovertemplate": "%{hovertext}

x=%{x}
y=%{y}", + "hovertext": [ + "embed", + "pos_embed", + "0_attn_out", + "0_mlp_out", + "1_attn_out", + "1_mlp_out", + "2_attn_out", + "2_mlp_out", + "3_attn_out", + "3_mlp_out", + "4_attn_out", + "4_mlp_out", + "5_attn_out", + "5_mlp_out", + "6_attn_out", + "6_mlp_out", + "7_attn_out", + "7_mlp_out", + "8_attn_out", + "8_mlp_out", + "9_attn_out", + "9_mlp_out", + "10_attn_out", + "10_mlp_out", + "11_attn_out", + "11_mlp_out" + ], + "legendgroup": "", + "line": { + "color": "#636efa", + "dash": "solid" + }, + "marker": { + "symbol": "circle" + }, + "mode": "lines", + "name": "", + "orientation": "v", + "showlegend": false, + "type": "scatter", + "x": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25 + ], + "xaxis": "x", + "y": [ + -0.00028366726473905146, + 0.00029660604195669293, + -0.0066563040018081665, + -0.0008816685294732451, + -0.0015505650080740452, + 0.00033882574643939734, + 0.00005131529178470373, + 0.0022051138803362846, + -0.0014595506945624948, + -0.0017218313878402114, + -0.005434143822640181, + 0.0009052485693246126, + -0.0057394010946154594, + 0.010805649682879448, + -0.018173698335886, + -0.002557049971073866, + 0.27958065271377563, + 0.0008325176313519478, + 0.19949400424957275, + 0.00991708692163229, + 4.565483093261719, + 0.11744903028011322, + -0.4123360514640808, + 0.15649384260177612, + -1.4416757822036743, + 0.10648896545171738 + ], + "yaxis": "y" + } + ], + "layout": { + "legend": { + "tracegroupgap": 0 + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Logit Difference From Each Layer" + }, + "xaxis": { + "anchor": "y", + "domain": [ + 0, + 1 + ], + "title": { + "text": "x" + } + }, + "yaxis": { + "anchor": "x", + "domain": [ + 0, + 1 + ], + "title": { + "text": "y" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "per_layer_residual, labels = cache.decompose_resid(\n", + " layer=-1, pos_slice=-1, return_labels=True\n", + ")\n", + "per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, cache)\n", + "line(per_layer_logit_diffs, hover_name=labels, title=\"Logit Difference From Each Layer\")" ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Head Attribution" ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } }, - "title": { - "text": "Logit Difference From Patched Head Value" + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can further break down the output of each attention layer into the sum of the outputs of each attention head. Each attention layer consists of 12 heads, which each act independently and additively.\n", + "\n", + "
Decomposing attention output into sums of heads \n", + "The standard way to compute the output of an attention layer is by concatenating the mixed values of each head, and multiplying by a big output weight matrix. But as described in [A Mathematical Framework](https://transformer-circuits.pub/2021/framework/index.html) this is equivalent to splitting the output weight matrix into a per-head output (here `model.blocks[k].attn.W_O`) and adding them up (including an overall bias term for the entire layer)\n", + "
\n", + "\n", + "We see that only a few heads really matter - heads L9H6 and L9H9 contribute a lot positively (explaining why attention layer 9 is so important), while heads L10H7 and L11H10 contribute a lot negatively (explaining why attention layer 10 and layer 11 are actively harmful). These correspond to (some of) the name movers and negative name movers discussed in the paper. There are also several heads that matter positively or negatively but less strongly (other name movers and backup name movers)\n", + "\n", + "There are a few meta observations worth making here - our model has 144 heads, yet we could localise this behaviour to a handful of specific heads, using straightforward, general techniques. This supports the claim in [A Mathematical Framework](https://transformer-circuits.pub/2021/framework/index.html) that attention heads are the right level of abstraction to understand attention. It also really surprising that there are *negative* heads - eg L10H7 makes the incorrect logit 7x *more* likely. I'm not sure what's going on there, though the paper discusses some possibilities." + ] }, - "xaxis": { - "anchor": "y", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "scaleanchor": "y", - "title": { - "text": "Head" - } + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tried to stack head results when they weren't cached. Computing head results now\n" + ] + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + -0.0020563392899930477, + -0.0005101899732835591, + 0.0004685786843765527, + 0.00012512074317783117, + -0.0006028738571330905, + -0.0002429460291750729, + -0.0023189077619463205, + -0.002758360467851162, + 0.000564602785743773, + 0.0009697531932033598, + -0.0002504526637494564, + 0.000004737317794933915 + ], + [ + -0.0010070882271975279, + 0.00039470894262194633, + -0.00154874159488827, + 0.0014034928753972054, + -0.0012653048615902662, + -0.0011358022456988692, + -0.00281596090644598, + -0.0029645217582583427, + 0.0029190476052463055, + 0.0025743592996150255, + 0.00036239007022231817, + 0.0017548729665577412 + ], + [ + 0.0005569400964304805, + -0.001126631861552596, + -0.0017353934235870838, + -0.0014514457434415817, + -0.00028735760133713484, + 0.0017211002996191382, + 0.0026658899150788784, + 0.00311466702260077, + 0.0005667927907779813, + -0.003666515462100506, + -0.0018847601022571325, + 0.000007039372576400638 + ], + [ + -0.0007264417363330722, + 0.00011364505917299539, + 0.0014301587361842394, + 0.0007490540738217533, + 0.0020184689201414585, + 0.0007436950691044331, + -0.00046178390039131045, + -0.0039057559333741665, + 0.0011406694538891315, + -0.00004022853681817651, + -0.0013293239753693342, + -0.0017636751290410757 + ], + [ + -0.0028280913829803467, + 0.00033634810824878514, + -0.0014248639345169067, + -0.003777273464947939, + 0.0015998880844563246, + 0.0002989505883306265, + -0.000804675742983818, + 0.002038792008534074, + -0.0015593919670209289, + -0.0006436670082621276, + 0.0011168173514306545, + -0.00035012533771805465 + ], + [ + 0.0011338205076754093, + 0.0011259170714765787, + -0.002516670385375619, + -0.0014790185960009694, + 0.0003878737334161997, + -0.00006408110493794084, + -0.0005096744280308485, + -0.0008840755908749998, + 0.0006398351397365332, + -0.0010097370250150561, + -0.006759158335626125, + 0.0033667823299765587 + ], + [ + -0.01514742337167263, + -0.0021350777242332697, + 0.002593174111098051, + -0.00042678468162193894, + -0.005558924749493599, + 0.0026658528950065374, + 0.006411008536815643, + -0.003826778382062912, + -0.0003843410813715309, + -0.0016430341638624668, + -0.0013344454346224666, + -0.0000920506427064538 + ], + [ + -0.00009476230479776859, + -0.0057889921590685844, + -0.0006383581785485148, + 0.13493388891220093, + -0.001768707763403654, + -0.018917907029390335, + 0.003873429261147976, + -0.0021450775675475597, + -0.010327338241040707, + 0.18325845897197723, + -0.0007747983909212053, + -0.00104526337236166 + ], + [ + -0.003833949100226164, + -0.0008046097937040031, + -0.012673400342464447, + 0.00804573018103838, + 0.003604492638260126, + -0.009398287162184715, + -0.08272082358598709, + 0.003555194940418005, + -0.018404025584459305, + 0.0017587244510650635, + 0.2896133363246918, + 0.022854052484035492 + ], + [ + 0.08595258742570877, + -0.0006932877004146576, + 0.06817055493593216, + 0.013111240230500698, + -0.021098043769598007, + 0.05112447217106819, + 1.3844914436340332, + 0.045836858451366425, + -0.03830280900001526, + 2.985445976257324, + 0.0019662054255604744, + -0.008030137047171593 + ], + [ + 0.5608693957328796, + 0.17083050310611725, + -0.03361757844686508, + 0.05821544677019119, + -0.0024530249647796154, + 0.0018771197646856308, + 0.28827205300331116, + -1.8986485004425049, + -0.0015286931302398443, + -0.035129792988300323, + 0.4802178740501404, + -0.0009115453576669097 + ], + [ + 0.016075748950242996, + -0.03986122086644173, + -0.3879126012325287, + 0.011123123578727245, + -0.005477819126099348, + -0.0025129620917141438, + -0.08056175708770752, + 0.007518616039305925, + 0.0430111438035965, + -0.040082238614559174, + -0.9702364802360535, + 0.011862239800393581 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Logit Difference From Each Head" + }, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "scaleanchor": "y", + "title": { + "text": "Head" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Layer" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "per_head_residual, labels = cache.stack_head_results(\n", + " layer=-1, pos_slice=-1, return_labels=True\n", + ")\n", + "per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)\n", + "per_head_logit_diffs = einops.rearrange(\n", + " per_head_logit_diffs,\n", + " \"(layer head_index) -> layer head_index\",\n", + " layer=model.cfg.n_layers,\n", + " head_index=model.cfg.n_heads,\n", + ")\n", + "imshow(\n", + " per_head_logit_diffs,\n", + " labels={\"x\": \"Head\", \"y\": \"Layer\"},\n", + " title=\"Logit Difference From Each Head\",\n", + ")" + ] }, - "yaxis": { - "anchor": "x", - "autorange": "reversed", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "title": { - "text": "Layer" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "imshow(\n", - " patched_head_v_diff,\n", - " title=\"Logit Difference From Patched Head Value\",\n", - " labels={\"x\": \"Head\", \"y\": \"Layer\"},\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "But it's very easy to interpret if we plot a scatter plot against patching head outputs. Here we see that the earlier heads (L5H5, L6H9, L3H0) and late name movers (L9H9, L10H7, L11H10) don't matter at all now, while the mid-late heads (L8H6, L8H10, L7H9) do. \n", - "\n", - "Meta lesson: Plot things early, often and in diverse ways as you explore a model's internals!" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ { - "hovertemplate": "%{hovertext}

Value Patch=%{x}
Output Patch=%{y}
Layer=%{marker.color}", - "hovertext": [ - "L0H0", - "L0H1", - "L0H2", - "L0H3", - "L0H4", - "L0H5", - "L0H6", - "L0H7", - "L0H8", - "L0H9", - "L0H10", - "L0H11", - "L1H0", - "L1H1", - "L1H2", - "L1H3", - "L1H4", - "L1H5", - "L1H6", - "L1H7", - "L1H8", - "L1H9", - "L1H10", - "L1H11", - "L2H0", - "L2H1", - "L2H2", - "L2H3", - "L2H4", - "L2H5", - "L2H6", - "L2H7", - "L2H8", - "L2H9", - "L2H10", - "L2H11", - "L3H0", - "L3H1", - "L3H2", - "L3H3", - "L3H4", - "L3H5", - "L3H6", - "L3H7", - "L3H8", - "L3H9", - "L3H10", - "L3H11", - "L4H0", - "L4H1", - "L4H2", - "L4H3", - "L4H4", - "L4H5", - "L4H6", - "L4H7", - "L4H8", - "L4H9", - "L4H10", - "L4H11", - "L5H0", - "L5H1", - "L5H2", - "L5H3", - "L5H4", - "L5H5", - "L5H6", - "L5H7", - "L5H8", - "L5H9", - "L5H10", - "L5H11", - "L6H0", - "L6H1", - "L6H2", - "L6H3", - "L6H4", - "L6H5", - "L6H6", - "L6H7", - "L6H8", - "L6H9", - "L6H10", - "L6H11", - "L7H0", - "L7H1", - "L7H2", - "L7H3", - "L7H4", - "L7H5", - "L7H6", - "L7H7", - "L7H8", - "L7H9", - "L7H10", - "L7H11", - "L8H0", - "L8H1", - "L8H2", - "L8H3", - "L8H4", - "L8H5", - "L8H6", - "L8H7", - "L8H8", - "L8H9", - "L8H10", - "L8H11", - "L9H0", - "L9H1", - "L9H2", - "L9H3", - "L9H4", - "L9H5", - "L9H6", - "L9H7", - "L9H8", - "L9H9", - "L9H10", - "L9H11", - "L10H0", - "L10H1", - "L10H2", - "L10H3", - "L10H4", - "L10H5", - "L10H6", - "L10H7", - "L10H8", - "L10H9", - "L10H10", - "L10H11", - "L11H0", - "L11H1", - "L11H2", - "L11H3", - "L11H4", - "L11H5", - "L11H6", - "L11H7", - "L11H8", - "L11H9", - "L11H10", - "L11H11" - ], - "legendgroup": "", - "marker": { - "color": [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 3, - 3, - 3, - 3, - 3, - 3, - 3, - 3, - 3, - 3, - 3, - 3, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 6, - 6, - 6, - 6, - 6, - 6, - 6, - 6, - 6, - 6, - 6, - 6, - 7, - 7, - 7, - 7, - 7, - 7, - 7, - 7, - 7, - 7, - 7, - 7, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - 9, - 9, - 9, - 9, - 9, - 9, - 9, - 9, - 9, - 9, - 9, - 9, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 11, - 11, - 11, - 11, - 11, - 11, - 11, - 11, - 11, - 11, - 11, - 11 - ], - "coloraxis": "coloraxis", - "symbol": "circle" - }, - "mode": "markers", - "name": "", - "orientation": "v", - "showlegend": false, - "type": "scatter", - "x": [ - -0.00019892427371814847, - 0.005339574534446001, - 0.0006527548539452255, - 0.003504416672512889, - -0.00898387935012579, - 0.0034814265090972185, - -0.0008631910313852131, - -3.406582254683599e-05, - 0.0005166929331608117, - 0.00044255363172851503, - -0.0039068968035280704, - -0.0001880836207419634, - -0.0004399022145662457, - -0.00044510437874123454, - -6.73597096465528e-05, - 7.242763240355998e-05, - -3.6549441574607044e-05, - -0.0019323208834975958, - -0.0001572397886775434, - 1.6143509128596634e-05, - 0.00020593880617525429, - 0.000336798548232764, - 0.0003515324497129768, - -0.0005669358652085066, - 0.00021013410878367722, - -0.0007199132232926786, - 0.0004868560063187033, - -0.0005974104860797524, - -0.0005921411793678999, - -0.0005443819100037217, - -0.000227552984142676, - -0.0004809825913980603, - 0.00020570388005580753, - 0.001183376181870699, - -0.0003574058646336198, - -0.0009104468626901507, - 0.0010395278222858906, - -0.00012042184971505776, - -7.762980385450646e-05, - -0.0007275318494066596, - -0.001310007064603269, - -0.0023108376190066338, - 0.010987084358930588, - -5.0712766096694395e-05, - 0.00014314358122646809, - 0.00015069512301124632, - -7.957642083056271e-05, - -2.0238119759596884e-05, - -0.0005373673629947007, - -0.0008137872209772468, - -0.00013334336108528078, - 0.030609702691435814, - -0.007185807917267084, - 0.000148916311445646, - 0.0013340713921934366, - -0.01142292469739914, - -0.0005336419562809169, - 0.0005126654868945479, - 0.00037344868178479373, - 0.0029547319281846285, - 8.22278525447473e-06, - 6.477540864580078e-06, - 0.0015973682748153806, - 0.00034015480196103454, - -0.0012577504385262728, - -5.450531898532063e-05, - 0.0006331544718705118, - -0.00027081489679403603, - 7.427356467815116e-05, - -0.006704355590045452, - 0.003175975289195776, - -0.0017300404142588377, - 0.04863045737147331, - 0.015314852818846703, - -0.0004648726317100227, - -0.00011676354915834963, - -4.930314753437415e-05, - -0.003952810075134039, - -0.01737578585743904, - -0.00015421917487401515, - 0.0012194222072139382, - -0.00018090127559844404, - -0.00042647725786082447, - 0.00012334177154116333, - -2.956846401502844e-05, - -0.0013855225406587124, - -0.00012129446986364201, - 0.1332160234451294, - -0.00024490474606864154, - -0.007315828464925289, - 0.00033297244226559997, - -0.000795092957559973, - -0.007938209921121597, - 0.208413764834404, - -0.00019127204723190516, - -0.00020650937221944332, - -0.0020483459811657667, - -0.0003764357534237206, - -0.0033135139383375645, - -0.009666135534644127, - -0.00031723169377073646, - -0.005141589790582657, - 0.31717124581336975, - 0.0028427678626030684, - 0.0004723234742414206, - -0.0011529687326401472, - 0.2726709246635437, - -0.003175639547407627, - -0.00043929810635745525, - 5.7089622714556754e-05, - -0.0020629793871194124, - 0.020066648721694946, - -0.007871017791330814, - 0.011316264048218727, - 0.003056862158700824, - 0.06856372952461243, - -0.002747517777606845, - -0.009279227815568447, - 0.000506624230183661, - -0.0013159140944480896, - -0.012957162223756313, - -0.0030454176012426615, - -0.01792328804731369, - -0.0043589151464402676, - -0.0011521632550284266, - 0.0004999117809347808, - -0.0031131464056670666, - 0.019585633650422096, - 4.34632929682266e-05, - 0.01297028549015522, - -0.007695754989981651, - -0.0009146086522378027, - 0.004100752994418144, - -0.020459463819861412, - -0.035875942558050156, - 0.014656225219368935, - 0.0008441276149824262, - 0.0017804511589929461, - -0.01804223284125328, - 0.003519016318023205, - 0.008253024891018867, - -0.0017665562918409705, - 0.044167667627334595, - 0.006474285386502743 - ], - "xaxis": "x", - "y": [ - 0.0009487751522101462, - 0.016124747693538666, - 0.0018548924708738923, - 0.0034389030188322067, - -0.00982347596436739, - 0.011058605276048183, - -0.004063969012349844, - -0.0015792781487107277, - -0.0012082795146852732, - 0.003828897839412093, - -0.004256919026374817, - -0.0011422622483223677, - -0.0010771177476271987, - -0.00037898647133260965, - 2.5171791548928013e-06, - -0.00026067905128002167, - -0.00014146546891424805, - 0.0038321535103023052, - -0.0004293300735298544, - -0.00142992555629462, - -0.0009228314156644046, - 0.0006944393389858305, - 0.00043302192352712154, - -0.0035714071709662676, - -0.0004967569257132709, - 0.0008057993836700916, - 0.0005424688570201397, - -0.0005309234256856143, - -0.0007159864180721343, - -0.0010389237431809306, - -0.0009490771917626262, - -8.649027586216107e-05, - 0.0002766547549981624, - 0.0021084228064864874, - -0.0001975146442418918, - -0.0016405630158260465, - 0.1162627637386322, - 0.0002507446042727679, - -0.0014675153652206063, - -0.00039680811460129917, - 0.018962211906909943, - -0.00018764731066767126, - 0.011170871555805206, - -0.0013301445869728923, - -0.0007356539717875421, - -0.00030253134900704026, - -0.00014683544577565044, - -0.00022228369198273867, - -0.001650598249398172, - 0.0002927311579696834, - -0.00143563118763268, - 0.03084198758006096, - -0.007432155776768923, - -0.00028236035723239183, - 0.006017433945089579, - -0.011007187888026237, - -0.001266107545234263, - 0.0014901700196787715, - -0.0001800622121663764, - 0.002944394713267684, - -0.004211106337606907, - 0.0029597999528050423, - 0.002045023487880826, - 0.0013397098518908024, - -0.0012190865818411112, - 0.34349915385246277, - 0.0005632104002870619, - -0.0001262281439267099, - -0.00515326950699091, - 0.016240738332271576, - 0.01709030382335186, - -0.004175194539129734, - 0.039775289595127106, - 0.015226684510707855, - -0.0010229480685666203, - 0.0008072761120274663, - -0.004935584031045437, - -0.002123525831848383, - -0.014274083077907562, - 0.0013746818294748664, - 0.0014838266652077436, - 0.1302703619003296, - -0.00033616088330745697, - 0.0012919505825266242, - 0.00037177055492065847, - 0.019514480605721474, - 0.00022255218937061727, - 0.124249167740345, - -0.00040352059295400977, - -0.007652895525097847, - 0.0013010123511776328, - -0.0011253133416175842, - -0.007449474185705185, - 0.19224143028259277, - -0.003275118535384536, - -0.0005017912480980158, - -0.001007912098430097, - 3.091096004936844e-05, - -0.0008595998515374959, - 0.012359987013041973, - -0.0004041247011628002, - -0.004328910261392593, - 0.3185553252696991, - 0.002330605871975422, - 0.0021182901691645384, - 0.0001405928487656638, - 0.2779357433319092, - 0.005738262087106705, - 0.0058898297138512135, - -0.0009689796715974808, - 0.00912561360746622, - 0.020675739273428917, - -0.03700518235564232, - 0.014263041317462921, - -0.04828466475009918, - 0.05834139883518219, - 0.0006514795240946114, - 0.26360899209976196, - 0.0004918567719869316, - -0.00261044898070395, - 0.08374208211898804, - 0.020676210522651672, - -0.003743582172319293, - 0.01085072010755539, - -0.001096583902835846, - 0.00047430366976186633, - 0.04818058758974075, - -0.4799128472805023, - 0.00018429107149131596, - 0.011861988343298435, - 0.06088569387793541, - 0.0008461413672193885, - 0.005328264087438583, - -0.011493473313748837, - -0.11350836604833603, - 0.006329597905278206, - 0.00031669469899497926, - -0.0011600167490541935, - -0.022669579833745956, - 0.004070379305630922, - 0.0073160636238753796, - -0.00834545586258173, - -0.27817651629447937, - 0.0036344374530017376 - ], - "yaxis": "y" - } - ], - "layout": { - "coloraxis": { - "colorbar": { - "title": { - "text": "Layer" - } - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ] + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Attention Analysis" + ] }, - "legend": { - "tracegroupgap": 0 + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Attention heads are particularly easy to study because we can look directly at their attention patterns and study from what positions they move information from and two. This is particularly easy here as we're looking at the direct effect on the logits so we need only look at the attention patterns from the final token. \n", + "\n", + "We use Alan Cooney's circuitsvis library to visualize the attention patterns! We visualize the top 3 positive and negative heads by direct logit attribution, and show these for the first prompt (as an illustration).\n", + "\n", + "
Interpreting Attention Patterns \n", + "An easy mistake to make when looking at attention patterns is thinking that they must convey information about the token looked at (maybe accounting for the context of the token). But actually, all we can confidently say is that it moves information from the *residual stream position* corresponding to that input token. Especially later on in the model, there may be components in the residual stream that are nothing to do with the input token! Eg the period at the end of a sentence may contain summary information for that sentence, and the head may solely move that, rather than caring about whether it ends in \".\", \"!\" or \"?\"\n", + "
" + ] }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "def visualize_attention_patterns(\n", + " heads: Union[List[int], int, Float[torch.Tensor, \"heads\"]],\n", + " local_cache: ActivationCache,\n", + " local_tokens: torch.Tensor,\n", + " title: Optional[str] = \"\",\n", + " max_width: Optional[int] = 700,\n", + ") -> str:\n", + " # If a single head is given, convert to a list\n", + " if isinstance(heads, int):\n", + " heads = [heads]\n", + "\n", + " # Create the plotting data\n", + " labels: List[str] = []\n", + " patterns: List[Float[torch.Tensor, \"dest_pos src_pos\"]] = []\n", + "\n", + " # Assume we have a single batch item\n", + " batch_index = 0\n", + "\n", + " for head in heads:\n", + " # Set the label\n", + " layer = head // model.cfg.n_heads\n", + " head_index = head % model.cfg.n_heads\n", + " labels.append(f\"L{layer}H{head_index}\")\n", + "\n", + " # Get the attention patterns for the head\n", + " # Attention patterns have shape [batch, head_index, query_pos, key_pos]\n", + " patterns.append(local_cache[\"attn\", layer][batch_index, head_index])\n", + "\n", + " # Convert the tokens to strings (for the axis labels)\n", + " str_tokens = model.to_str_tokens(local_tokens)\n", + "\n", + " # Combine the patterns into a single tensor\n", + " patterns: Float[torch.Tensor, \"head_index dest_pos src_pos\"] = torch.stack(\n", + " patterns, dim=0\n", + " )\n", + "\n", + " # Circuitsvis Plot (note we get the code version so we can concatenate with the title)\n", + " plot = attention_heads(\n", + " attention=patterns, tokens=str_tokens, attention_head_names=labels\n", + " ).show_code()\n", + "\n", + " # Display the title\n", + " title_html = f\"

{title}


\"\n", + "\n", + " # Return the visualisation as raw code\n", + " return f\"
{title_html + plot}
\"" ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Inspecting the patterns, we can see that both types of name movers attend to the indirect object - this suggests they're simply copying the name attended to (with the OV circuit) and that the interesting part is the circuit behind the attention pattern that calculates *where* to move information from (the QK circuit)" ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "

Top 3 Positive Logit Attribution Heads


\n", + "

Top 3 Negative Logit Attribution Heads


\n", + "
" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "top_k = 3\n", + "\n", + "top_positive_logit_attr_heads = torch.topk(\n", + " per_head_logit_diffs.flatten(), k=top_k\n", + ").indices\n", + "\n", + "positive_html = visualize_attention_patterns(\n", + " top_positive_logit_attr_heads,\n", + " cache,\n", + " tokens[0],\n", + " f\"Top {top_k} Positive Logit Attribution Heads\",\n", + ")\n", + "\n", + "top_negative_logit_attr_heads = torch.topk(\n", + " -per_head_logit_diffs.flatten(), k=top_k\n", + ").indices\n", + "\n", + "negative_html = visualize_attention_patterns(\n", + " top_negative_logit_attr_heads,\n", + " cache,\n", + " tokens[0],\n", + " title=f\"Top {top_k} Negative Logit Attribution Heads\",\n", + ")\n", + "\n", + "HTML(positive_html + negative_html)" ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } }, - "title": { - "text": "Scatter plot of output patching vs value patching" + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Activation Patching" + ] }, - "xaxis": { - "anchor": "y", - "domain": [ - 0, - 1 - ], - "range": [ - -0.5, - 0.5 - ], - "title": { - "text": "Value Patch" - } + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**This section explains how to do activation patching conceptually by implementing it from scratch. To use it in practice with TransformerLens, see [this demonstration instead](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Activation_Patching_in_TL_Demo.ipynb)**.\n", + "\n", + "The obvious limitation to the techniques used above is that they only look at the very end of the circuit - the parts that directly affect the logits. Clearly this is not sufficient to understand the circuit! We want to understand how things compose together to produce this final output, and ideally to produce an end-to-end circuit fully explaining this behaviour. \n", + "\n", + "The technique we'll use to investigate this is called **activation patching**. This was first introduced in [David Bau and Kevin Meng's excellent ROME paper](https://rome.baulab.info/), there called causal tracing. \n", + "\n", + "The setup of activation patching is to take two runs of the model on two different inputs, the clean run and the corrupted run. The clean run outputs the correct answer and the corrupted run does not. The key idea is that we give the model the corrupted input, but then **intervene** on a specific activation and **patch** in the corresponding activation from the clean run (ie replace the corrupted activation with the clean activation), and then continue the run. And we then measure how much the output has updated towards the correct answer. \n", + "\n", + "We can then iterate over many possible activations and look at how much they affect the corrupted run. If patching in an activation significantly increases the probability of the correct answer, this allows us to *localise* which activations matter. \n", + "\n", + "The ability to localise is a key move in mechanistic interpretability - if the computation is diffuse and spread across the entire model, it is likely much harder to form a clean mechanistic story for what's going on. But if we can identify precisely which parts of the model matter, we can then zoom in and determine what they represent and how they connect up with each other, and ultimately reverse engineer the underlying circuit that they represent. \n", + "\n", + "Here's an animation from the ROME paper demonstrating this technique (they studied factual recall, and use stars to represent corruption applied to the subject of the sentence, but the same principles apply):\n", + "\n", + "![CT Animation](https://rome.baulab.info/images/small-ct-animation.gif)\n", + "\n", + "See also [the explanation in a mech interp explainer](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=qeWBvs-R-taFfcCq-S_hgMqx) and [this piece](https://www.neelnanda.io/mechanistic-interpretability/attribution-patching#how-to-think-about-activation-patching) describing how to think about patching on a conceptual level" + ] }, - "yaxis": { - "anchor": "x", - "domain": [ - 0, - 1 - ], - "range": [ - -0.5, - 0.5 - ], - "title": { - "text": "Output Patch" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "head_labels = [\n", - " f\"L{l}H{h}\" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)\n", - "]\n", - "scatter(\n", - " x=utils.to_numpy(patched_head_v_diff.flatten()),\n", - " y=utils.to_numpy(patched_head_z_diff.flatten()),\n", - " xaxis=\"Value Patch\",\n", - " yaxis=\"Output Patch\",\n", - " caxis=\"Layer\",\n", - " hover_name=head_labels,\n", - " color=einops.repeat(\n", - " np.arange(model.cfg.n_layers), \"layer -> (layer head)\", head=model.cfg.n_heads\n", - " ),\n", - " range_x=(-0.5, 0.5),\n", - " range_y=(-0.5, 0.5),\n", - " title=\"Scatter plot of output patching vs value patching\",\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "When we patch in attention patterns, we see the opposite effect - early and late heads matter a lot, middle heads don't. (In fact, the sum of value patching and pattern patching is approx the same as output patching)" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": {}, - "outputs": [], - "source": [ - "def patch_head_pattern(\n", - " corrupted_head_pattern: Float[torch.Tensor, \"batch head_index query_pos d_head\"],\n", - " hook,\n", - " head_index,\n", - " clean_cache,\n", - "):\n", - " corrupted_head_pattern[:, head_index, :, :] = clean_cache[hook.name][\n", - " :, head_index, :, :\n", - " ]\n", - " return corrupted_head_pattern\n", - "\n", - "\n", - "patched_head_attn_diff = torch.zeros(\n", - " model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32\n", - ")\n", - "for layer in range(model.cfg.n_layers):\n", - " for head_index in range(model.cfg.n_heads):\n", - " hook_fn = partial(patch_head_pattern, head_index=head_index, clean_cache=cache)\n", - " patched_logits = model.run_with_hooks(\n", - " corrupted_tokens,\n", - " fwd_hooks=[(utils.get_act_name(\"attn\", layer, \"attn\"), hook_fn)],\n", - " return_type=\"logits\",\n", - " )\n", - " patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)\n", - "\n", - " patched_head_attn_diff[layer, head_index] = normalize_patched_logit_diff(\n", - " patched_logit_diff\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ { - "coloraxis": "coloraxis", - "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", - "name": "0", - "type": "heatmap", - "xaxis": "x", - "yaxis": "y", - "z": [ - [ - 0.0006401354330591857, - 0.005318799521774054, - 0.0011584057938307524, - -5.920405237702653e-05, - -0.00106671336106956, - 0.005079298280179501, - -0.0030818663071841, - -0.0020521720871329308, - -0.0014405983965843916, - 0.003492669900879264, - -0.002568227471783757, - -0.0009168237447738647 - ], - [ - -0.0007600873941555619, - 0.0001683824957581237, - 0.00012246915139257908, - -0.00034914951538667083, - 1.4901700524205808e-05, - 0.0050090523436665535, - -0.0002975976967718452, - -0.0014448943547904491, - -0.001099134678952396, - 0.00047447148244827986, - 5.195457561057992e-05, - -0.0034954219590872526 - ], - [ - -0.0007243098807521164, - 0.0017458146903663874, - -0.00015556166181340814, - 5.7626621128292754e-05, - -9.7398049547337e-05, - -0.0004238593974150717, - -0.0007917031762190163, - 0.00027222454082220793, - 0.00010179472155869007, - 0.0004223826399538666, - 0.00015193692524917424, - -0.0007437760941684246 - ], - [ - 0.11458104848861694, - 0.00021140948229003698, - -0.0009424989693798125, - 0.000429833511589095, - 0.02004295401275158, - 0.002104730810970068, - 7.628730963915586e-05, - -0.001543701975606382, - -0.0008484235731884837, - -0.0005819046637043357, - 0.00011921360419364646, - -1.899631206470076e-05 - ], - [ - -0.001127125695347786, - 0.001237143180333078, - -0.0012324444251134992, - -0.0005952289211563766, - -0.0007541133090853691, - -0.0005842540413141251, - 0.004813014063984156, - 0.00018187458044849336, - -0.0005361591465771198, - 0.0008579217828810215, - -0.0002985374303534627, - -1.144477391790133e-05 - ], - [ - -0.004241178277879953, - 0.0029509058222174644, - 0.0005218615406192839, - 0.0009535074350424111, - 0.0001622070267330855, - 0.34350839257240295, - -0.0003052163519896567, - 0.00010293584637111053, - -0.005300541408360004, - 0.024864863604307175, - 0.014383262023329735, - -0.0023285921197384596 - ], - [ - -0.0023893399629741907, - -0.002172795357182622, - -0.00047614958020858467, - 0.00043188079143874347, - -0.004675475414842367, - 0.0018583494238555431, - -0.0026542814448475838, - 0.0014367386465892196, - 0.00030326974228955805, - 0.13043038547039032, - 8.813483145786449e-05, - 0.0011766973184421659 - ], - [ - 0.00031847349600866437, - 0.02057075686752796, - 0.00031840638257563114, - -0.002512782346457243, - -0.0002628941729199141, - -0.00024718698114156723, - 0.0005524033331312239, - -0.00043131023994646966, - 0.00025715501396916807, - 0.008090951479971409, - -0.0030689111445099115, - -0.0004238593974150717 - ], - [ - 0.000976699055172503, - 0.00039251212729141116, - 0.0017534669023007154, - 0.022595642134547234, - -4.4805787183577195e-05, - 0.00014220383309293538, - 0.009584981948137283, - -0.0003157213795930147, - 0.0015271222218871117, - 0.0011813960736617446, - -0.010774029418826103, - 0.00936581939458847 - ], - [ - 0.006314125377684832, - -0.0010949057759717107, - 0.011662023141980171, - 0.0013481340138241649, - -0.02918696030974388, - 0.0038333951961249113, - -0.04409456625580788, - -0.005032042507082224, - 0.00482167350128293, - 0.2766477167606354, - -3.164933150401339e-05, - -0.0006618167390115559 - ], - [ - 0.0953889712691307, - 0.02506939135491848, - 0.014239178970456123, - 0.014754998497664928, - 9.890835644910112e-05, - -8.977938705356792e-05, - 0.05082912743091583, - -0.5051022171974182, - 0.00014696970174554735, - -0.0016026375815272331, - 0.06883199512958527, - 0.002327115274965763 - ], - [ - 0.0013425961369648576, - 0.009630928747355938, - -0.07776415348052979, - -0.007728713098913431, - -0.0005726079107262194, - -0.002957182005047798, - -0.0049475994892418385, - 0.00045916702947579324, - -0.0006328188464976847, - -0.006520198658108711, - -0.3204910457134247, - -0.002473111730068922 - ] - ] - } - ], - "layout": { - "coloraxis": { - "cmid": 0, - "colorscale": [ - [ - 0, - "rgb(103,0,31)" - ], - [ - 0.1, - "rgb(178,24,43)" - ], - [ - 0.2, - "rgb(214,96,77)" - ], - [ - 0.3, - "rgb(244,165,130)" - ], - [ - 0.4, - "rgb(253,219,199)" - ], - [ - 0.5, - "rgb(247,247,247)" - ], - [ - 0.6, - "rgb(209,229,240)" - ], - [ - 0.7, - "rgb(146,197,222)" - ], - [ - 0.8, - "rgb(67,147,195)" - ], - [ - 0.9, - "rgb(33,102,172)" - ], - [ - 1, - "rgb(5,48,97)" - ] - ] + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The above was all fairly abstract, so let's zoom in and lay out a concrete example to understand Indirect Object Identification.\n", + "\n", + "Here our clean input will be eg \"After John and Mary went to the store, **John** gave a bottle of milk to\" and our corrupted input will be eg \"After John and Mary went to the store, **Mary** gave a bottle of milk to\". These prompts are identical except for the name of the indirect object, and so patching is a causal intervention which will allow us to understand precisely which parts of the network are identifying the indirect object. \n", + "\n", + "One natural thing to patch in is the residual stream at a specific layer and specific position. For example, the model is likely initially doing some processing on the second subject token to realise that it's a duplicate, but then uses attention to move that information to the \" to\" token. So patching in the residual stream at the \" to\" token will likely matter a lot in later layers but not at all in early layers.\n", + "\n", + "We can zoom in much further and patch in specific activations from specific layers. For example, we think that the output of head L9H9 on the final token is significant for directly connecting to the logits\n", + "\n", + "We can patch in specific activations, and can zoom in as far as seems reasonable. For example, if we patch in the output of head L9H9 on the final token, we would predict that it will significantly affect performance. \n", + "\n", + "Note that this technique does *not* tell us how the components of the circuit connect up, just what they are. \n", + "\n", + "
Technical details \n", + "The choice of clean and corrupted prompt has both pros and cons. By carefully setting up the counterfactual, that only differs in the second subject, we avoid detecting the parts of the model doing irrelevant computation like detecting that the indirect object task is relevant at all or that it should be outputting a name rather than an article or pronoun. Or even context like that John and Mary are names at all. \n", + "\n", + "However, it *also* bakes in some details that *are* relevant to the task. Such as finding the location of the second subject, and of the names in the first clause. Or that the name mover heads have learned to copy whatever they look at. \n", + "\n", + "Some of these could be patched by also changing up the order of the names in the original sentence - patching in \"After John and Mary went to the store, John gave a bottle of milk to\" vs \"After Mary and John went to the store, John gave a bottle of milk to\".\n", + "\n", + "In the ROME paper they take a different tack. Rather than carefully setting up counterfactuals between two different but related inputs, they **corrupt** the clean input by adding Gaussian noise to the token embedding for the subject. This is in some ways much lower effort (you don't need to set up a similar but different prompt) but can also introduce some issues, such as ways this noise might break things. In practice, you should take care about how you choose your counterfactuals and try out several. Try to reason beforehand about what they will and will not tell you, and compare the results between different counterfactuals.\n", + "\n", + "I discuss some of these limitations and how the author's solved them with much more refined usage of these techniques in our interview\n", + "
" + ] }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Residual Stream" ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Lets begin by patching in the residual stream at the start of each layer and for each token position. " ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We first create a set of corrupted tokens - where we swap each pair of prompts to have the opposite answer." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Corrupted Average Logit Diff -3.55\n", + "Clean Average Logit Diff 3.55\n" + ] + } + ], + "source": [ + "corrupted_prompts = []\n", + "for i in range(0, len(prompts), 2):\n", + " corrupted_prompts.append(prompts[i + 1])\n", + " corrupted_prompts.append(prompts[i])\n", + "corrupted_tokens = model.to_tokens(corrupted_prompts, prepend_bos=True)\n", + "corrupted_logits, corrupted_cache = model.run_with_cache(\n", + " corrupted_tokens, return_type=\"logits\"\n", + ")\n", + "corrupted_average_logit_diff = logits_to_ave_logit_diff(corrupted_logits, answer_tokens)\n", + "print(\"Corrupted Average Logit Diff\", round(corrupted_average_logit_diff.item(), 2))\n", + "print(\"Clean Average Logit Diff\", round(original_average_logit_diff.item(), 2))" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['<|endoftext|>When John and Mary went to the shops, Mary gave the bag to',\n", + " '<|endoftext|>When John and Mary went to the shops, John gave the bag to',\n", + " '<|endoftext|>When Tom and James went to the park, Tom gave the ball to',\n", + " '<|endoftext|>When Tom and James went to the park, James gave the ball to',\n", + " '<|endoftext|>When Dan and Sid went to the shops, Dan gave an apple to',\n", + " '<|endoftext|>When Dan and Sid went to the shops, Sid gave an apple to',\n", + " '<|endoftext|>After Martin and Amy went to the park, Martin gave a drink to',\n", + " '<|endoftext|>After Martin and Amy went to the park, Amy gave a drink to']" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.to_string(corrupted_tokens)" ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } }, - "title": { - "text": "Logit Difference From Patched Head Pattern" + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now intervene on the corrupted run and patch in the clean residual stream at a specific layer and position.\n", + "\n", + "We do the intervention using TransformerLens's `HookPoint` feature. We can design a hook function that takes in a specific activation and returns an edited copy, and temporarily add it in with `model.run_with_hooks`. " + ] }, - "xaxis": { - "anchor": "y", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "scaleanchor": "y", - "title": { - "text": "Head" - } + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "def patch_residual_component(\n", + " corrupted_residual_component: Float[torch.Tensor, \"batch pos d_model\"],\n", + " hook,\n", + " pos,\n", + " clean_cache,\n", + "):\n", + " corrupted_residual_component[:, pos, :] = clean_cache[hook.name][:, pos, :]\n", + " return corrupted_residual_component\n", + "\n", + "\n", + "def normalize_patched_logit_diff(patched_logit_diff):\n", + " # Subtract corrupted logit diff to measure the improvement, divide by the total improvement from clean to corrupted to normalise\n", + " # 0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively *improved* on clean performance\n", + " return (patched_logit_diff - corrupted_average_logit_diff) / (\n", + " original_average_logit_diff - corrupted_average_logit_diff\n", + " )\n", + "\n", + "\n", + "patched_residual_stream_diff = torch.zeros(\n", + " model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32\n", + ")\n", + "for layer in range(model.cfg.n_layers):\n", + " for position in range(tokens.shape[1]):\n", + " hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)\n", + " patched_logits = model.run_with_hooks(\n", + " corrupted_tokens,\n", + " fwd_hooks=[(utils.get_act_name(\"resid_pre\", layer), hook_fn)],\n", + " return_type=\"logits\",\n", + " )\n", + " patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)\n", + "\n", + " patched_residual_stream_diff[layer, position] = normalize_patched_logit_diff(\n", + " patched_logit_diff\n", + " )" + ] }, - "yaxis": { - "anchor": "x", - "autorange": "reversed", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "title": { - "text": "Layer" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ { - "hovertemplate": "%{hovertext}

Attention Patch=%{x}
Output Patch=%{y}", - "hovertext": [ - "L0H0", - "L0H1", - "L0H2", - "L0H3", - "L0H4", - "L0H5", - "L0H6", - "L0H7", - "L0H8", - "L0H9", - "L0H10", - "L0H11", - "L1H0", - "L1H1", - "L1H2", - "L1H3", - "L1H4", - "L1H5", - "L1H6", - "L1H7", - "L1H8", - "L1H9", - "L1H10", - "L1H11", - "L2H0", - "L2H1", - "L2H2", - "L2H3", - "L2H4", - "L2H5", - "L2H6", - "L2H7", - "L2H8", - "L2H9", - "L2H10", - "L2H11", - "L3H0", - "L3H1", - "L3H2", - "L3H3", - "L3H4", - "L3H5", - "L3H6", - "L3H7", - "L3H8", - "L3H9", - "L3H10", - "L3H11", - "L4H0", - "L4H1", - "L4H2", - "L4H3", - "L4H4", - "L4H5", - "L4H6", - "L4H7", - "L4H8", - "L4H9", - "L4H10", - "L4H11", - "L5H0", - "L5H1", - "L5H2", - "L5H3", - "L5H4", - "L5H5", - "L5H6", - "L5H7", - "L5H8", - "L5H9", - "L5H10", - "L5H11", - "L6H0", - "L6H1", - "L6H2", - "L6H3", - "L6H4", - "L6H5", - "L6H6", - "L6H7", - "L6H8", - "L6H9", - "L6H10", - "L6H11", - "L7H0", - "L7H1", - "L7H2", - "L7H3", - "L7H4", - "L7H5", - "L7H6", - "L7H7", - "L7H8", - "L7H9", - "L7H10", - "L7H11", - "L8H0", - "L8H1", - "L8H2", - "L8H3", - "L8H4", - "L8H5", - "L8H6", - "L8H7", - "L8H8", - "L8H9", - "L8H10", - "L8H11", - "L9H0", - "L9H1", - "L9H2", - "L9H3", - "L9H4", - "L9H5", - "L9H6", - "L9H7", - "L9H8", - "L9H9", - "L9H10", - "L9H11", - "L10H0", - "L10H1", - "L10H2", - "L10H3", - "L10H4", - "L10H5", - "L10H6", - "L10H7", - "L10H8", - "L10H9", - "L10H10", - "L10H11", - "L11H0", - "L11H1", - "L11H2", - "L11H3", - "L11H4", - "L11H5", - "L11H6", - "L11H7", - "L11H8", - "L11H9", - "L11H10", - "L11H11" - ], - "legendgroup": "", - "marker": { - "color": "#636efa", - "symbol": "circle" - }, - "mode": "markers", - "name": "", - "orientation": "v", - "showlegend": false, - "type": "scatter", - "x": [ - 0.0006401354330591857, - 0.005318799521774054, - 0.0011584057938307524, - -5.920405237702653e-05, - -0.00106671336106956, - 0.005079298280179501, - -0.0030818663071841, - -0.0020521720871329308, - -0.0014405983965843916, - 0.003492669900879264, - -0.002568227471783757, - -0.0009168237447738647, - -0.0007600873941555619, - 0.0001683824957581237, - 0.00012246915139257908, - -0.00034914951538667083, - 1.4901700524205808e-05, - 0.0050090523436665535, - -0.0002975976967718452, - -0.0014448943547904491, - -0.001099134678952396, - 0.00047447148244827986, - 5.195457561057992e-05, - -0.0034954219590872526, - -0.0007243098807521164, - 0.0017458146903663874, - -0.00015556166181340814, - 5.7626621128292754e-05, - -9.7398049547337e-05, - -0.0004238593974150717, - -0.0007917031762190163, - 0.00027222454082220793, - 0.00010179472155869007, - 0.0004223826399538666, - 0.00015193692524917424, - -0.0007437760941684246, - 0.11458104848861694, - 0.00021140948229003698, - -0.0009424989693798125, - 0.000429833511589095, - 0.02004295401275158, - 0.002104730810970068, - 7.628730963915586e-05, - -0.001543701975606382, - -0.0008484235731884837, - -0.0005819046637043357, - 0.00011921360419364646, - -1.899631206470076e-05, - -0.001127125695347786, - 0.001237143180333078, - -0.0012324444251134992, - -0.0005952289211563766, - -0.0007541133090853691, - -0.0005842540413141251, - 0.004813014063984156, - 0.00018187458044849336, - -0.0005361591465771198, - 0.0008579217828810215, - -0.0002985374303534627, - -1.144477391790133e-05, - -0.004241178277879953, - 0.0029509058222174644, - 0.0005218615406192839, - 0.0009535074350424111, - 0.0001622070267330855, - 0.34350839257240295, - -0.0003052163519896567, - 0.00010293584637111053, - -0.005300541408360004, - 0.024864863604307175, - 0.014383262023329735, - -0.0023285921197384596, - -0.0023893399629741907, - -0.002172795357182622, - -0.00047614958020858467, - 0.00043188079143874347, - -0.004675475414842367, - 0.0018583494238555431, - -0.0026542814448475838, - 0.0014367386465892196, - 0.00030326974228955805, - 0.13043038547039032, - 8.813483145786449e-05, - 0.0011766973184421659, - 0.00031847349600866437, - 0.02057075686752796, - 0.00031840638257563114, - -0.002512782346457243, - -0.0002628941729199141, - -0.00024718698114156723, - 0.0005524033331312239, - -0.00043131023994646966, - 0.00025715501396916807, - 0.008090951479971409, - -0.0030689111445099115, - -0.0004238593974150717, - 0.000976699055172503, - 0.00039251212729141116, - 0.0017534669023007154, - 0.022595642134547234, - -4.4805787183577195e-05, - 0.00014220383309293538, - 0.009584981948137283, - -0.0003157213795930147, - 0.0015271222218871117, - 0.0011813960736617446, - -0.010774029418826103, - 0.00936581939458847, - 0.006314125377684832, - -0.0010949057759717107, - 0.011662023141980171, - 0.0013481340138241649, - -0.02918696030974388, - 0.0038333951961249113, - -0.04409456625580788, - -0.005032042507082224, - 0.00482167350128293, - 0.2766477167606354, - -3.164933150401339e-05, - -0.0006618167390115559, - 0.0953889712691307, - 0.02506939135491848, - 0.014239178970456123, - 0.014754998497664928, - 9.890835644910112e-05, - -8.977938705356792e-05, - 0.05082912743091583, - -0.5051022171974182, - 0.00014696970174554735, - -0.0016026375815272331, - 0.06883199512958527, - 0.002327115274965763, - 0.0013425961369648576, - 0.009630928747355938, - -0.07776415348052979, - -0.007728713098913431, - -0.0005726079107262194, - -0.002957182005047798, - -0.0049475994892418385, - 0.00045916702947579324, - -0.0006328188464976847, - -0.006520198658108711, - -0.3204910457134247, - -0.002473111730068922 - ], - "xaxis": "x", - "y": [ - 0.0009487751522101462, - 0.016124747693538666, - 0.0018548924708738923, - 0.0034389030188322067, - -0.00982347596436739, - 0.011058605276048183, - -0.004063969012349844, - -0.0015792781487107277, - -0.0012082795146852732, - 0.003828897839412093, - -0.004256919026374817, - -0.0011422622483223677, - -0.0010771177476271987, - -0.00037898647133260965, - 2.5171791548928013e-06, - -0.00026067905128002167, - -0.00014146546891424805, - 0.0038321535103023052, - -0.0004293300735298544, - -0.00142992555629462, - -0.0009228314156644046, - 0.0006944393389858305, - 0.00043302192352712154, - -0.0035714071709662676, - -0.0004967569257132709, - 0.0008057993836700916, - 0.0005424688570201397, - -0.0005309234256856143, - -0.0007159864180721343, - -0.0010389237431809306, - -0.0009490771917626262, - -8.649027586216107e-05, - 0.0002766547549981624, - 0.0021084228064864874, - -0.0001975146442418918, - -0.0016405630158260465, - 0.1162627637386322, - 0.0002507446042727679, - -0.0014675153652206063, - -0.00039680811460129917, - 0.018962211906909943, - -0.00018764731066767126, - 0.011170871555805206, - -0.0013301445869728923, - -0.0007356539717875421, - -0.00030253134900704026, - -0.00014683544577565044, - -0.00022228369198273867, - -0.001650598249398172, - 0.0002927311579696834, - -0.00143563118763268, - 0.03084198758006096, - -0.007432155776768923, - -0.00028236035723239183, - 0.006017433945089579, - -0.011007187888026237, - -0.001266107545234263, - 0.0014901700196787715, - -0.0001800622121663764, - 0.002944394713267684, - -0.004211106337606907, - 0.0029597999528050423, - 0.002045023487880826, - 0.0013397098518908024, - -0.0012190865818411112, - 0.34349915385246277, - 0.0005632104002870619, - -0.0001262281439267099, - -0.00515326950699091, - 0.016240738332271576, - 0.01709030382335186, - -0.004175194539129734, - 0.039775289595127106, - 0.015226684510707855, - -0.0010229480685666203, - 0.0008072761120274663, - -0.004935584031045437, - -0.002123525831848383, - -0.014274083077907562, - 0.0013746818294748664, - 0.0014838266652077436, - 0.1302703619003296, - -0.00033616088330745697, - 0.0012919505825266242, - 0.00037177055492065847, - 0.019514480605721474, - 0.00022255218937061727, - 0.124249167740345, - -0.00040352059295400977, - -0.007652895525097847, - 0.0013010123511776328, - -0.0011253133416175842, - -0.007449474185705185, - 0.19224143028259277, - -0.003275118535384536, - -0.0005017912480980158, - -0.001007912098430097, - 3.091096004936844e-05, - -0.0008595998515374959, - 0.012359987013041973, - -0.0004041247011628002, - -0.004328910261392593, - 0.3185553252696991, - 0.002330605871975422, - 0.0021182901691645384, - 0.0001405928487656638, - 0.2779357433319092, - 0.005738262087106705, - 0.0058898297138512135, - -0.0009689796715974808, - 0.00912561360746622, - 0.020675739273428917, - -0.03700518235564232, - 0.014263041317462921, - -0.04828466475009918, - 0.05834139883518219, - 0.0006514795240946114, - 0.26360899209976196, - 0.0004918567719869316, - -0.00261044898070395, - 0.08374208211898804, - 0.020676210522651672, - -0.003743582172319293, - 0.01085072010755539, - -0.001096583902835846, - 0.00047430366976186633, - 0.04818058758974075, - -0.4799128472805023, - 0.00018429107149131596, - 0.011861988343298435, - 0.06088569387793541, - 0.0008461413672193885, - 0.005328264087438583, - -0.011493473313748837, - -0.11350836604833603, - 0.006329597905278206, - 0.00031669469899497926, - -0.0011600167490541935, - -0.022669579833745956, - 0.004070379305630922, - 0.0073160636238753796, - -0.00834545586258173, - -0.27817651629447937, - 0.0036344374530017376 - ], - "yaxis": "y" - } - ], - "layout": { - "legend": { - "tracegroupgap": 0 + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can immediately see that, exactly as predicted, originally all relevant computation happens on the second subject token, and at layers 7 and 8, the information is moved to the final token. Moving the residual stream at the correct position near *exactly* recovers performance!\n", + "\n", + "For reference, tokens and their index from the first prompt are on the x-axis. In an abuse of notation, note that the difference here is averaged over *all* 8 prompts, while the labels only come from the *first* prompt. \n", + "\n", + "To be easier to interpret, we normalise the logit difference, by subtracting the corrupted logit difference, and dividing by the total improvement from clean to corrupted to normalise\n", + "0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively *improved* on clean performance" + ] }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Position: %{x}
Layer: %{y}
color: %{z}", + "name": "0", + "type": "heatmap", + "x": [ + "<|endoftext|>_0", + "When_1", + " John_2", + " and_3", + " Mary_4", + " went_5", + " to_6", + " the_7", + " shops_8", + ",_9", + " John_10", + " gave_11", + " the_12", + " bag_13", + " to_14" + ], + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1.000650405883789, + -0.0002469856117386371, + 0.00000976665523921838, + -0.00036458822432905436, + -0.000048967522161547095 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1.001051902770996, + -0.000027621845219982788, + -0.000019768245692830533, + -0.0004596704675350338, + -0.0005947590689174831 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1.0002663135528564, + 0.0008680911851115525, + 0.0005157867562957108, + -0.0009929431835189462, + -0.0008658089209347963 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.994907796382904, + 0.005429857410490513, + 0.0016050540143623948, + -0.0006193603039719164, + -0.0016324409516528249 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.9675672054290771, + 0.03134213387966156, + 0.0028418952133506536, + -0.0012302964460104704, + -0.000985861523076892 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.967520534992218, + 0.03100077249109745, + 0.0017823305679485202, + -0.00048668819363228977, + -0.0006467136554419994 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.9228319525718689, + 0.05134531855583191, + 0.004728672094643116, + 0.0009345446596853435, + 0.017046840861439705 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.6565483808517456, + 0.02385685034096241, + 0.002357019344344735, + -0.000017183941963594407, + 0.3186916410923004 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.027302566915750504, + 0.03142499923706055, + 0.0018202561186626554, + 0.0007990868762135506, + 0.9383866190910339 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.026841485872864723, + 0.02098155952990055, + 0.0012512058019638062, + 0.00032317222212441266, + 1.0048279762268066 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.005687985569238663, + 0.014263377524912357, + 0.00048709093243815005, + -0.00008977938705356792, + 0.9914212226867676 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Logit Difference From Patched Residual Stream" + }, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "scaleanchor": "y", + "title": { + "text": "Position" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Layer" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "prompt_position_labels = [\n", + " f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(tokens[0]))\n", + "]\n", + "imshow(\n", + " patched_residual_stream_diff,\n", + " x=prompt_position_labels,\n", + " title=\"Logit Difference From Patched Residual Stream\",\n", + " labels={\"x\": \"Position\", \"y\": \"Layer\"},\n", + ")" ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Layers" ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can apply exactly the same idea, but this time patching in attention or MLP layers. These are also residual components with identical shapes to the residual stream terms, so we can reuse the same hooks." ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } }, - "title": { - "text": "Scatter plot of output patching vs attention patching" + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "patched_attn_diff = torch.zeros(\n", + " model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32\n", + ")\n", + "patched_mlp_diff = torch.zeros(\n", + " model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32\n", + ")\n", + "for layer in range(model.cfg.n_layers):\n", + " for position in range(tokens.shape[1]):\n", + " hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)\n", + " patched_attn_logits = model.run_with_hooks(\n", + " corrupted_tokens,\n", + " fwd_hooks=[(utils.get_act_name(\"attn_out\", layer), hook_fn)],\n", + " return_type=\"logits\",\n", + " )\n", + " patched_attn_logit_diff = logits_to_ave_logit_diff(\n", + " patched_attn_logits, answer_tokens\n", + " )\n", + " patched_mlp_logits = model.run_with_hooks(\n", + " corrupted_tokens,\n", + " fwd_hooks=[(utils.get_act_name(\"mlp_out\", layer), hook_fn)],\n", + " return_type=\"logits\",\n", + " )\n", + " patched_mlp_logit_diff = logits_to_ave_logit_diff(\n", + " patched_mlp_logits, answer_tokens\n", + " )\n", + "\n", + " patched_attn_diff[layer, position] = normalize_patched_logit_diff(\n", + " patched_attn_logit_diff\n", + " )\n", + " patched_mlp_diff[layer, position] = normalize_patched_logit_diff(\n", + " patched_mlp_logit_diff\n", + " )" + ] }, - "xaxis": { - "anchor": "y", - "domain": [ - 0, - 1 - ], - "title": { - "text": "Attention Patch" - } + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We see that several attention layers are significant but that, matching the residual stream results, early layers matter on the second subject token, and later layers matter on the final token, and layers essentially don't matter on any other token. Extremely localised! As with direct logit attribution, layer 9 is positive and layers 10 and 11 are not, suggesting that the late layers only matter for direct logit effects, but we also see that layers 7 and 8 matter significantly. Presumably these are the heads that move information about which name is duplicated from the second subject token to the final token." + ] }, - "yaxis": { - "anchor": "x", - "domain": [ - 0, - 1 - ], - "title": { - "text": "Output Patch" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "imshow(\n", - " patched_head_attn_diff,\n", - " title=\"Logit Difference From Patched Head Pattern\",\n", - " labels={\"x\": \"Head\", \"y\": \"Layer\"},\n", - ")\n", - "head_labels = [\n", - " f\"L{l}H{h}\" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)\n", - "]\n", - "scatter(\n", - " x=utils.to_numpy(patched_head_attn_diff.flatten()),\n", - " y=utils.to_numpy(patched_head_z_diff.flatten()),\n", - " hover_name=head_labels,\n", - " xaxis=\"Attention Patch\",\n", - " yaxis=\"Output Patch\",\n", - " title=\"Scatter plot of output patching vs attention patching\",\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Consolidating Understanding\n", - "\n", - "OK, let's zoom out and reconsolidate. At a high-level, we find that all the action is on the second subject token until layer 7 and then transitions to the final token. And that attention layers matter a lot, MLP layers not so much (apart from MLP0, likely as an extended embedding).\n", - "\n", - "We've further localised important behaviour to several categories of heads. We've found 3 categories of heads that matter a lot - early heads (L5H5, L6H9, L3H0) whose output matters on the second subject and whose behaviour is determined by their attention patterns, mid-late heads (L8H6, L8H10, L7H9, L7H3) whose output matters on the final token and whose behaviour is determined by their value vectors, and late heads (L9H9, L10H7, L11H10) whose output matters on the final token and whose behaviour is determined by their attention patterns.\n", - "\n", - "A natural speculation is that early heads detect both that the second subject is a repeated token and *which* is repeated (ie the \" John\" token is repeated), middle heads compose with this and move this duplicated token information from the second subject token to the final token, and the late heads compose with this to *inhibit* their attention to the duplicated token, and then attend to the correct indirect object name and copy that directly to the logits." - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Visualizing Attention Patterns\n", - "\n", - "We can validate this by looking at the attention patterns of these heads! Let's take the top 10 heads by output patching (in absolute value) and split it into early, middle and late.\n", - "\n", - "We see that middle heads attend from the final token to the second subject, and late heads attend from the final token to the indirect object, which is completely consistent with the above speculation! But weirdly, while *one* early head attends from the second subject to its first copy, the other two mysteriously attend to the word *after* the first copy." - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "

Top Early Heads


\n", - "

Top Middle Heads


\n", - "

Top Late Heads


\n", - "
" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 33, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "top_k = 10\n", - "top_heads_by_output_patch = torch.topk(\n", - " patched_head_z_diff.abs().flatten(), k=top_k\n", - ").indices\n", - "first_mid_layer = 7\n", - "first_late_layer = 9\n", - "early_heads = top_heads_by_output_patch[\n", - " top_heads_by_output_patch < model.cfg.n_heads * first_mid_layer\n", - "]\n", - "mid_heads = top_heads_by_output_patch[\n", - " torch.logical_and(\n", - " model.cfg.n_heads * first_mid_layer <= top_heads_by_output_patch,\n", - " top_heads_by_output_patch < model.cfg.n_heads * first_late_layer,\n", - " )\n", - "]\n", - "late_heads = top_heads_by_output_patch[\n", - " model.cfg.n_heads * first_late_layer <= top_heads_by_output_patch\n", - "]\n", - "\n", - "early = visualize_attention_patterns(\n", - " early_heads, cache, tokens[0], title=f\"Top Early Heads\"\n", - ")\n", - "mid = visualize_attention_patterns(\n", - " mid_heads, cache, tokens[0], title=f\"Top Middle Heads\"\n", - ")\n", - "late = visualize_attention_patterns(\n", - " late_heads, cache, tokens[0], title=f\"Top Late Heads\"\n", - ")\n", - "\n", - "HTML(early + mid + late)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Comparing to the Paper\n", - "\n", - "We can now refer to the (far, far more rigorous and detailed) analysis in the paper to compare our results! Here's the diagram they give of their results. \n", - "\n", - "![IOI1](https://pbs.twimg.com/media/FghGkTAWAAAmkhm.jpg)\n", - "\n", - "(Head 1.2 in their notation is L1H2 in my notation etc. And note - in the [latest version of the paper](https://arxiv.org/pdf/2211.00593.pdf) they add 9.0 as a backup name mover, and remove 11.3)\n", - "\n", - "The heads form three categories corresponding to the early, middle and late categories we found and we did fairly well! Definitely not perfect, but with some fairly generic techniques and some a priori reasoning, we found the broad strokes of the circuit and what it looks like. We focused on the most important heads, so we didn't find all relevant heads in each category (especially not the heads in brackets, which are more minor), but this serves as a good base for doing more rigorous and involved analysis, especially for finding the *complete* circuit (ie all of the parts of the model which participate in this behaviour) rather than just a partial and suggestive circuit. Go check out [their paper](https://arxiv.org/abs/2211.00593) or [our interview](https://www.youtube.com/watch?v=gzwj0jWbvbo) to learn more about what they did and what they found!\n", - "\n", - "Breaking down their categories:\n", - "\n", - "* Early: The duplicate token heads, previous token heads and induction heads. These serve the purpose of detecting that the second subject is duplicated and which earlier name is the duplicate.\n", - " * We found a direct duplicate token head which behaves exactly as expected, L3H0. Heads L5H0 and L6H9 are induction heads, which explains why they don't attend directly to the earlier copy of John!\n", - " * Note that the duplicate token heads and induction heads do not compose with each other - both directly add to the S-Inhibition heads. The diagram is somewhat misleading.\n", - "* Middle: They call these S-Inhibition heads - they copy the information about the duplicate token from the second subject to the to token, and their output is used to *inhibit* the attention paid from the name movers to the first subject copy. We found all these heads, and had a decent guess for what they did.\n", - " * In either case they attend to the second subject, so the patch that mattered was their value vectors!\n", - "* Late: They call these name movers, and we found some of them. They attend from the final token to the indirect object name and copy that to the logits, using the S-Inhibition heads to inhibit attention to the first copy of the subject token.\n", - " * We did find their surprising result of *negative* name movers - name movers that inhibit the correct answer!\n", - " * They have an entire category of heads we missed called backup name movers - we'll get to these later.\n", - "\n", - "So, now, let's dig into the two anomalies we missed - induction heads and backup name mover heads" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Bonus: Exploring Anomalies" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Early Heads are Induction Heads(?!)\n", - "\n", - "A really weird observation is that some of the early heads detecting duplicated tokens are induction heads, not just direct duplicate token heads. This is very weird! What's up with that? \n", - "\n", - "First off, what's an induction head? An induction head is an important type of attention head that can detect and continue repeated sequences. It is the second head in a two head induction circuit, which looks for previous copies of the current token and attends to the token *after* it, and then copies that to the current position and predicts that it will come next. They're enough of a big deal that [we wrote a whole paper on them](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html).\n", - "\n", - "![Move image demo](https://pbs.twimg.com/media/FNWAzXjVEAEOGRe.jpg)\n", - "\n", - "Second, why is it surprising that they come up here? It's surprising because it feels like overkill. The model doesn't care about *what* token comes after the first copy of the subject, just that it's duplicated. And it already has simpler duplicate token heads. My best guess is that it just already had induction heads around and that, in addition to their main function, they *also* only activate on duplicated tokens. So it was useful to repurpose this existing machinery. \n", - "\n", - "This suggests that as we look for circuits in larger models life may get more and more complicated, as components in simpler circuits get repurposed and built upon. " - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can verify that these are induction heads by running the model on repeated text and plotting the heads." - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "metadata": {}, - "outputs": [], - "source": [ - "example_text = \"Research in mechanistic interpretability seeks to explain behaviors of machine learning models in terms of their internal components.\"\n", - "example_repeated_text = example_text + example_text\n", - "example_repeated_tokens = model.to_tokens(example_repeated_text, prepend_bos=True)\n", - "example_repeated_logits, example_repeated_cache = model.run_with_cache(\n", - " example_repeated_tokens\n", - ")\n", - "induction_head_labels = [81, 65]" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "

Induction Heads


\n", - "
" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 35, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "code = visualize_attention_patterns(\n", - " induction_head_labels,\n", - " example_repeated_cache,\n", - " example_repeated_tokens,\n", - " title=\"Induction Heads\",\n", - " max_width=800,\n", - ")\n", - "HTML(code)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Implications\n", - "\n", - "One implication of this is that it's useful to categories heads according to whether they occur in\n", - "simpler circuits, so that as we look for more complex circuits we can easily look for them. This is\n", - "easy to do here! An interesting fact about induction heads is that they work on a sequence of\n", - "repeated random tokens - notable for being wildly off distribution from the natural language GPT-2\n", - "was trained on. Being able to predict a model's behaviour off distribution is a good mark of success\n", - "for mechanistic interpretability! This is a good sanity check for whether a head is an induction\n", - "head or not. \n", - "\n", - "We can characterise an induction head by just giving a sequence of random tokens repeated once, and\n", - "measuring the average attention paid from the second copy of a token to the token after the first\n", - "copy. At the same time, we can also measure the average attention paid from the second copy of a\n", - "token to the first copy of the token, which is the attention that the induction head would pay if it\n", - "were a duplicate token head, and the average attention paid to the previous token to find previous\n", - "token heads.\n", - "\n", - "Note that this is a superficial study of whether something is an induction head - we totally ignore\n", - "the question of whether it actually does boost the correct token or whether it composes with a\n", - "single previous head and how. In particular, we sometimes get anti-induction heads which suppress\n", - "the induction-y token (no clue why!), and this technique will find those too . But given the\n", - "previous rigorous analysis, we can be pretty confident that this picks up on some true signal about\n", - "induction heads." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "
Technical Implementation Details \n", - "We can do this again by using hooks, this time just to access the attention patterns rather than to intervene on them. \n", - "\n", - "Our hook function acts on the attention pattern activation. This has the name\n", - "\"blocks.{layer}.{layer_type}.hook_{activation_name}\" in general, here it's\n", - "\"blocks.{layer}.attn.hook_attn\". And it has shape [batch, head_index, query_pos, token_pos]. Our\n", - "hook function takes in the attention pattern activation, calculates the score for the relevant type\n", - "of head, and write it to an external cache.\n", - "\n", - "We add in hooks using `model.run_with_hooks(tokens, fwd_hooks=[(names_filter, hook_fn)])` to\n", - "temporarily add in the hooks and run the model, getting the resulting output. Previously\n", - "names_filter was the name of the activation, but here it's a boolean function mapping activation\n", - "names to whether we want to hook them or not. Here it's just whether the name ends with hook_attn.\n", - "hook_fn must take in the two inputs activation (the activation tensor) and hook (the HookPoint\n", - "object, which contains the name of the activation and some metadata such as the current layer).\n", - "\n", - "Internally our hooks use the function `tensor.diagonal`, this takes the diagonal between two\n", - "dimensions, and allows an arbitrary offset - offset by 1 to get previous tokens, seq_len to get\n", - "duplicate tokens (the distance to earlier copies) and seq_len-1 to get induction heads (the distance\n", - "to the token *after* earlier copies). Different offsets give a different length of output tensor,\n", - "and we can now just average to get a score in [0, 1] for each head\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[0.0390, 0.0000, 0.0310],\n", - " [0.1890, 0.1720, 0.0680],\n", - " [0.1570, 0.0210, 0.4820]])\n", - "tensor([[0.0030, 0.1320, 0.0050],\n", - " [0.0000, 0.0000, 0.0020],\n", - " [0.0020, 0.0090, 0.0000]])\n", - "tensor([[0.0040, 0.0000, 0.0040],\n", - " [0.0010, 0.0000, 0.0020],\n", - " [0.0020, 0.0090, 0.0020]])\n" - ] - } - ], - "source": [ - "seq_len = 100\n", - "batch_size = 2\n", - "\n", - "prev_token_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=device)\n", - "\n", - "\n", - "def prev_token_hook(pattern, hook):\n", - " layer = hook.layer()\n", - " diagonal = pattern.diagonal(offset=1, dim1=-1, dim2=-2)\n", - " # print(diagonal)\n", - " # print(pattern)\n", - " prev_token_scores[layer] = einops.reduce(\n", - " diagonal, \"batch head_index diagonal -> head_index\", \"mean\"\n", - " )\n", - "\n", - "\n", - "duplicate_token_scores = torch.zeros(\n", - " (model.cfg.n_layers, model.cfg.n_heads), device=device\n", - ")\n", - "\n", - "\n", - "def duplicate_token_hook(pattern, hook):\n", - " layer = hook.layer()\n", - " diagonal = pattern.diagonal(offset=seq_len, dim1=-1, dim2=-2)\n", - " duplicate_token_scores[layer] = einops.reduce(\n", - " diagonal, \"batch head_index diagonal -> head_index\", \"mean\"\n", - " )\n", - "\n", - "\n", - "induction_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=device)\n", - "\n", - "\n", - "def induction_hook(pattern, hook):\n", - " layer = hook.layer()\n", - " diagonal = pattern.diagonal(offset=seq_len - 1, dim1=-1, dim2=-2)\n", - " induction_scores[layer] = einops.reduce(\n", - " diagonal, \"batch head_index diagonal -> head_index\", \"mean\"\n", - " )\n", - "\n", - "\n", - "torch.manual_seed(0)\n", - "original_tokens = torch.randint(\n", - " 100, 20000, size=(batch_size, seq_len), device=\"cpu\"\n", - ").to(device)\n", - "repeated_tokens = einops.repeat(\n", - " original_tokens, \"batch seq_len -> batch (2 seq_len)\"\n", - ").to(device)\n", - "\n", - "pattern_filter = lambda act_name: act_name.endswith(\"hook_pattern\")\n", - "\n", - "loss = model.run_with_hooks(\n", - " repeated_tokens,\n", - " return_type=\"loss\",\n", - " fwd_hooks=[\n", - " (pattern_filter, prev_token_hook),\n", - " (pattern_filter, duplicate_token_hook),\n", - " (pattern_filter, induction_hook),\n", - " ],\n", - ")\n", - "print(torch.round(utils.get_corner(prev_token_scores).detach().cpu(), decimals=3))\n", - "print(torch.round(utils.get_corner(duplicate_token_scores).detach().cpu(), decimals=3))\n", - "print(torch.round(utils.get_corner(induction_scores).detach().cpu(), decimals=3))" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can now plot the head scores, and instantly see that the relevant early heads are induction heads or duplicate token heads (though also that there's a lot of induction heads that are *not* use - I have no idea why!). " - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ { - "coloraxis": "coloraxis", - "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", - "name": "0", - "type": "heatmap", - "xaxis": "x", - "yaxis": "y", - "z": [ - [ - 0.039069853723049164, - 0.0004489101702347398, - 0.03133601322770119, - 0.007519590202718973, - 0.034592196345329285, - 0.00036230171099305153, - 0.034512776881456375, - 0.19740213453769684, - 0.038447845727205276, - 0.04053792357444763, - 0.027628764510154724, - 0.02496313862502575 - ], - [ - 0.1890650987625122, - 0.17219914495944977, - 0.06807752698659897, - 0.04494515433907509, - 0.07908554375171661, - 0.03096739575266838, - 0.028282109647989273, - 0.03644327446818352, - 0.026936717331409454, - 0.018826229497790337, - 0.045100897550582886, - 0.0065726665779948235 - ], - [ - 0.15745528042316437, - 0.020724520087242126, - 0.4817989468574524, - 0.2991352379322052, - 0.10764895379543304, - 0.33004048466682434, - 0.0997551754117012, - 0.04926132410764694, - 0.25493940711021423, - 0.3606453835964203, - 0.1257179230451584, - 0.07931824028491974 - ], - [ - 0.005844001192599535, - 0.15787364542484283, - 0.4189082086086273, - 0.30129021406173706, - 0.014345049858093262, - 0.032344333827495575, - 0.3312888443470001, - 0.5285974144935608, - 0.34242063760757446, - 0.101837158203125, - 0.10516070574522018, - 0.2233113795518875 - ], - [ - 0.10626544803380966, - 0.11930850893259048, - 0.022880680859088898, - 0.22826944291591644, - 0.020003994926810265, - 0.10010036826133728, - 0.1739213615655899, - 0.17407020926475525, - 0.02587701380252838, - 0.10249985754489899, - 0.009514841251075268, - 0.9921423196792603 - ], - [ - 0.019766658544540405, - 0.00528325280174613, - 0.16648508608341217, - 0.12087740004062653, - 0.16500000655651093, - 0.00803269725292921, - 0.41770195960998535, - 0.025827765464782715, - 0.04802601411938667, - 0.016231779009103775, - 0.03110172413289547, - 0.024261215701699257 - ], - [ - 0.2172909826040268, - 0.039100028574466705, - 0.01804858259856701, - 0.059900715947151184, - 0.032934583723545074, - 0.0873451679944992, - 0.026895340532064438, - 0.0943947583436966, - 0.49925994873046875, - 0.006240115500986576, - 0.027026718482375145, - 0.1278565675020218 - ], - [ - 0.2511657178401947, - 0.01330868061631918, - 0.006663354113698006, - 0.037430502474308014, - 0.02331537753343582, - 0.01740722358226776, - 0.022067422047257423, - 0.022141192108392715, - 0.04502448812127113, - 0.0208425372838974, - 0.008310739882290363, - 0.017167754471302032 - ], - [ - 0.020890623331069946, - 0.016537941992282867, - 0.02158307284116745, - 0.0150058064609766, - 0.02421221323311329, - 0.10198988765478134, - 0.029100384563207626, - 0.22793792188167572, - 0.02781485579907894, - 0.0179410632699728, - 0.024828944355249405, - 0.03806235268712044 - ], - [ - 0.02607586607336998, - 0.015407431870698929, - 0.02044427953660488, - 0.14558182656764984, - 0.01247025839984417, - 0.017151640728116035, - 0.013311829417943954, - 0.024451706558465958, - 0.018111787736415863, - 0.01319331955164671, - 0.0357399508357048, - 0.01879822090268135 - ], - [ - 0.02147812582552433, - 0.018419174477458, - 0.018183622509241104, - 0.02172141708433628, - 0.0315677747130394, - 0.034705750644207, - 0.017550116404891014, - 0.011417553760111332, - 0.01579565554857254, - 0.04592214897274971, - 0.01621554046869278, - 0.03039470687508583 - ], - [ - 0.03320508822798729, - 0.0175714660435915, - 0.015131079591810703, - 0.04148406535387039, - 0.015181189402937889, - 0.01758997142314911, - 0.015148494392633438, - 0.01767607219517231, - 0.06622709333896637, - 0.018451133742928505, - 0.01700744964182377, - 0.029749270528554916 - ] - ] - } - ], - "layout": { - "coloraxis": { - "cmid": 0, - "colorscale": [ - [ - 0, - "rgb(103,0,31)" - ], - [ - 0.1, - "rgb(178,24,43)" - ], - [ - 0.2, - "rgb(214,96,77)" - ], - [ - 0.3, - "rgb(244,165,130)" - ], - [ - 0.4, - "rgb(253,219,199)" - ], - [ - 0.5, - "rgb(247,247,247)" - ], - [ - 0.6, - "rgb(209,229,240)" - ], - [ - 0.7, - "rgb(146,197,222)" - ], - [ - 0.8, - "rgb(67,147,195)" - ], - [ - 0.9, - "rgb(33,102,172)" - ], - [ - 1, - "rgb(5,48,97)" - ] - ] + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Position: %{x}
Layer: %{y}
color: %{z}", + "name": "0", + "type": "heatmap", + "x": [ + "<|endoftext|>_0", + "When_1", + " John_2", + " and_3", + " Mary_4", + " went_5", + " to_6", + " the_7", + " shops_8", + ",_9", + " John_10", + " gave_11", + " the_12", + " bag_13", + " to_14" + ], + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.035456884652376175, + -0.0002469856117386371, + 0.00000976665523921838, + -0.00036458822432905436, + -0.000048967522161547095 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.0029848709236830473, + 0.00007950929284561425, + 0.000020842242520302534, + 0.00008088535105343908, + -0.0005967392353340983 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.0019131568260490894, + 0.0006668510613963008, + 0.00039482791908085346, + -0.0007051457650959492, + -0.00027282864903099835 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.1546323299407959, + 0.0038019807543605566, + 0.0005171628436073661, + -0.00011964991426793858, + -0.0005599213181994855 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.005406397394835949, + 0.019581740722060204, + 0.001007509301416576, + -0.0002424211270408705, + 0.0007936497568152845 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.3520970046520233, + 0.0010525835677981377, + 0.00022436455765273422, + 0.00013367898645810783, + 0.00008172441448550671 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.11986024677753448, + 0.021243548020720482, + 0.002727783052250743, + 0.0013409851817414165, + 0.01797366514801979 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.013310473412275314, + 0.011509180068969727, + 0.00037542887730523944, + -0.00004094611358596012, + 0.29760244488716125 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.0015009435592219234, + 0.017351653426885605, + 0.0005848917062394321, + 0.0010122752282768488, + 0.5697318911552429 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.00012901381705887616, + 0.00630143890157342, + 0.00014156615361571312, + 0.00031229801243171096, + 0.27152299880981445 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.0009373303619213402, + 0.00008669164526509121, + 0.00033243544748984277, + 9.73309283835988e-7, + -0.1929796040058136 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.40617984533309937 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Logit Difference From Patched Attention Layer" + }, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "scaleanchor": "y", + "title": { + "text": "Position" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Layer" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "imshow(\n", + " patched_attn_diff,\n", + " x=prompt_position_labels,\n", + " title=\"Logit Difference From Patched Attention Layer\",\n", + " labels={\"x\": \"Position\", \"y\": \"Layer\"},\n", + ")" + ] }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In contrast, the MLP layers do not matter much. This makes sense, since this is more a task about moving information than about processing it, and the MLP layers specialise in processing information.\n", + "\n", + "The one exception is MLP 0, which matters a lot, but I think this is misleading and just a generally true statement about MLP 0 rather than being about the circuit on this task.\n", + "\n", + "
My takes on MLP0 \n", + "It's often observed on GPT-2 Small that MLP0 matters a lot, and that ablating it utterly destroys performance. My current best guess is that the first MLP layer is essentially acting as an extension of the embedding (for whatever reason) and that when later layers want to access the input tokens they mostly read in the output of the first MLP layer, rather than the token embeddings. Within this frame, the first attention layer doesn't do much. \n", + "\n", + "In this framing, it makes sense that MLP0 matters on the second subject token, because that's the one position with a different input token!\n", + "\n", + "I'm not entirely sure why this happens, but I would guess that it's because the embedding and unembedding matrices in GPT-2 Small are the same. This is pretty unprincipled, as the tasks of embedding and unembedding tokens are not inverses, but this is common practice, and plausibly models want to dedicate some parameters to overcoming this. \n", + "\n", + "I only have suggestive evidence of this, and would love to see someone look into this properly!\n", + "
" ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Position: %{x}
Layer: %{y}
color: %{z}", + "name": "0", + "type": "heatmap", + "x": [ + "<|endoftext|>_0", + "When_1", + " John_2", + " and_3", + " Mary_4", + " went_5", + " to_6", + " the_7", + " shops_8", + ",_9", + " John_10", + " gave_11", + " the_12", + " bag_13", + " to_14" + ], + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.8507890701293945, + -0.00027843358111567795, + -0.00007293107046280056, + -0.00047373308916576207, + 0.000040039929444901645 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.008863994851708412, + 0.000222149450564757, + 0.00014938619278836995, + -0.00004853121208725497, + 0.000304041663184762 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.013550343923270702, + 0.0000586334899708163, + -0.0003296833310741931, + -0.0006382559076882899, + 0.0007730424986220896 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.0019468198297545314, + 0.0004995090421289206, + 0.00017318192112725228, + 0.00016871812113095075, + 0.00040764876757748425 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.019787074998021126, + 0.004128609783947468, + -0.0000486990247736685, + -0.00017019486404024065, + 0.0007914346642792225 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.09652391821146011, + -0.0018826150335371494, + -0.0004844730719923973, + 0.0007094081956893206, + -0.00018335132335778326 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.015900013968348503, + -0.0008501688134856522, + 0.00012337534280959517, + 0.000027521158699528314, + -0.007238299585878849 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.010360540822148323, + 0.0031509376130998135, + 0.0005309234256856143, + 0.0002361114020459354, + 0.008496351540088654 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.012533102184534073, + 0.00002201692586822901, + -0.00035374757135286927, + 0.00008615465048933402, + -0.021631328389048576 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.00033465056912973523, + 0.0008094912045635283, + 0.000016244195649051107, + 0.00012924875773023814, + 0.03162466362118721 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.0013599144294857979, + -0.00019499746849760413, + -0.00009934466652339324, + -0.00014217027637641877, + 0.028764141723513603 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.02044912613928318 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Logit Difference From Patched MLP Layer" + }, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "scaleanchor": "y", + "title": { + "text": "Position" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Layer" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "imshow(\n", + " patched_mlp_diff,\n", + " x=prompt_position_labels,\n", + " title=\"Logit Difference From Patched MLP Layer\",\n", + " labels={\"x\": \"Position\", \"y\": \"Layer\"},\n", + ")" ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Heads\n", + "\n", + "We can refine the above analysis by patching in individual heads! This is somewhat more annoying, because there are now three dimensions (head_index, position and layer), so for now lets patch in a head's output across all positions.\n", + "\n", + "The easiest way to do this is to patch in the activation `z`, the \"mixed value\" of the attention head. That is, the average of all previous values weighted by the attention pattern, ie the activation that is then multiplied by `W_O`, the output weights. " ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } }, - "title": { - "text": "Previous Token Scores" + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "def patch_head_vector(\n", + " corrupted_head_vector: Float[torch.Tensor, \"batch pos head_index d_head\"],\n", + " hook,\n", + " head_index,\n", + " clean_cache,\n", + "):\n", + " corrupted_head_vector[:, :, head_index, :] = clean_cache[hook.name][\n", + " :, :, head_index, :\n", + " ]\n", + " return corrupted_head_vector\n", + "\n", + "\n", + "patched_head_z_diff = torch.zeros(\n", + " model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32\n", + ")\n", + "for layer in range(model.cfg.n_layers):\n", + " for head_index in range(model.cfg.n_heads):\n", + " hook_fn = partial(patch_head_vector, head_index=head_index, clean_cache=cache)\n", + " patched_logits = model.run_with_hooks(\n", + " corrupted_tokens,\n", + " fwd_hooks=[(utils.get_act_name(\"z\", layer, \"attn\"), hook_fn)],\n", + " return_type=\"logits\",\n", + " )\n", + " patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)\n", + "\n", + " patched_head_z_diff[layer, head_index] = normalize_patched_logit_diff(\n", + " patched_logit_diff\n", + " )" + ] }, - "xaxis": { - "anchor": "y", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "scaleanchor": "y", - "title": { - "text": "Head" - } + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now see that, in addition to the name mover heads identified before, in mid-late layers the heads L8H6, L8H10, L7H9 matter and are presumably responsible for moving information from the second subject to the final token. And heads L5H5, L6H9, L3H0 also matter a lot, and are presumably involved in detecting duplicated tokens." + ] }, - "yaxis": { - "anchor": "x", - "autorange": "reversed", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "title": { - "text": "Layer" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ { - "coloraxis": "coloraxis", - "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", - "name": "0", - "type": "heatmap", - "xaxis": "x", - "yaxis": "y", - "z": [ - [ - 0.0031923248898237944, - 0.13236315548419952, - 0.005006915424019098, - 1.0427449524286203e-05, - 0.0013110184809193015, - 0.7034568786621094, - 0.00426204688847065, - 0.00016496369789820164, - 0.002474633976817131, - 0.0008572910446673632, - 0.01889149099588394, - 0.008690938353538513 - ], - [ - 0.0002916341181844473, - 0.00013782267342321575, - 0.0015036173863336444, - 0.005392482969909906, - 0.0018583914497867227, - 0.009062949568033218, - 0.012414448894560337, - 0.0022405502386391163, - 0.005135662388056517, - 0.005220627877861261, - 0.005546474829316139, - 0.02975049614906311 - ], - [ - 0.0024816279765218496, - 0.009442180395126343, - 0.0003456332196947187, - 0.0002591445227153599, - 0.0052116685546934605, - 0.000570951378904283, - 0.0015209749108180404, - 0.006313100922852755, - 0.001560864970088005, - 0.0004215767839923501, - 0.00015359291865024716, - 0.005160381551831961 - ], - [ - 0.6775657534599304, - 0.002840448170900345, - 0.0007841526530683041, - 0.00471264636144042, - 0.006322895642369986, - 0.006206681486219168, - 0.0005474805948324502, - 0.00037829449865967035, - 0.0020155368838459253, - 0.007952751591801643, - 0.003576782764866948, - 0.002608788898214698 - ], - [ - 0.00860405620187521, - 0.0070286463014781475, - 0.007598803844302893, - 0.003442801535129547, - 0.016561277210712433, - 0.0059797209687530994, - 0.004869826138019562, - 0.0007624455611221492, - 0.006062133703380823, - 0.007536627352237701, - 0.012022900395095348, - 1.055422134237094e-12 - ], - [ - 0.00950299296528101, - 0.00856209360063076, - 0.004162600729614496, - 0.003008665982633829, - 0.006847422569990158, - 0.004358117934316397, - 0.007669268175959587, - 0.009584215469658375, - 0.0076188258826732635, - 0.0043280418030917645, - 0.041402824223041534, - 0.00976183544844389 - ], - [ - 0.004456141032278538, - 0.008873268961906433, - 0.007405205629765987, - 0.0062249391339719296, - 0.00731915095821023, - 0.005623893812298775, - 0.017349667847156525, - 0.005529467947781086, - 0.002920132130384445, - 0.008636755868792534, - 0.006222263444215059, - 0.00835894700139761 - ], - [ - 0.003699858672916889, - 0.04107949137687683, - 0.04148268699645996, - 0.009313640184700489, - 0.009097025729715824, - 0.008774377405643463, - 0.007298537530004978, - 0.023312218487262726, - 0.008843323215842247, - 0.00987986009567976, - 0.017598601058125496, - 0.006039854139089584 - ], - [ - 0.008986304514110088, - 0.028667239472270012, - 0.008891218341886997, - 0.010114557109773159, - 0.009737391024827957, - 0.007611637003719807, - 0.009763265959918499, - 0.005155472084879875, - 0.009276345372200012, - 0.011895839124917984, - 0.010411946102976799, - 0.007498950231820345 - ], - [ - 0.024409977719187737, - 0.011438451707363129, - 0.02003096230328083, - 0.0051185814663767815, - 0.015081286430358887, - 0.012334450148046017, - 0.015452565625309944, - 0.008602450601756573, - 0.014702522195875645, - 0.020766200497746468, - 0.009192758239805698, - 0.005703347735106945 - ], - [ - 0.017897022888064384, - 0.013280633836984634, - 0.006755237001925707, - 0.012744844891130924, - 0.008020960725843906, - 0.007722244597971439, - 0.017341373488307, - 0.0074546560645103455, - 0.007832515984773636, - 0.00825214572250843, - 0.013642766512930393, - 0.012807483784854412 - ], - [ - 0.004923742264509201, - 0.007951060310006142, - 0.007947920821607113, - 0.004564082249999046, - 0.010363400913774967, - 0.009582078084349632, - 0.0102877551689744, - 0.00832072552293539, - 0.0025700009427964687, - 0.012810997664928436, - 0.008063871413469315, - 0.006558285094797611 - ] - ] - } - ], - "layout": { - "coloraxis": { - "cmid": 0, - "colorscale": [ - [ - 0, - "rgb(103,0,31)" - ], - [ - 0.1, - "rgb(178,24,43)" - ], - [ - 0.2, - "rgb(214,96,77)" - ], - [ - 0.3, - "rgb(244,165,130)" - ], - [ - 0.4, - "rgb(253,219,199)" - ], - [ - 0.5, - "rgb(247,247,247)" - ], - [ - 0.6, - "rgb(209,229,240)" - ], - [ - 0.7, - "rgb(146,197,222)" - ], - [ - 0.8, - "rgb(67,147,195)" - ], - [ - 0.9, - "rgb(33,102,172)" - ], - [ - 1, - "rgb(5,48,97)" - ] - ] + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + 0.0009487751522101462, + 0.016124747693538666, + 0.0018548924708738923, + 0.0034389030188322067, + -0.00982347596436739, + 0.011058605276048183, + -0.004063969012349844, + -0.0015792781487107277, + -0.0012082795146852732, + 0.003828897839412093, + -0.004256919026374817, + -0.0011422622483223677 + ], + [ + -0.0010771177476271987, + -0.00037898647133260965, + 0.0000025171791548928013, + -0.00026067905128002167, + -0.00014146546891424805, + 0.0038321535103023052, + -0.0004293300735298544, + -0.00142992555629462, + -0.0009228314156644046, + 0.0006944393389858305, + 0.00043302192352712154, + -0.0035714071709662676 + ], + [ + -0.0004967569257132709, + 0.0008057993836700916, + 0.0005424688570201397, + -0.0005309234256856143, + -0.0007159864180721343, + -0.0010389237431809306, + -0.0009490771917626262, + -0.00008649027586216107, + 0.0002766547549981624, + 0.0021084228064864874, + -0.0001975146442418918, + -0.0016405630158260465 + ], + [ + 0.1162627637386322, + 0.0002507446042727679, + -0.0014675153652206063, + -0.00039680811460129917, + 0.018962211906909943, + -0.00018764731066767126, + 0.011170871555805206, + -0.0013301445869728923, + -0.0007356539717875421, + -0.00030253134900704026, + -0.00014683544577565044, + -0.00022228369198273867 + ], + [ + -0.001650598249398172, + 0.0002927311579696834, + -0.00143563118763268, + 0.03084198758006096, + -0.007432155776768923, + -0.00028236035723239183, + 0.006017433945089579, + -0.011007187888026237, + -0.001266107545234263, + 0.0014901700196787715, + -0.0001800622121663764, + 0.002944394713267684 + ], + [ + -0.004211106337606907, + 0.0029597999528050423, + 0.002045023487880826, + 0.0013397098518908024, + -0.0012190865818411112, + 0.34349915385246277, + 0.0005632104002870619, + -0.0001262281439267099, + -0.00515326950699091, + 0.016240738332271576, + 0.01709030382335186, + -0.004175194539129734 + ], + [ + 0.039775289595127106, + 0.015226684510707855, + -0.0010229480685666203, + 0.0008072761120274663, + -0.004935584031045437, + -0.002123525831848383, + -0.014274083077907562, + 0.0013746818294748664, + 0.0014838266652077436, + 0.1302703619003296, + -0.00033616088330745697, + 0.0012919505825266242 + ], + [ + 0.00037177055492065847, + 0.019514480605721474, + 0.00022255218937061727, + 0.124249167740345, + -0.00040352059295400977, + -0.007652895525097847, + 0.0013010123511776328, + -0.0011253133416175842, + -0.007449474185705185, + 0.19224143028259277, + -0.003275118535384536, + -0.0005017912480980158 + ], + [ + -0.001007912098430097, + 0.00003091096004936844, + -0.0008595998515374959, + 0.012359987013041973, + -0.0004041247011628002, + -0.004328910261392593, + 0.3185553252696991, + 0.002330605871975422, + 0.0021182901691645384, + 0.0001405928487656638, + 0.2779357433319092, + 0.005738262087106705 + ], + [ + 0.0058898297138512135, + -0.0009689796715974808, + 0.00912561360746622, + 0.020675739273428917, + -0.03700518235564232, + 0.014263041317462921, + -0.04828466475009918, + 0.05834139883518219, + 0.0006514795240946114, + 0.26360899209976196, + 0.0004918567719869316, + -0.00261044898070395 + ], + [ + 0.08374208211898804, + 0.020676210522651672, + -0.003743582172319293, + 0.01085072010755539, + -0.001096583902835846, + 0.00047430366976186633, + 0.04818058758974075, + -0.4799128472805023, + 0.00018429107149131596, + 0.011861988343298435, + 0.06088569387793541, + 0.0008461413672193885 + ], + [ + 0.005328264087438583, + -0.011493473313748837, + -0.11350836604833603, + 0.006329597905278206, + 0.00031669469899497926, + -0.0011600167490541935, + -0.022669579833745956, + 0.004070379305630922, + 0.0073160636238753796, + -0.00834545586258173, + -0.27817651629447937, + 0.0036344374530017376 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Logit Difference From Patched Head Output" + }, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "scaleanchor": "y", + "title": { + "text": "Head" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Layer" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "imshow(\n", + " patched_head_z_diff,\n", + " title=\"Logit Difference From Patched Head Output\",\n", + " labels={\"x\": \"Head\", \"y\": \"Layer\"},\n", + ")" + ] }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Decomposing Heads" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Decomposing attention layers into patching in individual heads has already helped us localise the behaviour a lot. But we can understand it further by decomposing heads. An attention head consists of two semi-independent operations - calculating *where* to move information from and to (represented by the attention pattern and implemented via the QK-circuit) and calculating *what* information to move (represented by the value vectors and implemented by the OV circuit). We can disentangle which of these is important by patching in just the attention pattern *or* the value vectors. (See [A Mathematical Framework](https://transformer-circuits.pub/2021/framework/index.html) or [my walkthrough video](https://www.youtube.com/watch?v=KV5gbOmHbjU) for more on this decomposition. If you're not familiar with the details of how attention is implemented, I recommend checking out [my clean transformer implementation](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/clean-transformer-demo/Clean_Transformer_Demo.ipynb#scrollTo=3Pb0NYbZ900e) to see how the code works))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First let's patch in the value vectors, to measure when figuring out what to move is important. . This has the same shape as z ([batch, pos, head_index, d_head]) so we can reuse the same hook." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "patched_head_v_diff = torch.zeros(\n", + " model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32\n", + ")\n", + "for layer in range(model.cfg.n_layers):\n", + " for head_index in range(model.cfg.n_heads):\n", + " hook_fn = partial(patch_head_vector, head_index=head_index, clean_cache=cache)\n", + " patched_logits = model.run_with_hooks(\n", + " corrupted_tokens,\n", + " fwd_hooks=[(utils.get_act_name(\"v\", layer, \"attn\"), hook_fn)],\n", + " return_type=\"logits\",\n", + " )\n", + " patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)\n", + "\n", + " patched_head_v_diff[layer, head_index] = normalize_patched_logit_diff(\n", + " patched_logit_diff\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can plot this as a heatmap and it's initially hard to interpret." + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + -0.00019892427371814847, + 0.005339574534446001, + 0.0006527548539452255, + 0.003504416672512889, + -0.00898387935012579, + 0.0034814265090972185, + -0.0008631910313852131, + -0.00003406582254683599, + 0.0005166929331608117, + 0.00044255363172851503, + -0.0039068968035280704, + -0.0001880836207419634 + ], + [ + -0.0004399022145662457, + -0.00044510437874123454, + -0.0000673597096465528, + 0.00007242763240355998, + -0.000036549441574607044, + -0.0019323208834975958, + -0.0001572397886775434, + 0.000016143509128596634, + 0.00020593880617525429, + 0.000336798548232764, + 0.0003515324497129768, + -0.0005669358652085066 + ], + [ + 0.00021013410878367722, + -0.0007199132232926786, + 0.0004868560063187033, + -0.0005974104860797524, + -0.0005921411793678999, + -0.0005443819100037217, + -0.000227552984142676, + -0.0004809825913980603, + 0.00020570388005580753, + 0.001183376181870699, + -0.0003574058646336198, + -0.0009104468626901507 + ], + [ + 0.0010395278222858906, + -0.00012042184971505776, + -0.00007762980385450646, + -0.0007275318494066596, + -0.001310007064603269, + -0.0023108376190066338, + 0.010987084358930588, + -0.000050712766096694395, + 0.00014314358122646809, + 0.00015069512301124632, + -0.00007957642083056271, + -0.000020238119759596884 + ], + [ + -0.0005373673629947007, + -0.0008137872209772468, + -0.00013334336108528078, + 0.030609702691435814, + -0.007185807917267084, + 0.000148916311445646, + 0.0013340713921934366, + -0.01142292469739914, + -0.0005336419562809169, + 0.0005126654868945479, + 0.00037344868178479373, + 0.0029547319281846285 + ], + [ + 0.00000822278525447473, + 0.000006477540864580078, + 0.0015973682748153806, + 0.00034015480196103454, + -0.0012577504385262728, + -0.00005450531898532063, + 0.0006331544718705118, + -0.00027081489679403603, + 0.00007427356467815116, + -0.006704355590045452, + 0.003175975289195776, + -0.0017300404142588377 + ], + [ + 0.04863045737147331, + 0.015314852818846703, + -0.0004648726317100227, + -0.00011676354915834963, + -0.00004930314753437415, + -0.003952810075134039, + -0.01737578585743904, + -0.00015421917487401515, + 0.0012194222072139382, + -0.00018090127559844404, + -0.00042647725786082447, + 0.00012334177154116333 + ], + [ + -0.00002956846401502844, + -0.0013855225406587124, + -0.00012129446986364201, + 0.1332160234451294, + -0.00024490474606864154, + -0.007315828464925289, + 0.00033297244226559997, + -0.000795092957559973, + -0.007938209921121597, + 0.208413764834404, + -0.00019127204723190516, + -0.00020650937221944332 + ], + [ + -0.0020483459811657667, + -0.0003764357534237206, + -0.0033135139383375645, + -0.009666135534644127, + -0.00031723169377073646, + -0.005141589790582657, + 0.31717124581336975, + 0.0028427678626030684, + 0.0004723234742414206, + -0.0011529687326401472, + 0.2726709246635437, + -0.003175639547407627 + ], + [ + -0.00043929810635745525, + 0.000057089622714556754, + -0.0020629793871194124, + 0.020066648721694946, + -0.007871017791330814, + 0.011316264048218727, + 0.003056862158700824, + 0.06856372952461243, + -0.002747517777606845, + -0.009279227815568447, + 0.000506624230183661, + -0.0013159140944480896 + ], + [ + -0.012957162223756313, + -0.0030454176012426615, + -0.01792328804731369, + -0.0043589151464402676, + -0.0011521632550284266, + 0.0004999117809347808, + -0.0031131464056670666, + 0.019585633650422096, + 0.0000434632929682266, + 0.01297028549015522, + -0.007695754989981651, + -0.0009146086522378027 + ], + [ + 0.004100752994418144, + -0.020459463819861412, + -0.035875942558050156, + 0.014656225219368935, + 0.0008441276149824262, + 0.0017804511589929461, + -0.01804223284125328, + 0.003519016318023205, + 0.008253024891018867, + -0.0017665562918409705, + 0.044167667627334595, + 0.006474285386502743 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Logit Difference From Patched Head Value" + }, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "scaleanchor": "y", + "title": { + "text": "Head" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Layer" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "imshow(\n", + " patched_head_v_diff,\n", + " title=\"Logit Difference From Patched Head Value\",\n", + " labels={\"x\": \"Head\", \"y\": \"Layer\"},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "But it's very easy to interpret if we plot a scatter plot against patching head outputs. Here we see that the earlier heads (L5H5, L6H9, L3H0) and late name movers (L9H9, L10H7, L11H10) don't matter at all now, while the mid-late heads (L8H6, L8H10, L7H9) do. \n", + "\n", + "Meta lesson: Plot things early, often and in diverse ways as you explore a model's internals!" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "hovertemplate": "%{hovertext}

Value Patch=%{x}
Output Patch=%{y}
Layer=%{marker.color}", + "hovertext": [ + "L0H0", + "L0H1", + "L0H2", + "L0H3", + "L0H4", + "L0H5", + "L0H6", + "L0H7", + "L0H8", + "L0H9", + "L0H10", + "L0H11", + "L1H0", + "L1H1", + "L1H2", + "L1H3", + "L1H4", + "L1H5", + "L1H6", + "L1H7", + "L1H8", + "L1H9", + "L1H10", + "L1H11", + "L2H0", + "L2H1", + "L2H2", + "L2H3", + "L2H4", + "L2H5", + "L2H6", + "L2H7", + "L2H8", + "L2H9", + "L2H10", + "L2H11", + "L3H0", + "L3H1", + "L3H2", + "L3H3", + "L3H4", + "L3H5", + "L3H6", + "L3H7", + "L3H8", + "L3H9", + "L3H10", + "L3H11", + "L4H0", + "L4H1", + "L4H2", + "L4H3", + "L4H4", + "L4H5", + "L4H6", + "L4H7", + "L4H8", + "L4H9", + "L4H10", + "L4H11", + "L5H0", + "L5H1", + "L5H2", + "L5H3", + "L5H4", + "L5H5", + "L5H6", + "L5H7", + "L5H8", + "L5H9", + "L5H10", + "L5H11", + "L6H0", + "L6H1", + "L6H2", + "L6H3", + "L6H4", + "L6H5", + "L6H6", + "L6H7", + "L6H8", + "L6H9", + "L6H10", + "L6H11", + "L7H0", + "L7H1", + "L7H2", + "L7H3", + "L7H4", + "L7H5", + "L7H6", + "L7H7", + "L7H8", + "L7H9", + "L7H10", + "L7H11", + "L8H0", + "L8H1", + "L8H2", + "L8H3", + "L8H4", + "L8H5", + "L8H6", + "L8H7", + "L8H8", + "L8H9", + "L8H10", + "L8H11", + "L9H0", + "L9H1", + "L9H2", + "L9H3", + "L9H4", + "L9H5", + "L9H6", + "L9H7", + "L9H8", + "L9H9", + "L9H10", + "L9H11", + "L10H0", + "L10H1", + "L10H2", + "L10H3", + "L10H4", + "L10H5", + "L10H6", + "L10H7", + "L10H8", + "L10H9", + "L10H10", + "L10H11", + "L11H0", + "L11H1", + "L11H2", + "L11H3", + "L11H4", + "L11H5", + "L11H6", + "L11H7", + "L11H8", + "L11H9", + "L11H10", + "L11H11" + ], + "legendgroup": "", + "marker": { + "color": [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 11, + 11, + 11, + 11, + 11, + 11, + 11, + 11, + 11, + 11, + 11, + 11 + ], + "coloraxis": "coloraxis", + "symbol": "circle" + }, + "mode": "markers", + "name": "", + "orientation": "v", + "showlegend": false, + "type": "scatter", + "x": [ + -0.00019892427371814847, + 0.005339574534446001, + 0.0006527548539452255, + 0.003504416672512889, + -0.00898387935012579, + 0.0034814265090972185, + -0.0008631910313852131, + -0.00003406582254683599, + 0.0005166929331608117, + 0.00044255363172851503, + -0.0039068968035280704, + -0.0001880836207419634, + -0.0004399022145662457, + -0.00044510437874123454, + -0.0000673597096465528, + 0.00007242763240355998, + -0.000036549441574607044, + -0.0019323208834975958, + -0.0001572397886775434, + 0.000016143509128596634, + 0.00020593880617525429, + 0.000336798548232764, + 0.0003515324497129768, + -0.0005669358652085066, + 0.00021013410878367722, + -0.0007199132232926786, + 0.0004868560063187033, + -0.0005974104860797524, + -0.0005921411793678999, + -0.0005443819100037217, + -0.000227552984142676, + -0.0004809825913980603, + 0.00020570388005580753, + 0.001183376181870699, + -0.0003574058646336198, + -0.0009104468626901507, + 0.0010395278222858906, + -0.00012042184971505776, + -0.00007762980385450646, + -0.0007275318494066596, + -0.001310007064603269, + -0.0023108376190066338, + 0.010987084358930588, + -0.000050712766096694395, + 0.00014314358122646809, + 0.00015069512301124632, + -0.00007957642083056271, + -0.000020238119759596884, + -0.0005373673629947007, + -0.0008137872209772468, + -0.00013334336108528078, + 0.030609702691435814, + -0.007185807917267084, + 0.000148916311445646, + 0.0013340713921934366, + -0.01142292469739914, + -0.0005336419562809169, + 0.0005126654868945479, + 0.00037344868178479373, + 0.0029547319281846285, + 0.00000822278525447473, + 0.000006477540864580078, + 0.0015973682748153806, + 0.00034015480196103454, + -0.0012577504385262728, + -0.00005450531898532063, + 0.0006331544718705118, + -0.00027081489679403603, + 0.00007427356467815116, + -0.006704355590045452, + 0.003175975289195776, + -0.0017300404142588377, + 0.04863045737147331, + 0.015314852818846703, + -0.0004648726317100227, + -0.00011676354915834963, + -0.00004930314753437415, + -0.003952810075134039, + -0.01737578585743904, + -0.00015421917487401515, + 0.0012194222072139382, + -0.00018090127559844404, + -0.00042647725786082447, + 0.00012334177154116333, + -0.00002956846401502844, + -0.0013855225406587124, + -0.00012129446986364201, + 0.1332160234451294, + -0.00024490474606864154, + -0.007315828464925289, + 0.00033297244226559997, + -0.000795092957559973, + -0.007938209921121597, + 0.208413764834404, + -0.00019127204723190516, + -0.00020650937221944332, + -0.0020483459811657667, + -0.0003764357534237206, + -0.0033135139383375645, + -0.009666135534644127, + -0.00031723169377073646, + -0.005141589790582657, + 0.31717124581336975, + 0.0028427678626030684, + 0.0004723234742414206, + -0.0011529687326401472, + 0.2726709246635437, + -0.003175639547407627, + -0.00043929810635745525, + 0.000057089622714556754, + -0.0020629793871194124, + 0.020066648721694946, + -0.007871017791330814, + 0.011316264048218727, + 0.003056862158700824, + 0.06856372952461243, + -0.002747517777606845, + -0.009279227815568447, + 0.000506624230183661, + -0.0013159140944480896, + -0.012957162223756313, + -0.0030454176012426615, + -0.01792328804731369, + -0.0043589151464402676, + -0.0011521632550284266, + 0.0004999117809347808, + -0.0031131464056670666, + 0.019585633650422096, + 0.0000434632929682266, + 0.01297028549015522, + -0.007695754989981651, + -0.0009146086522378027, + 0.004100752994418144, + -0.020459463819861412, + -0.035875942558050156, + 0.014656225219368935, + 0.0008441276149824262, + 0.0017804511589929461, + -0.01804223284125328, + 0.003519016318023205, + 0.008253024891018867, + -0.0017665562918409705, + 0.044167667627334595, + 0.006474285386502743 + ], + "xaxis": "x", + "y": [ + 0.0009487751522101462, + 0.016124747693538666, + 0.0018548924708738923, + 0.0034389030188322067, + -0.00982347596436739, + 0.011058605276048183, + -0.004063969012349844, + -0.0015792781487107277, + -0.0012082795146852732, + 0.003828897839412093, + -0.004256919026374817, + -0.0011422622483223677, + -0.0010771177476271987, + -0.00037898647133260965, + 0.0000025171791548928013, + -0.00026067905128002167, + -0.00014146546891424805, + 0.0038321535103023052, + -0.0004293300735298544, + -0.00142992555629462, + -0.0009228314156644046, + 0.0006944393389858305, + 0.00043302192352712154, + -0.0035714071709662676, + -0.0004967569257132709, + 0.0008057993836700916, + 0.0005424688570201397, + -0.0005309234256856143, + -0.0007159864180721343, + -0.0010389237431809306, + -0.0009490771917626262, + -0.00008649027586216107, + 0.0002766547549981624, + 0.0021084228064864874, + -0.0001975146442418918, + -0.0016405630158260465, + 0.1162627637386322, + 0.0002507446042727679, + -0.0014675153652206063, + -0.00039680811460129917, + 0.018962211906909943, + -0.00018764731066767126, + 0.011170871555805206, + -0.0013301445869728923, + -0.0007356539717875421, + -0.00030253134900704026, + -0.00014683544577565044, + -0.00022228369198273867, + -0.001650598249398172, + 0.0002927311579696834, + -0.00143563118763268, + 0.03084198758006096, + -0.007432155776768923, + -0.00028236035723239183, + 0.006017433945089579, + -0.011007187888026237, + -0.001266107545234263, + 0.0014901700196787715, + -0.0001800622121663764, + 0.002944394713267684, + -0.004211106337606907, + 0.0029597999528050423, + 0.002045023487880826, + 0.0013397098518908024, + -0.0012190865818411112, + 0.34349915385246277, + 0.0005632104002870619, + -0.0001262281439267099, + -0.00515326950699091, + 0.016240738332271576, + 0.01709030382335186, + -0.004175194539129734, + 0.039775289595127106, + 0.015226684510707855, + -0.0010229480685666203, + 0.0008072761120274663, + -0.004935584031045437, + -0.002123525831848383, + -0.014274083077907562, + 0.0013746818294748664, + 0.0014838266652077436, + 0.1302703619003296, + -0.00033616088330745697, + 0.0012919505825266242, + 0.00037177055492065847, + 0.019514480605721474, + 0.00022255218937061727, + 0.124249167740345, + -0.00040352059295400977, + -0.007652895525097847, + 0.0013010123511776328, + -0.0011253133416175842, + -0.007449474185705185, + 0.19224143028259277, + -0.003275118535384536, + -0.0005017912480980158, + -0.001007912098430097, + 0.00003091096004936844, + -0.0008595998515374959, + 0.012359987013041973, + -0.0004041247011628002, + -0.004328910261392593, + 0.3185553252696991, + 0.002330605871975422, + 0.0021182901691645384, + 0.0001405928487656638, + 0.2779357433319092, + 0.005738262087106705, + 0.0058898297138512135, + -0.0009689796715974808, + 0.00912561360746622, + 0.020675739273428917, + -0.03700518235564232, + 0.014263041317462921, + -0.04828466475009918, + 0.05834139883518219, + 0.0006514795240946114, + 0.26360899209976196, + 0.0004918567719869316, + -0.00261044898070395, + 0.08374208211898804, + 0.020676210522651672, + -0.003743582172319293, + 0.01085072010755539, + -0.001096583902835846, + 0.00047430366976186633, + 0.04818058758974075, + -0.4799128472805023, + 0.00018429107149131596, + 0.011861988343298435, + 0.06088569387793541, + 0.0008461413672193885, + 0.005328264087438583, + -0.011493473313748837, + -0.11350836604833603, + 0.006329597905278206, + 0.00031669469899497926, + -0.0011600167490541935, + -0.022669579833745956, + 0.004070379305630922, + 0.0073160636238753796, + -0.00834545586258173, + -0.27817651629447937, + 0.0036344374530017376 + ], + "yaxis": "y" + } + ], + "layout": { + "coloraxis": { + "colorbar": { + "title": { + "text": "Layer" + } + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "legend": { + "tracegroupgap": 0 + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Scatter plot of output patching vs value patching" + }, + "xaxis": { + "anchor": "y", + "domain": [ + 0, + 1 + ], + "range": [ + -0.5, + 0.5 + ], + "title": { + "text": "Value Patch" + } + }, + "yaxis": { + "anchor": "x", + "domain": [ + 0, + 1 + ], + "range": [ + -0.5, + 0.5 + ], + "title": { + "text": "Output Patch" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "head_labels = [\n", + " f\"L{l}H{h}\" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)\n", + "]\n", + "scatter(\n", + " x=utils.to_numpy(patched_head_v_diff.flatten()),\n", + " y=utils.to_numpy(patched_head_z_diff.flatten()),\n", + " xaxis=\"Value Patch\",\n", + " yaxis=\"Output Patch\",\n", + " caxis=\"Layer\",\n", + " hover_name=head_labels,\n", + " color=einops.repeat(\n", + " np.arange(model.cfg.n_layers), \"layer -> (layer head)\", head=model.cfg.n_heads\n", + " ),\n", + " range_x=(-0.5, 0.5),\n", + " range_y=(-0.5, 0.5),\n", + " title=\"Scatter plot of output patching vs value patching\",\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When we patch in attention patterns, we see the opposite effect - early and late heads matter a lot, middle heads don't. (In fact, the sum of value patching and pattern patching is approx the same as output patching)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "def patch_head_pattern(\n", + " corrupted_head_pattern: Float[torch.Tensor, \"batch head_index query_pos d_head\"],\n", + " hook,\n", + " head_index,\n", + " clean_cache,\n", + "):\n", + " corrupted_head_pattern[:, head_index, :, :] = clean_cache[hook.name][\n", + " :, head_index, :, :\n", + " ]\n", + " return corrupted_head_pattern\n", + "\n", + "\n", + "patched_head_attn_diff = torch.zeros(\n", + " model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32\n", + ")\n", + "for layer in range(model.cfg.n_layers):\n", + " for head_index in range(model.cfg.n_heads):\n", + " hook_fn = partial(patch_head_pattern, head_index=head_index, clean_cache=cache)\n", + " patched_logits = model.run_with_hooks(\n", + " corrupted_tokens,\n", + " fwd_hooks=[(utils.get_act_name(\"attn\", layer, \"attn\"), hook_fn)],\n", + " return_type=\"logits\",\n", + " )\n", + " patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)\n", + "\n", + " patched_head_attn_diff[layer, head_index] = normalize_patched_logit_diff(\n", + " patched_logit_diff\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + 0.0006401354330591857, + 0.005318799521774054, + 0.0011584057938307524, + -0.00005920405237702653, + -0.00106671336106956, + 0.005079298280179501, + -0.0030818663071841, + -0.0020521720871329308, + -0.0014405983965843916, + 0.003492669900879264, + -0.002568227471783757, + -0.0009168237447738647 + ], + [ + -0.0007600873941555619, + 0.0001683824957581237, + 0.00012246915139257908, + -0.00034914951538667083, + 0.000014901700524205808, + 0.0050090523436665535, + -0.0002975976967718452, + -0.0014448943547904491, + -0.001099134678952396, + 0.00047447148244827986, + 0.00005195457561057992, + -0.0034954219590872526 + ], + [ + -0.0007243098807521164, + 0.0017458146903663874, + -0.00015556166181340814, + 0.000057626621128292754, + -0.000097398049547337, + -0.0004238593974150717, + -0.0007917031762190163, + 0.00027222454082220793, + 0.00010179472155869007, + 0.0004223826399538666, + 0.00015193692524917424, + -0.0007437760941684246 + ], + [ + 0.11458104848861694, + 0.00021140948229003698, + -0.0009424989693798125, + 0.000429833511589095, + 0.02004295401275158, + 0.002104730810970068, + 0.00007628730963915586, + -0.001543701975606382, + -0.0008484235731884837, + -0.0005819046637043357, + 0.00011921360419364646, + -0.00001899631206470076 + ], + [ + -0.001127125695347786, + 0.001237143180333078, + -0.0012324444251134992, + -0.0005952289211563766, + -0.0007541133090853691, + -0.0005842540413141251, + 0.004813014063984156, + 0.00018187458044849336, + -0.0005361591465771198, + 0.0008579217828810215, + -0.0002985374303534627, + -0.00001144477391790133 + ], + [ + -0.004241178277879953, + 0.0029509058222174644, + 0.0005218615406192839, + 0.0009535074350424111, + 0.0001622070267330855, + 0.34350839257240295, + -0.0003052163519896567, + 0.00010293584637111053, + -0.005300541408360004, + 0.024864863604307175, + 0.014383262023329735, + -0.0023285921197384596 + ], + [ + -0.0023893399629741907, + -0.002172795357182622, + -0.00047614958020858467, + 0.00043188079143874347, + -0.004675475414842367, + 0.0018583494238555431, + -0.0026542814448475838, + 0.0014367386465892196, + 0.00030326974228955805, + 0.13043038547039032, + 0.00008813483145786449, + 0.0011766973184421659 + ], + [ + 0.00031847349600866437, + 0.02057075686752796, + 0.00031840638257563114, + -0.002512782346457243, + -0.0002628941729199141, + -0.00024718698114156723, + 0.0005524033331312239, + -0.00043131023994646966, + 0.00025715501396916807, + 0.008090951479971409, + -0.0030689111445099115, + -0.0004238593974150717 + ], + [ + 0.000976699055172503, + 0.00039251212729141116, + 0.0017534669023007154, + 0.022595642134547234, + -0.000044805787183577195, + 0.00014220383309293538, + 0.009584981948137283, + -0.0003157213795930147, + 0.0015271222218871117, + 0.0011813960736617446, + -0.010774029418826103, + 0.00936581939458847 + ], + [ + 0.006314125377684832, + -0.0010949057759717107, + 0.011662023141980171, + 0.0013481340138241649, + -0.02918696030974388, + 0.0038333951961249113, + -0.04409456625580788, + -0.005032042507082224, + 0.00482167350128293, + 0.2766477167606354, + -0.00003164933150401339, + -0.0006618167390115559 + ], + [ + 0.0953889712691307, + 0.02506939135491848, + 0.014239178970456123, + 0.014754998497664928, + 0.00009890835644910112, + -0.00008977938705356792, + 0.05082912743091583, + -0.5051022171974182, + 0.00014696970174554735, + -0.0016026375815272331, + 0.06883199512958527, + 0.002327115274965763 + ], + [ + 0.0013425961369648576, + 0.009630928747355938, + -0.07776415348052979, + -0.007728713098913431, + -0.0005726079107262194, + -0.002957182005047798, + -0.0049475994892418385, + 0.00045916702947579324, + -0.0006328188464976847, + -0.006520198658108711, + -0.3204910457134247, + -0.002473111730068922 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Logit Difference From Patched Head Pattern" + }, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "scaleanchor": "y", + "title": { + "text": "Head" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Layer" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "hovertemplate": "%{hovertext}

Attention Patch=%{x}
Output Patch=%{y}", + "hovertext": [ + "L0H0", + "L0H1", + "L0H2", + "L0H3", + "L0H4", + "L0H5", + "L0H6", + "L0H7", + "L0H8", + "L0H9", + "L0H10", + "L0H11", + "L1H0", + "L1H1", + "L1H2", + "L1H3", + "L1H4", + "L1H5", + "L1H6", + "L1H7", + "L1H8", + "L1H9", + "L1H10", + "L1H11", + "L2H0", + "L2H1", + "L2H2", + "L2H3", + "L2H4", + "L2H5", + "L2H6", + "L2H7", + "L2H8", + "L2H9", + "L2H10", + "L2H11", + "L3H0", + "L3H1", + "L3H2", + "L3H3", + "L3H4", + "L3H5", + "L3H6", + "L3H7", + "L3H8", + "L3H9", + "L3H10", + "L3H11", + "L4H0", + "L4H1", + "L4H2", + "L4H3", + "L4H4", + "L4H5", + "L4H6", + "L4H7", + "L4H8", + "L4H9", + "L4H10", + "L4H11", + "L5H0", + "L5H1", + "L5H2", + "L5H3", + "L5H4", + "L5H5", + "L5H6", + "L5H7", + "L5H8", + "L5H9", + "L5H10", + "L5H11", + "L6H0", + "L6H1", + "L6H2", + "L6H3", + "L6H4", + "L6H5", + "L6H6", + "L6H7", + "L6H8", + "L6H9", + "L6H10", + "L6H11", + "L7H0", + "L7H1", + "L7H2", + "L7H3", + "L7H4", + "L7H5", + "L7H6", + "L7H7", + "L7H8", + "L7H9", + "L7H10", + "L7H11", + "L8H0", + "L8H1", + "L8H2", + "L8H3", + "L8H4", + "L8H5", + "L8H6", + "L8H7", + "L8H8", + "L8H9", + "L8H10", + "L8H11", + "L9H0", + "L9H1", + "L9H2", + "L9H3", + "L9H4", + "L9H5", + "L9H6", + "L9H7", + "L9H8", + "L9H9", + "L9H10", + "L9H11", + "L10H0", + "L10H1", + "L10H2", + "L10H3", + "L10H4", + "L10H5", + "L10H6", + "L10H7", + "L10H8", + "L10H9", + "L10H10", + "L10H11", + "L11H0", + "L11H1", + "L11H2", + "L11H3", + "L11H4", + "L11H5", + "L11H6", + "L11H7", + "L11H8", + "L11H9", + "L11H10", + "L11H11" + ], + "legendgroup": "", + "marker": { + "color": "#636efa", + "symbol": "circle" + }, + "mode": "markers", + "name": "", + "orientation": "v", + "showlegend": false, + "type": "scatter", + "x": [ + 0.0006401354330591857, + 0.005318799521774054, + 0.0011584057938307524, + -0.00005920405237702653, + -0.00106671336106956, + 0.005079298280179501, + -0.0030818663071841, + -0.0020521720871329308, + -0.0014405983965843916, + 0.003492669900879264, + -0.002568227471783757, + -0.0009168237447738647, + -0.0007600873941555619, + 0.0001683824957581237, + 0.00012246915139257908, + -0.00034914951538667083, + 0.000014901700524205808, + 0.0050090523436665535, + -0.0002975976967718452, + -0.0014448943547904491, + -0.001099134678952396, + 0.00047447148244827986, + 0.00005195457561057992, + -0.0034954219590872526, + -0.0007243098807521164, + 0.0017458146903663874, + -0.00015556166181340814, + 0.000057626621128292754, + -0.000097398049547337, + -0.0004238593974150717, + -0.0007917031762190163, + 0.00027222454082220793, + 0.00010179472155869007, + 0.0004223826399538666, + 0.00015193692524917424, + -0.0007437760941684246, + 0.11458104848861694, + 0.00021140948229003698, + -0.0009424989693798125, + 0.000429833511589095, + 0.02004295401275158, + 0.002104730810970068, + 0.00007628730963915586, + -0.001543701975606382, + -0.0008484235731884837, + -0.0005819046637043357, + 0.00011921360419364646, + -0.00001899631206470076, + -0.001127125695347786, + 0.001237143180333078, + -0.0012324444251134992, + -0.0005952289211563766, + -0.0007541133090853691, + -0.0005842540413141251, + 0.004813014063984156, + 0.00018187458044849336, + -0.0005361591465771198, + 0.0008579217828810215, + -0.0002985374303534627, + -0.00001144477391790133, + -0.004241178277879953, + 0.0029509058222174644, + 0.0005218615406192839, + 0.0009535074350424111, + 0.0001622070267330855, + 0.34350839257240295, + -0.0003052163519896567, + 0.00010293584637111053, + -0.005300541408360004, + 0.024864863604307175, + 0.014383262023329735, + -0.0023285921197384596, + -0.0023893399629741907, + -0.002172795357182622, + -0.00047614958020858467, + 0.00043188079143874347, + -0.004675475414842367, + 0.0018583494238555431, + -0.0026542814448475838, + 0.0014367386465892196, + 0.00030326974228955805, + 0.13043038547039032, + 0.00008813483145786449, + 0.0011766973184421659, + 0.00031847349600866437, + 0.02057075686752796, + 0.00031840638257563114, + -0.002512782346457243, + -0.0002628941729199141, + -0.00024718698114156723, + 0.0005524033331312239, + -0.00043131023994646966, + 0.00025715501396916807, + 0.008090951479971409, + -0.0030689111445099115, + -0.0004238593974150717, + 0.000976699055172503, + 0.00039251212729141116, + 0.0017534669023007154, + 0.022595642134547234, + -0.000044805787183577195, + 0.00014220383309293538, + 0.009584981948137283, + -0.0003157213795930147, + 0.0015271222218871117, + 0.0011813960736617446, + -0.010774029418826103, + 0.00936581939458847, + 0.006314125377684832, + -0.0010949057759717107, + 0.011662023141980171, + 0.0013481340138241649, + -0.02918696030974388, + 0.0038333951961249113, + -0.04409456625580788, + -0.005032042507082224, + 0.00482167350128293, + 0.2766477167606354, + -0.00003164933150401339, + -0.0006618167390115559, + 0.0953889712691307, + 0.02506939135491848, + 0.014239178970456123, + 0.014754998497664928, + 0.00009890835644910112, + -0.00008977938705356792, + 0.05082912743091583, + -0.5051022171974182, + 0.00014696970174554735, + -0.0016026375815272331, + 0.06883199512958527, + 0.002327115274965763, + 0.0013425961369648576, + 0.009630928747355938, + -0.07776415348052979, + -0.007728713098913431, + -0.0005726079107262194, + -0.002957182005047798, + -0.0049475994892418385, + 0.00045916702947579324, + -0.0006328188464976847, + -0.006520198658108711, + -0.3204910457134247, + -0.002473111730068922 + ], + "xaxis": "x", + "y": [ + 0.0009487751522101462, + 0.016124747693538666, + 0.0018548924708738923, + 0.0034389030188322067, + -0.00982347596436739, + 0.011058605276048183, + -0.004063969012349844, + -0.0015792781487107277, + -0.0012082795146852732, + 0.003828897839412093, + -0.004256919026374817, + -0.0011422622483223677, + -0.0010771177476271987, + -0.00037898647133260965, + 0.0000025171791548928013, + -0.00026067905128002167, + -0.00014146546891424805, + 0.0038321535103023052, + -0.0004293300735298544, + -0.00142992555629462, + -0.0009228314156644046, + 0.0006944393389858305, + 0.00043302192352712154, + -0.0035714071709662676, + -0.0004967569257132709, + 0.0008057993836700916, + 0.0005424688570201397, + -0.0005309234256856143, + -0.0007159864180721343, + -0.0010389237431809306, + -0.0009490771917626262, + -0.00008649027586216107, + 0.0002766547549981624, + 0.0021084228064864874, + -0.0001975146442418918, + -0.0016405630158260465, + 0.1162627637386322, + 0.0002507446042727679, + -0.0014675153652206063, + -0.00039680811460129917, + 0.018962211906909943, + -0.00018764731066767126, + 0.011170871555805206, + -0.0013301445869728923, + -0.0007356539717875421, + -0.00030253134900704026, + -0.00014683544577565044, + -0.00022228369198273867, + -0.001650598249398172, + 0.0002927311579696834, + -0.00143563118763268, + 0.03084198758006096, + -0.007432155776768923, + -0.00028236035723239183, + 0.006017433945089579, + -0.011007187888026237, + -0.001266107545234263, + 0.0014901700196787715, + -0.0001800622121663764, + 0.002944394713267684, + -0.004211106337606907, + 0.0029597999528050423, + 0.002045023487880826, + 0.0013397098518908024, + -0.0012190865818411112, + 0.34349915385246277, + 0.0005632104002870619, + -0.0001262281439267099, + -0.00515326950699091, + 0.016240738332271576, + 0.01709030382335186, + -0.004175194539129734, + 0.039775289595127106, + 0.015226684510707855, + -0.0010229480685666203, + 0.0008072761120274663, + -0.004935584031045437, + -0.002123525831848383, + -0.014274083077907562, + 0.0013746818294748664, + 0.0014838266652077436, + 0.1302703619003296, + -0.00033616088330745697, + 0.0012919505825266242, + 0.00037177055492065847, + 0.019514480605721474, + 0.00022255218937061727, + 0.124249167740345, + -0.00040352059295400977, + -0.007652895525097847, + 0.0013010123511776328, + -0.0011253133416175842, + -0.007449474185705185, + 0.19224143028259277, + -0.003275118535384536, + -0.0005017912480980158, + -0.001007912098430097, + 0.00003091096004936844, + -0.0008595998515374959, + 0.012359987013041973, + -0.0004041247011628002, + -0.004328910261392593, + 0.3185553252696991, + 0.002330605871975422, + 0.0021182901691645384, + 0.0001405928487656638, + 0.2779357433319092, + 0.005738262087106705, + 0.0058898297138512135, + -0.0009689796715974808, + 0.00912561360746622, + 0.020675739273428917, + -0.03700518235564232, + 0.014263041317462921, + -0.04828466475009918, + 0.05834139883518219, + 0.0006514795240946114, + 0.26360899209976196, + 0.0004918567719869316, + -0.00261044898070395, + 0.08374208211898804, + 0.020676210522651672, + -0.003743582172319293, + 0.01085072010755539, + -0.001096583902835846, + 0.00047430366976186633, + 0.04818058758974075, + -0.4799128472805023, + 0.00018429107149131596, + 0.011861988343298435, + 0.06088569387793541, + 0.0008461413672193885, + 0.005328264087438583, + -0.011493473313748837, + -0.11350836604833603, + 0.006329597905278206, + 0.00031669469899497926, + -0.0011600167490541935, + -0.022669579833745956, + 0.004070379305630922, + 0.0073160636238753796, + -0.00834545586258173, + -0.27817651629447937, + 0.0036344374530017376 + ], + "yaxis": "y" + } + ], + "layout": { + "legend": { + "tracegroupgap": 0 + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Scatter plot of output patching vs attention patching" + }, + "xaxis": { + "anchor": "y", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Attention Patch" + } + }, + "yaxis": { + "anchor": "x", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Output Patch" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "imshow(\n", + " patched_head_attn_diff,\n", + " title=\"Logit Difference From Patched Head Pattern\",\n", + " labels={\"x\": \"Head\", \"y\": \"Layer\"},\n", + ")\n", + "head_labels = [\n", + " f\"L{l}H{h}\" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)\n", + "]\n", + "scatter(\n", + " x=utils.to_numpy(patched_head_attn_diff.flatten()),\n", + " y=utils.to_numpy(patched_head_z_diff.flatten()),\n", + " hover_name=head_labels,\n", + " xaxis=\"Attention Patch\",\n", + " yaxis=\"Output Patch\",\n", + " title=\"Scatter plot of output patching vs attention patching\",\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Consolidating Understanding\n", + "\n", + "OK, let's zoom out and reconsolidate. At a high-level, we find that all the action is on the second subject token until layer 7 and then transitions to the final token. And that attention layers matter a lot, MLP layers not so much (apart from MLP0, likely as an extended embedding).\n", + "\n", + "We've further localised important behaviour to several categories of heads. We've found 3 categories of heads that matter a lot - early heads (L5H5, L6H9, L3H0) whose output matters on the second subject and whose behaviour is determined by their attention patterns, mid-late heads (L8H6, L8H10, L7H9, L7H3) whose output matters on the final token and whose behaviour is determined by their value vectors, and late heads (L9H9, L10H7, L11H10) whose output matters on the final token and whose behaviour is determined by their attention patterns.\n", + "\n", + "A natural speculation is that early heads detect both that the second subject is a repeated token and *which* is repeated (ie the \" John\" token is repeated), middle heads compose with this and move this duplicated token information from the second subject token to the final token, and the late heads compose with this to *inhibit* their attention to the duplicated token, and then attend to the correct indirect object name and copy that directly to the logits." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualizing Attention Patterns\n", + "\n", + "We can validate this by looking at the attention patterns of these heads! Let's take the top 10 heads by output patching (in absolute value) and split it into early, middle and late.\n", + "\n", + "We see that middle heads attend from the final token to the second subject, and late heads attend from the final token to the indirect object, which is completely consistent with the above speculation! But weirdly, while *one* early head attends from the second subject to its first copy, the other two mysteriously attend to the word *after* the first copy." + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "

Top Early Heads


\n", + "

Top Middle Heads


\n", + "

Top Late Heads


\n", + "
" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "top_k = 10\n", + "top_heads_by_output_patch = torch.topk(\n", + " patched_head_z_diff.abs().flatten(), k=top_k\n", + ").indices\n", + "first_mid_layer = 7\n", + "first_late_layer = 9\n", + "early_heads = top_heads_by_output_patch[\n", + " top_heads_by_output_patch < model.cfg.n_heads * first_mid_layer\n", + "]\n", + "mid_heads = top_heads_by_output_patch[\n", + " torch.logical_and(\n", + " model.cfg.n_heads * first_mid_layer <= top_heads_by_output_patch,\n", + " top_heads_by_output_patch < model.cfg.n_heads * first_late_layer,\n", + " )\n", + "]\n", + "late_heads = top_heads_by_output_patch[\n", + " model.cfg.n_heads * first_late_layer <= top_heads_by_output_patch\n", + "]\n", + "\n", + "early = visualize_attention_patterns(\n", + " early_heads, cache, tokens[0], title=f\"Top Early Heads\"\n", + ")\n", + "mid = visualize_attention_patterns(\n", + " mid_heads, cache, tokens[0], title=f\"Top Middle Heads\"\n", + ")\n", + "late = visualize_attention_patterns(\n", + " late_heads, cache, tokens[0], title=f\"Top Late Heads\"\n", + ")\n", + "\n", + "HTML(early + mid + late)" ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Comparing to the Paper\n", + "\n", + "We can now refer to the (far, far more rigorous and detailed) analysis in the paper to compare our results! Here's the diagram they give of their results. \n", + "\n", + "![IOI1](https://pbs.twimg.com/media/FghGkTAWAAAmkhm.jpg)\n", + "\n", + "(Head 1.2 in their notation is L1H2 in my notation etc. And note - in the [latest version of the paper](https://arxiv.org/pdf/2211.00593.pdf) they add 9.0 as a backup name mover, and remove 11.3)\n", + "\n", + "The heads form three categories corresponding to the early, middle and late categories we found and we did fairly well! Definitely not perfect, but with some fairly generic techniques and some a priori reasoning, we found the broad strokes of the circuit and what it looks like. We focused on the most important heads, so we didn't find all relevant heads in each category (especially not the heads in brackets, which are more minor), but this serves as a good base for doing more rigorous and involved analysis, especially for finding the *complete* circuit (ie all of the parts of the model which participate in this behaviour) rather than just a partial and suggestive circuit. Go check out [their paper](https://arxiv.org/abs/2211.00593) or [our interview](https://www.youtube.com/watch?v=gzwj0jWbvbo) to learn more about what they did and what they found!\n", + "\n", + "Breaking down their categories:\n", + "\n", + "* Early: The duplicate token heads, previous token heads and induction heads. These serve the purpose of detecting that the second subject is duplicated and which earlier name is the duplicate.\n", + " * We found a direct duplicate token head which behaves exactly as expected, L3H0. Heads L5H0 and L6H9 are induction heads, which explains why they don't attend directly to the earlier copy of John!\n", + " * Note that the duplicate token heads and induction heads do not compose with each other - both directly add to the S-Inhibition heads. The diagram is somewhat misleading.\n", + "* Middle: They call these S-Inhibition heads - they copy the information about the duplicate token from the second subject to the to token, and their output is used to *inhibit* the attention paid from the name movers to the first subject copy. We found all these heads, and had a decent guess for what they did.\n", + " * In either case they attend to the second subject, so the patch that mattered was their value vectors!\n", + "* Late: They call these name movers, and we found some of them. They attend from the final token to the indirect object name and copy that to the logits, using the S-Inhibition heads to inhibit attention to the first copy of the subject token.\n", + " * We did find their surprising result of *negative* name movers - name movers that inhibit the correct answer!\n", + " * They have an entire category of heads we missed called backup name movers - we'll get to these later.\n", + "\n", + "So, now, let's dig into the two anomalies we missed - induction heads and backup name mover heads" ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Bonus: Exploring Anomalies" ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } }, - "title": { - "text": "Duplicate Token Scores" + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Early Heads are Induction Heads(?!)\n", + "\n", + "A really weird observation is that some of the early heads detecting duplicated tokens are induction heads, not just direct duplicate token heads. This is very weird! What's up with that? \n", + "\n", + "First off, what's an induction head? An induction head is an important type of attention head that can detect and continue repeated sequences. It is the second head in a two head induction circuit, which looks for previous copies of the current token and attends to the token *after* it, and then copies that to the current position and predicts that it will come next. They're enough of a big deal that [we wrote a whole paper on them](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html).\n", + "\n", + "![Move image demo](https://pbs.twimg.com/media/FNWAzXjVEAEOGRe.jpg)\n", + "\n", + "Second, why is it surprising that they come up here? It's surprising because it feels like overkill. The model doesn't care about *what* token comes after the first copy of the subject, just that it's duplicated. And it already has simpler duplicate token heads. My best guess is that it just already had induction heads around and that, in addition to their main function, they *also* only activate on duplicated tokens. So it was useful to repurpose this existing machinery. \n", + "\n", + "This suggests that as we look for circuits in larger models life may get more and more complicated, as components in simpler circuits get repurposed and built upon. " + ] }, - "xaxis": { - "anchor": "y", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "scaleanchor": "y", - "title": { - "text": "Head" - } + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can verify that these are induction heads by running the model on repeated text and plotting the heads." + ] }, - "yaxis": { - "anchor": "x", - "autorange": "reversed", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "title": { - "text": "Layer" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ { - "coloraxis": "coloraxis", - "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", - "name": "0", - "type": "heatmap", - "xaxis": "x", - "yaxis": "y", - "z": [ - [ - 0.004035575315356255, - 3.85937346436549e-05, - 0.003946058917790651, - 1.7428524756724073e-07, - 5.9896130551351234e-05, - 4.0836803236743435e-05, - 0.0035017586778849363, - 0.00024610417312942445, - 0.0031679815147072077, - 0.0030104012694209814, - 0.002093541668727994, - 0.008525434881448746 - ], - [ - 0.000526473973877728, - 0.00015670718858018517, - 0.001507942914031446, - 0.005595325026661158, - 0.0018401180859655142, - 0.0038875630125403404, - 0.005349153187125921, - 0.004649169277399778, - 0.005880181211978197, - 0.007283917628228664, - 0.005552186165004969, - 0.00012677280756179243 - ], - [ - 0.0022015420254319906, - 0.008784863166511059, - 0.002159146359190345, - 0.0010447809472680092, - 0.005142326466739178, - 0.002251626690849662, - 0.0008376616751775146, - 0.006352409720420837, - 0.002618127502501011, - 0.0010309136705473065, - 0.00015219187480397522, - 0.005351166240870953 - ], - [ - 0.007752244360744953, - 0.0030915802344679832, - 0.001362923881970346, - 0.004341960418969393, - 0.011233060620725155, - 0.006535551976412535, - 0.000906877510715276, - 0.0006078600417822599, - 0.002819513902068138, - 0.005254077725112438, - 0.004195652436465025, - 0.00255418848246336 - ], - [ - 0.007342735771089792, - 0.004788339603692293, - 0.007458819076418877, - 0.0033073313534259796, - 0.007871866226196289, - 0.004219769034534693, - 0.004172054585069418, - 0.0005154653917998075, - 0.008124975487589836, - 0.0068268910981714725, - 0.008085492067039013, - 3.761376626831847e-11 - ], - [ - 0.4337766170501709, - 0.9306095838546753, - 0.006382268853485584, - 0.0034730439074337482, - 0.005500996019691229, - 0.9255973696708679, - 0.00538142304867506, - 0.007857315242290497, - 0.00863779615610838, - 0.01576443389058113, - 0.012188379652798176, - 0.008265726268291473 - ], - [ - 0.002507298020645976, - 0.008432027883827686, - 0.008623305708169937, - 0.007653353735804558, - 0.01105806790292263, - 0.005525435321033001, - 0.017205175012350082, - 0.004794349893927574, - 0.0040976013988256454, - 0.9257788062095642, - 0.020375633612275124, - 0.006313954945653677 - ], - [ - 0.005555536597967148, - 0.18942977488040924, - 0.8509925007820129, - 0.008273146115243435, - 0.008239664137363434, - 0.00864996388554573, - 0.02832852303981781, - 0.08996275067329407, - 0.006617339327931404, - 0.009413909167051315, - 0.9037814736366272, - 0.03037159889936447 - ], - [ - 0.00735454261302948, - 0.3791317641735077, - 0.005602709017693996, - 0.025401461869478226, - 0.008504674769937992, - 0.00623108958825469, - 0.11892436444759369, - 0.005114651285111904, - 0.013350939378142357, - 0.01576736941933632, - 0.025843923911452293, - 0.008429747074842453 - ], - [ - 0.2398916333913803, - 0.14378757774829865, - 0.09330663084983826, - 0.005819779820740223, - 0.07744801044464111, - 0.01644793339073658, - 0.4442836344242096, - 0.011141352355480194, - 0.03619001433253288, - 0.472646564245224, - 0.00803996529430151, - 0.030953049659729004 - ], - [ - 0.3606555163860321, - 0.48201146721839905, - 0.022851115092635155, - 0.1264195442199707, - 0.04125598818063736, - 0.0072374604642391205, - 0.2877156138420105, - 0.3897320628166199, - 0.030060900375247, - 0.006112942937761545, - 0.1655488908290863, - 0.22245149314403534 - ], - [ - 0.007408542558550835, - 0.033737149089574814, - 0.02041277289390564, - 0.002755412133410573, - 0.02518630214035511, - 0.07808877527713776, - 0.033082809299230576, - 0.046440087258815765, - 0.0032543439883738756, - 0.2744256258010864, - 0.3800230026245117, - 0.009483495727181435 - ] - ] - } - ], - "layout": { - "coloraxis": { - "cmid": 0, - "colorscale": [ - [ - 0, - "rgb(103,0,31)" - ], - [ - 0.1, - "rgb(178,24,43)" - ], - [ - 0.2, - "rgb(214,96,77)" - ], - [ - 0.3, - "rgb(244,165,130)" - ], - [ - 0.4, - "rgb(253,219,199)" - ], - [ - 0.5, - "rgb(247,247,247)" - ], - [ - 0.6, - "rgb(209,229,240)" - ], - [ - 0.7, - "rgb(146,197,222)" - ], - [ - 0.8, - "rgb(67,147,195)" - ], - [ - 0.9, - "rgb(33,102,172)" - ], - [ - 1, - "rgb(5,48,97)" - ] - ] + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "example_text = \"Research in mechanistic interpretability seeks to explain behaviors of machine learning models in terms of their internal components.\"\n", + "example_repeated_text = example_text + example_text\n", + "example_repeated_tokens = model.to_tokens(example_repeated_text, prepend_bos=True)\n", + "example_repeated_logits, example_repeated_cache = model.run_with_cache(\n", + " example_repeated_tokens\n", + ")\n", + "induction_head_labels = [81, 65]" + ] }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "

Induction Heads


\n", + "
" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "code = visualize_attention_patterns(\n", + " induction_head_labels,\n", + " example_repeated_cache,\n", + " example_repeated_tokens,\n", + " title=\"Induction Heads\",\n", + " max_width=800,\n", + ")\n", + "HTML(code)" ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Implications\n", + "\n", + "One implication of this is that it's useful to categories heads according to whether they occur in\n", + "simpler circuits, so that as we look for more complex circuits we can easily look for them. This is\n", + "easy to do here! An interesting fact about induction heads is that they work on a sequence of\n", + "repeated random tokens - notable for being wildly off distribution from the natural language GPT-2\n", + "was trained on. Being able to predict a model's behaviour off distribution is a good mark of success\n", + "for mechanistic interpretability! This is a good sanity check for whether a head is an induction\n", + "head or not. \n", + "\n", + "We can characterise an induction head by just giving a sequence of random tokens repeated once, and\n", + "measuring the average attention paid from the second copy of a token to the token after the first\n", + "copy. At the same time, we can also measure the average attention paid from the second copy of a\n", + "token to the first copy of the token, which is the attention that the induction head would pay if it\n", + "were a duplicate token head, and the average attention paid to the previous token to find previous\n", + "token heads.\n", + "\n", + "Note that this is a superficial study of whether something is an induction head - we totally ignore\n", + "the question of whether it actually does boost the correct token or whether it composes with a\n", + "single previous head and how. In particular, we sometimes get anti-induction heads which suppress\n", + "the induction-y token (no clue why!), and this technique will find those too . But given the\n", + "previous rigorous analysis, we can be pretty confident that this picks up on some true signal about\n", + "induction heads." ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
Technical Implementation Details \n", + "We can do this again by using hooks, this time just to access the attention patterns rather than to intervene on them. \n", + "\n", + "Our hook function acts on the attention pattern activation. This has the name\n", + "\"blocks.{layer}.{layer_type}.hook_{activation_name}\" in general, here it's\n", + "\"blocks.{layer}.attn.hook_attn\". And it has shape [batch, head_index, query_pos, token_pos]. Our\n", + "hook function takes in the attention pattern activation, calculates the score for the relevant type\n", + "of head, and write it to an external cache.\n", + "\n", + "We add in hooks using `model.run_with_hooks(tokens, fwd_hooks=[(names_filter, hook_fn)])` to\n", + "temporarily add in the hooks and run the model, getting the resulting output. Previously\n", + "names_filter was the name of the activation, but here it's a boolean function mapping activation\n", + "names to whether we want to hook them or not. Here it's just whether the name ends with hook_attn.\n", + "hook_fn must take in the two inputs activation (the activation tensor) and hook (the HookPoint\n", + "object, which contains the name of the activation and some metadata such as the current layer).\n", + "\n", + "Internally our hooks use the function `tensor.diagonal`, this takes the diagonal between two\n", + "dimensions, and allows an arbitrary offset - offset by 1 to get previous tokens, seq_len to get\n", + "duplicate tokens (the distance to earlier copies) and seq_len-1 to get induction heads (the distance\n", + "to the token *after* earlier copies). Different offsets give a different length of output tensor,\n", + "and we can now just average to get a score in [0, 1] for each head\n", + "
" ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } }, - "title": { - "text": "Induction Head Scores" + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[0.0390, 0.0000, 0.0310],\n", + " [0.1890, 0.1720, 0.0680],\n", + " [0.1570, 0.0210, 0.4820]])\n", + "tensor([[0.0030, 0.1320, 0.0050],\n", + " [0.0000, 0.0000, 0.0020],\n", + " [0.0020, 0.0090, 0.0000]])\n", + "tensor([[0.0040, 0.0000, 0.0040],\n", + " [0.0010, 0.0000, 0.0020],\n", + " [0.0020, 0.0090, 0.0020]])\n" + ] + } + ], + "source": [ + "seq_len = 100\n", + "batch_size = 2\n", + "\n", + "prev_token_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=device)\n", + "\n", + "\n", + "def prev_token_hook(pattern, hook):\n", + " layer = hook.layer()\n", + " diagonal = pattern.diagonal(offset=1, dim1=-1, dim2=-2)\n", + " # print(diagonal)\n", + " # print(pattern)\n", + " prev_token_scores[layer] = einops.reduce(\n", + " diagonal, \"batch head_index diagonal -> head_index\", \"mean\"\n", + " )\n", + "\n", + "\n", + "duplicate_token_scores = torch.zeros(\n", + " (model.cfg.n_layers, model.cfg.n_heads), device=device\n", + ")\n", + "\n", + "\n", + "def duplicate_token_hook(pattern, hook):\n", + " layer = hook.layer()\n", + " diagonal = pattern.diagonal(offset=seq_len, dim1=-1, dim2=-2)\n", + " duplicate_token_scores[layer] = einops.reduce(\n", + " diagonal, \"batch head_index diagonal -> head_index\", \"mean\"\n", + " )\n", + "\n", + "\n", + "induction_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=device)\n", + "\n", + "\n", + "def induction_hook(pattern, hook):\n", + " layer = hook.layer()\n", + " diagonal = pattern.diagonal(offset=seq_len - 1, dim1=-1, dim2=-2)\n", + " induction_scores[layer] = einops.reduce(\n", + " diagonal, \"batch head_index diagonal -> head_index\", \"mean\"\n", + " )\n", + "\n", + "\n", + "torch.manual_seed(0)\n", + "original_tokens = torch.randint(\n", + " 100, 20000, size=(batch_size, seq_len), device=\"cpu\"\n", + ").to(device)\n", + "repeated_tokens = einops.repeat(\n", + " original_tokens, \"batch seq_len -> batch (2 seq_len)\"\n", + ").to(device)\n", + "\n", + "pattern_filter = lambda act_name: act_name.endswith(\"hook_pattern\")\n", + "\n", + "loss = model.run_with_hooks(\n", + " repeated_tokens,\n", + " return_type=\"loss\",\n", + " fwd_hooks=[\n", + " (pattern_filter, prev_token_hook),\n", + " (pattern_filter, duplicate_token_hook),\n", + " (pattern_filter, induction_hook),\n", + " ],\n", + ")\n", + "print(torch.round(utils.get_corner(prev_token_scores).detach().cpu(), decimals=3))\n", + "print(torch.round(utils.get_corner(duplicate_token_scores).detach().cpu(), decimals=3))\n", + "print(torch.round(utils.get_corner(induction_scores).detach().cpu(), decimals=3))" + ] }, - "xaxis": { - "anchor": "y", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "scaleanchor": "y", - "title": { - "text": "Head" - } + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now plot the head scores, and instantly see that the relevant early heads are induction heads or duplicate token heads (though also that there's a lot of induction heads that are *not* use - I have no idea why!). " + ] }, - "yaxis": { - "anchor": "x", - "autorange": "reversed", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "title": { - "text": "Layer" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "imshow(\n", - " prev_token_scores, labels={\"x\": \"Head\", \"y\": \"Layer\"}, title=\"Previous Token Scores\"\n", - ")\n", - "imshow(\n", - " duplicate_token_scores,\n", - " labels={\"x\": \"Head\", \"y\": \"Layer\"},\n", - " title=\"Duplicate Token Scores\",\n", - ")\n", - "imshow(\n", - " induction_scores, labels={\"x\": \"Head\", \"y\": \"Layer\"}, title=\"Induction Head Scores\"\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The above suggests that it would be a useful bit of infrastructure to have a \"wiki\" for the heads of a model, giving their scores according to some metrics re head functions, like the ones we've seen here. TransformerLens makes this easy to make, as just changing the name input to `HookedTransformer.from_pretrained` gives a different model but in the same architecture, so the same code should work. If you want to make this, I'd love to see it! \n", - "\n", - "As a proof of concept, [I made a mosaic of all induction heads across the 40 models then in TransformerLens](https://www.neelnanda.io/mosaic).\n", - "\n", - "![induction scores as proof of concept](https://firebasestorage.googleapis.com/v0/b/firescript-577a2.appspot.com/o/imgs%2Fapp%2FNeelNanda%2F5vtuFmdzt_.png?alt=media&token=4d613de4-9d14-48d6-ba9d-e591c562d429)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Backup Name Mover Heads" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Another fascinating anomaly is that of the **backup name mover heads**. A standard technique to apply when interpreting model internals is ablations, or knock-out. If we run the model but intervene to set a specific head to zero, what happens? If the model is robust to this intervention, then naively we can be confident that the head is not doing anything important, and conversely if the model is much worse at the task this suggests that head was important. There are several conceptual flaws with this approach, making the evidence only suggestive, eg that the average output of the head may be far from zero and so the knockout may send it far from expected activations, breaking internals on *any* task. But it's still an easy technique to apply to give some data.\n", - "\n", - "But a wild finding in the paper is that models have **built in redundancy**. If we knock out one of the name movers, then there are some backup name movers in later layers that *change their behaviour* and do (some of) the job of the original name mover head. This means that naive knock-out will significantly underestimate the importance of the name movers.\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's test this! Let's ablate the most important name mover (head L9H9) on just the final token using a custom ablation hook and then cache all new activations and compared performance. We focus on the final position because we want to specifically ablate the direct logit effect. When we do this, we see that naively, removing the top name mover should reduce the logit diff massively, from 3.55 to 0.57. **But actually, it only goes down to 2.99!**\n", - "\n", - "
Implementation Details \n", - "Ablating heads is really easy in TransformerLens! We can just define a hook on the z activation in the relevant attention layer (recall, z is the mixed values, and comes immediately before multiplying by the output weights $W_O$). z has a head_index axis, so we can set the component for the relevant head and for position -1 to zero, and return it. (Technically we could just edit in place without returning it, but by convention we always return an edited activation). \n", - "\n", - "We now want to compare all internal activations with a hook, which is hard to do with the nice `run_with_hooks` API. So we can directly access the hook on the z activation with `model.blocks[layer].attn.hook_z` and call its `add_hook` method. This adds in the hook to the *global state* of the model. We can now use run_with_cache, and don't need to care about the global state, because run_with_cache internally adds a bunch of caching hooks, and then removes all hooks after the run, *including* the previously added ablation hook. This can be disabled with the reset_hooks_end flag, but here it's useful! \n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Top Name Mover to ablate: L9H9\n", - "Original logit diff: 3.55\n", - "Post ablation logit diff: 2.92\n", - "Direct Logit Attribution of top name mover head: 2.99\n", - "Naive prediction of post ablation logit diff: 0.57\n" - ] - } - ], - "source": [ - "top_name_mover = per_head_logit_diffs.flatten().argmax().item()\n", - "top_name_mover_layer = top_name_mover // model.cfg.n_heads\n", - "top_name_mover_head = top_name_mover % model.cfg.n_heads\n", - "print(f\"Top Name Mover to ablate: L{top_name_mover_layer}H{top_name_mover_head}\")\n", - "\n", - "\n", - "def ablate_top_head_hook(z: Float[torch.Tensor, \"batch pos head_index d_head\"], hook):\n", - " z[:, -1, top_name_mover_head, :] = 0\n", - " return z\n", - "\n", - "\n", - "# Adds a hook into global model state\n", - "model.blocks[top_name_mover_layer].attn.hook_z.add_hook(ablate_top_head_hook)\n", - "# Runs the model, temporarily adds caching hooks and then removes *all* hooks after running, including the ablation hook.\n", - "ablated_logits, ablated_cache = model.run_with_cache(tokens)\n", - "print(f\"Original logit diff: {original_average_logit_diff:.2f}\")\n", - "print(\n", - " f\"Post ablation logit diff: {logits_to_ave_logit_diff(ablated_logits, answer_tokens).item():.2f}\"\n", - ")\n", - "print(\n", - " f\"Direct Logit Attribution of top name mover head: {per_head_logit_diffs.flatten()[top_name_mover].item():.2f}\"\n", - ")\n", - "print(\n", - " f\"Naive prediction of post ablation logit diff: {original_average_logit_diff - per_head_logit_diffs.flatten()[top_name_mover].item():.2f}\"\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "So what's up with this? As before, we can look at the direct logit attribution of each head to see what's going on. It's easiest to interpret if plotted as a scatter plot against the initial per head logit difference.\n", - "\n", - "And we can see a *really* big difference in a few heads! (Hover to see labels) In particular the negative name mover L10H7 decreases its negative effect a lot, adding +1 to the logit diff, and the backup name mover L10H10 adjusts its effect to be more positive, adding +0.8 to the logit diff (with several other marginal changes). (And obviously the ablated head has gone down to zero!)" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tried to stack head results when they weren't cached. Computing head results now\n" - ] - }, - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ { - "coloraxis": "coloraxis", - "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", - "name": "0", - "type": "heatmap", - "xaxis": "x", - "yaxis": "y", - "z": [ - [ - -0.002156503964215517, - -0.0004650682385545224, - 0.00024167183437384665, - 0.0002806585980579257, - -0.0004162999684922397, - -0.0004892416181974113, - -0.002620948012918234, - -0.002935677068307996, - 0.00042561208829283714, - 0.0005418329383246601, - 0.00023754138965159655, - -7.48957390896976e-05 - ], - [ - -0.000658505829051137, - 0.0004060641804244369, - -0.0009330413886345923, - 0.0008937822422012687, - -0.0009785268921405077, - -0.000533820129930973, - -0.0027988189831376076, - -0.004214101936668158, - 0.002578593324869871, - 0.0024506838526576757, - 0.0005351756699383259, - 0.0012349633034318686 - ], - [ - 0.0009405204327777028, - -0.0011168691562488675, - -0.0011541967978700995, - -0.0015697095077484846, - -0.0005699327448382974, - 0.001451514894142747, - 0.002439911477267742, - 0.003158293664455414, - 0.000923738582059741, - -0.003578126197680831, - -0.0010650777257978916, - -0.0003558753523975611 - ], - [ - -0.0005624951445497572, - -1.1960582924075425e-05, - 0.0011531109921634197, - 0.0007360265008173883, - 0.0016493839211761951, - 0.0008800819050520658, - -0.0006905529880896211, - -0.003031972097232938, - 0.0008080147090367973, - 0.00010368914809077978, - -0.0005807994166389108, - -0.0011067037703469396 - ], - [ - -0.0026375530287623405, - 0.0002691895351745188, - -0.0016417437000200152, - -0.003406986128538847, - 0.0017449699807912111, - 0.00046454701805487275, - -0.0007899806369096041, - 0.0018328562146052718, - -0.00086324627045542, - -0.0003978293389081955, - 0.0007879206677898765, - -0.00012048585631418973 - ], - [ - 0.0008688560919836164, - 0.0009473530226387084, - -0.0022812988609075546, - -0.0011803123634308577, - 0.0002407809515716508, - -0.0004318578285165131, - -0.0003728170122485608, - -0.000738416681997478, - 0.0008113418589346111, - -0.00040444196201860905, - -0.007074396125972271, - 0.003946478478610516 - ], - [ - -0.014917617663741112, - -0.0022801742888987064, - 0.0022679336834698915, - -8.302251808345318e-05, - -0.004980948753654957, - 0.0027670026756823063, - 0.006266288459300995, - -0.003485947148874402, - -0.0013348984066396952, - -0.0017918883822858334, - -0.0012231896398589015, - 0.00040514359716326 - ], - [ - -0.0002460568503011018, - -0.005790225230157375, - -0.0004975841729901731, - 0.142182856798172, - -0.0014961492270231247, - -0.019006317481398582, - 0.003133433870971203, - -0.001858205534517765, - -0.011305196210741997, - 0.1922595500946045, - -0.0011892566690221429, - -0.0010282933944836259 - ], - [ - -0.0038003993686288595, - -0.0008570950012654066, - -0.013956742361187935, - 0.00828910805284977, - 0.004315475933253765, - -0.009073829278349876, - -0.08315148949623108, - 0.0034569751005619764, - -0.01805492490530014, - 0.002178061753511429, - 0.29780513048171997, - 0.02409379370510578 - ], - [ - 0.08904723823070526, - -0.0007931794971227646, - 0.07247699797153473, - 0.015016308054327965, - -0.02120928093791008, - 0.05205465108156204, - 1.4411165714263916, - 0.04743674397468567, - -0.03229031339287758, - 0, - 0.0019993737805634737, - -0.00807223655283451 - ], - [ - 0.8600788116455078, - 0.3260062038898468, - 0.16344408690929413, - 0.07133537530899048, - -0.00444837287068367, - 0.000681330740917474, - 0.36613449454307556, - -0.7105098962783813, - -0.002031375654041767, - -0.032143525779247284, - 1.2294330596923828, - 0.0018453558441251516 - ], - [ - 0.016877274960279465, - -0.001730365096591413, - -0.5010868310928345, - 0.02749764919281006, - -0.0059662917628884315, - -0.004944110754877329, - -0.08855228126049042, - 0.006622308399528265, - 0.044124361127614975, - -0.02726735547184944, - -1.134916067123413, - 0.02287953346967697 - ] - ] - } - ], - "layout": { - "coloraxis": { - "cmid": 0, - "colorscale": [ - [ - 0, - "rgb(103,0,31)" - ], - [ - 0.1, - "rgb(178,24,43)" - ], - [ - 0.2, - "rgb(214,96,77)" - ], - [ - 0.3, - "rgb(244,165,130)" - ], - [ - 0.4, - "rgb(253,219,199)" - ], - [ - 0.5, - "rgb(247,247,247)" - ], - [ - 0.6, - "rgb(209,229,240)" - ], - [ - 0.7, - "rgb(146,197,222)" - ], - [ - 0.8, - "rgb(67,147,195)" - ], - [ - 0.9, - "rgb(33,102,172)" - ], - [ - 1, - "rgb(5,48,97)" - ] - ] + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + 0.039069853723049164, + 0.0004489101702347398, + 0.03133601322770119, + 0.007519590202718973, + 0.034592196345329285, + 0.00036230171099305153, + 0.034512776881456375, + 0.19740213453769684, + 0.038447845727205276, + 0.04053792357444763, + 0.027628764510154724, + 0.02496313862502575 + ], + [ + 0.1890650987625122, + 0.17219914495944977, + 0.06807752698659897, + 0.04494515433907509, + 0.07908554375171661, + 0.03096739575266838, + 0.028282109647989273, + 0.03644327446818352, + 0.026936717331409454, + 0.018826229497790337, + 0.045100897550582886, + 0.0065726665779948235 + ], + [ + 0.15745528042316437, + 0.020724520087242126, + 0.4817989468574524, + 0.2991352379322052, + 0.10764895379543304, + 0.33004048466682434, + 0.0997551754117012, + 0.04926132410764694, + 0.25493940711021423, + 0.3606453835964203, + 0.1257179230451584, + 0.07931824028491974 + ], + [ + 0.005844001192599535, + 0.15787364542484283, + 0.4189082086086273, + 0.30129021406173706, + 0.014345049858093262, + 0.032344333827495575, + 0.3312888443470001, + 0.5285974144935608, + 0.34242063760757446, + 0.101837158203125, + 0.10516070574522018, + 0.2233113795518875 + ], + [ + 0.10626544803380966, + 0.11930850893259048, + 0.022880680859088898, + 0.22826944291591644, + 0.020003994926810265, + 0.10010036826133728, + 0.1739213615655899, + 0.17407020926475525, + 0.02587701380252838, + 0.10249985754489899, + 0.009514841251075268, + 0.9921423196792603 + ], + [ + 0.019766658544540405, + 0.00528325280174613, + 0.16648508608341217, + 0.12087740004062653, + 0.16500000655651093, + 0.00803269725292921, + 0.41770195960998535, + 0.025827765464782715, + 0.04802601411938667, + 0.016231779009103775, + 0.03110172413289547, + 0.024261215701699257 + ], + [ + 0.2172909826040268, + 0.039100028574466705, + 0.01804858259856701, + 0.059900715947151184, + 0.032934583723545074, + 0.0873451679944992, + 0.026895340532064438, + 0.0943947583436966, + 0.49925994873046875, + 0.006240115500986576, + 0.027026718482375145, + 0.1278565675020218 + ], + [ + 0.2511657178401947, + 0.01330868061631918, + 0.006663354113698006, + 0.037430502474308014, + 0.02331537753343582, + 0.01740722358226776, + 0.022067422047257423, + 0.022141192108392715, + 0.04502448812127113, + 0.0208425372838974, + 0.008310739882290363, + 0.017167754471302032 + ], + [ + 0.020890623331069946, + 0.016537941992282867, + 0.02158307284116745, + 0.0150058064609766, + 0.02421221323311329, + 0.10198988765478134, + 0.029100384563207626, + 0.22793792188167572, + 0.02781485579907894, + 0.0179410632699728, + 0.024828944355249405, + 0.03806235268712044 + ], + [ + 0.02607586607336998, + 0.015407431870698929, + 0.02044427953660488, + 0.14558182656764984, + 0.01247025839984417, + 0.017151640728116035, + 0.013311829417943954, + 0.024451706558465958, + 0.018111787736415863, + 0.01319331955164671, + 0.0357399508357048, + 0.01879822090268135 + ], + [ + 0.02147812582552433, + 0.018419174477458, + 0.018183622509241104, + 0.02172141708433628, + 0.0315677747130394, + 0.034705750644207, + 0.017550116404891014, + 0.011417553760111332, + 0.01579565554857254, + 0.04592214897274971, + 0.01621554046869278, + 0.03039470687508583 + ], + [ + 0.03320508822798729, + 0.0175714660435915, + 0.015131079591810703, + 0.04148406535387039, + 0.015181189402937889, + 0.01758997142314911, + 0.015148494392633438, + 0.01767607219517231, + 0.06622709333896637, + 0.018451133742928505, + 0.01700744964182377, + 0.029749270528554916 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Previous Token Scores" + }, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "scaleanchor": "y", + "title": { + "text": "Head" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Layer" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + 0.0031923248898237944, + 0.13236315548419952, + 0.005006915424019098, + 0.000010427449524286203, + 0.0013110184809193015, + 0.7034568786621094, + 0.00426204688847065, + 0.00016496369789820164, + 0.002474633976817131, + 0.0008572910446673632, + 0.01889149099588394, + 0.008690938353538513 + ], + [ + 0.0002916341181844473, + 0.00013782267342321575, + 0.0015036173863336444, + 0.005392482969909906, + 0.0018583914497867227, + 0.009062949568033218, + 0.012414448894560337, + 0.0022405502386391163, + 0.005135662388056517, + 0.005220627877861261, + 0.005546474829316139, + 0.02975049614906311 + ], + [ + 0.0024816279765218496, + 0.009442180395126343, + 0.0003456332196947187, + 0.0002591445227153599, + 0.0052116685546934605, + 0.000570951378904283, + 0.0015209749108180404, + 0.006313100922852755, + 0.001560864970088005, + 0.0004215767839923501, + 0.00015359291865024716, + 0.005160381551831961 + ], + [ + 0.6775657534599304, + 0.002840448170900345, + 0.0007841526530683041, + 0.00471264636144042, + 0.006322895642369986, + 0.006206681486219168, + 0.0005474805948324502, + 0.00037829449865967035, + 0.0020155368838459253, + 0.007952751591801643, + 0.003576782764866948, + 0.002608788898214698 + ], + [ + 0.00860405620187521, + 0.0070286463014781475, + 0.007598803844302893, + 0.003442801535129547, + 0.016561277210712433, + 0.0059797209687530994, + 0.004869826138019562, + 0.0007624455611221492, + 0.006062133703380823, + 0.007536627352237701, + 0.012022900395095348, + 1.055422134237094e-12 + ], + [ + 0.00950299296528101, + 0.00856209360063076, + 0.004162600729614496, + 0.003008665982633829, + 0.006847422569990158, + 0.004358117934316397, + 0.007669268175959587, + 0.009584215469658375, + 0.0076188258826732635, + 0.0043280418030917645, + 0.041402824223041534, + 0.00976183544844389 + ], + [ + 0.004456141032278538, + 0.008873268961906433, + 0.007405205629765987, + 0.0062249391339719296, + 0.00731915095821023, + 0.005623893812298775, + 0.017349667847156525, + 0.005529467947781086, + 0.002920132130384445, + 0.008636755868792534, + 0.006222263444215059, + 0.00835894700139761 + ], + [ + 0.003699858672916889, + 0.04107949137687683, + 0.04148268699645996, + 0.009313640184700489, + 0.009097025729715824, + 0.008774377405643463, + 0.007298537530004978, + 0.023312218487262726, + 0.008843323215842247, + 0.00987986009567976, + 0.017598601058125496, + 0.006039854139089584 + ], + [ + 0.008986304514110088, + 0.028667239472270012, + 0.008891218341886997, + 0.010114557109773159, + 0.009737391024827957, + 0.007611637003719807, + 0.009763265959918499, + 0.005155472084879875, + 0.009276345372200012, + 0.011895839124917984, + 0.010411946102976799, + 0.007498950231820345 + ], + [ + 0.024409977719187737, + 0.011438451707363129, + 0.02003096230328083, + 0.0051185814663767815, + 0.015081286430358887, + 0.012334450148046017, + 0.015452565625309944, + 0.008602450601756573, + 0.014702522195875645, + 0.020766200497746468, + 0.009192758239805698, + 0.005703347735106945 + ], + [ + 0.017897022888064384, + 0.013280633836984634, + 0.006755237001925707, + 0.012744844891130924, + 0.008020960725843906, + 0.007722244597971439, + 0.017341373488307, + 0.0074546560645103455, + 0.007832515984773636, + 0.00825214572250843, + 0.013642766512930393, + 0.012807483784854412 + ], + [ + 0.004923742264509201, + 0.007951060310006142, + 0.007947920821607113, + 0.004564082249999046, + 0.010363400913774967, + 0.009582078084349632, + 0.0102877551689744, + 0.00832072552293539, + 0.0025700009427964687, + 0.012810997664928436, + 0.008063871413469315, + 0.006558285094797611 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Duplicate Token Scores" + }, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "scaleanchor": "y", + "title": { + "text": "Head" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Layer" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + 0.004035575315356255, + 0.0000385937346436549, + 0.003946058917790651, + 1.7428524756724073e-7, + 0.000059896130551351234, + 0.000040836803236743435, + 0.0035017586778849363, + 0.00024610417312942445, + 0.0031679815147072077, + 0.0030104012694209814, + 0.002093541668727994, + 0.008525434881448746 + ], + [ + 0.000526473973877728, + 0.00015670718858018517, + 0.001507942914031446, + 0.005595325026661158, + 0.0018401180859655142, + 0.0038875630125403404, + 0.005349153187125921, + 0.004649169277399778, + 0.005880181211978197, + 0.007283917628228664, + 0.005552186165004969, + 0.00012677280756179243 + ], + [ + 0.0022015420254319906, + 0.008784863166511059, + 0.002159146359190345, + 0.0010447809472680092, + 0.005142326466739178, + 0.002251626690849662, + 0.0008376616751775146, + 0.006352409720420837, + 0.002618127502501011, + 0.0010309136705473065, + 0.00015219187480397522, + 0.005351166240870953 + ], + [ + 0.007752244360744953, + 0.0030915802344679832, + 0.001362923881970346, + 0.004341960418969393, + 0.011233060620725155, + 0.006535551976412535, + 0.000906877510715276, + 0.0006078600417822599, + 0.002819513902068138, + 0.005254077725112438, + 0.004195652436465025, + 0.00255418848246336 + ], + [ + 0.007342735771089792, + 0.004788339603692293, + 0.007458819076418877, + 0.0033073313534259796, + 0.007871866226196289, + 0.004219769034534693, + 0.004172054585069418, + 0.0005154653917998075, + 0.008124975487589836, + 0.0068268910981714725, + 0.008085492067039013, + 3.761376626831847e-11 + ], + [ + 0.4337766170501709, + 0.9306095838546753, + 0.006382268853485584, + 0.0034730439074337482, + 0.005500996019691229, + 0.9255973696708679, + 0.00538142304867506, + 0.007857315242290497, + 0.00863779615610838, + 0.01576443389058113, + 0.012188379652798176, + 0.008265726268291473 + ], + [ + 0.002507298020645976, + 0.008432027883827686, + 0.008623305708169937, + 0.007653353735804558, + 0.01105806790292263, + 0.005525435321033001, + 0.017205175012350082, + 0.004794349893927574, + 0.0040976013988256454, + 0.9257788062095642, + 0.020375633612275124, + 0.006313954945653677 + ], + [ + 0.005555536597967148, + 0.18942977488040924, + 0.8509925007820129, + 0.008273146115243435, + 0.008239664137363434, + 0.00864996388554573, + 0.02832852303981781, + 0.08996275067329407, + 0.006617339327931404, + 0.009413909167051315, + 0.9037814736366272, + 0.03037159889936447 + ], + [ + 0.00735454261302948, + 0.3791317641735077, + 0.005602709017693996, + 0.025401461869478226, + 0.008504674769937992, + 0.00623108958825469, + 0.11892436444759369, + 0.005114651285111904, + 0.013350939378142357, + 0.01576736941933632, + 0.025843923911452293, + 0.008429747074842453 + ], + [ + 0.2398916333913803, + 0.14378757774829865, + 0.09330663084983826, + 0.005819779820740223, + 0.07744801044464111, + 0.01644793339073658, + 0.4442836344242096, + 0.011141352355480194, + 0.03619001433253288, + 0.472646564245224, + 0.00803996529430151, + 0.030953049659729004 + ], + [ + 0.3606555163860321, + 0.48201146721839905, + 0.022851115092635155, + 0.1264195442199707, + 0.04125598818063736, + 0.0072374604642391205, + 0.2877156138420105, + 0.3897320628166199, + 0.030060900375247, + 0.006112942937761545, + 0.1655488908290863, + 0.22245149314403534 + ], + [ + 0.007408542558550835, + 0.033737149089574814, + 0.02041277289390564, + 0.002755412133410573, + 0.02518630214035511, + 0.07808877527713776, + 0.033082809299230576, + 0.046440087258815765, + 0.0032543439883738756, + 0.2744256258010864, + 0.3800230026245117, + 0.009483495727181435 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Induction Head Scores" + }, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "scaleanchor": "y", + "title": { + "text": "Head" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Layer" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "imshow(\n", + " prev_token_scores, labels={\"x\": \"Head\", \"y\": \"Layer\"}, title=\"Previous Token Scores\"\n", + ")\n", + "imshow(\n", + " duplicate_token_scores,\n", + " labels={\"x\": \"Head\", \"y\": \"Layer\"},\n", + " title=\"Duplicate Token Scores\",\n", + ")\n", + "imshow(\n", + " induction_scores, labels={\"x\": \"Head\", \"y\": \"Layer\"}, title=\"Induction Head Scores\"\n", + ")" + ] }, - "margin": { - "t": 60 + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The above suggests that it would be a useful bit of infrastructure to have a \"wiki\" for the heads of a model, giving their scores according to some metrics re head functions, like the ones we've seen here. TransformerLens makes this easy to make, as just changing the name input to `HookedTransformer.from_pretrained` gives a different model but in the same architecture, so the same code should work. If you want to make this, I'd love to see it! \n", + "\n", + "As a proof of concept, [I made a mosaic of all induction heads across the 40 models then in TransformerLens](https://www.neelnanda.io/mosaic).\n", + "\n", + "![induction scores as proof of concept](https://firebasestorage.googleapis.com/v0/b/firescript-577a2.appspot.com/o/imgs%2Fapp%2FNeelNanda%2F5vtuFmdzt_.png?alt=media&token=4d613de4-9d14-48d6-ba9d-e591c562d429)" + ] }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Backup Name Mover Heads" ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Another fascinating anomaly is that of the **backup name mover heads**. A standard technique to apply when interpreting model internals is ablations, or knock-out. If we run the model but intervene to set a specific head to zero, what happens? If the model is robust to this intervention, then naively we can be confident that the head is not doing anything important, and conversely if the model is much worse at the task this suggests that head was important. There are several conceptual flaws with this approach, making the evidence only suggestive, eg that the average output of the head may be far from zero and so the knockout may send it far from expected activations, breaking internals on *any* task. But it's still an easy technique to apply to give some data.\n", + "\n", + "But a wild finding in the paper is that models have **built in redundancy**. If we knock out one of the name movers, then there are some backup name movers in later layers that *change their behaviour* and do (some of) the job of the original name mover head. This means that naive knock-out will significantly underestimate the importance of the name movers.\n" ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's test this! Let's ablate the most important name mover (head L9H9) on just the final token using a custom ablation hook and then cache all new activations and compared performance. We focus on the final position because we want to specifically ablate the direct logit effect. When we do this, we see that naively, removing the top name mover should reduce the logit diff massively, from 3.55 to 0.57. **But actually, it only goes down to 2.99!**\n", + "\n", + "
Implementation Details \n", + "Ablating heads is really easy in TransformerLens! We can just define a hook on the z activation in the relevant attention layer (recall, z is the mixed values, and comes immediately before multiplying by the output weights $W_O$). z has a head_index axis, so we can set the component for the relevant head and for position -1 to zero, and return it. (Technically we could just edit in place without returning it, but by convention we always return an edited activation). \n", + "\n", + "We now want to compare all internal activations with a hook, which is hard to do with the nice `run_with_hooks` API. So we can directly access the hook on the z activation with `model.blocks[layer].attn.hook_z` and call its `add_hook` method. This adds in the hook to the *global state* of the model. We can now use run_with_cache, and don't need to care about the global state, because run_with_cache internally adds a bunch of caching hooks, and then removes all hooks after the run, *including* the previously added ablation hook. This can be disabled with the reset_hooks_end flag, but here it's useful! \n", + "
" ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } }, - "xaxis": { - "anchor": "y", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "scaleanchor": "y", - "title": { - "text": "Head" - } + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Top Name Mover to ablate: L9H9\n", + "Original logit diff: 3.55\n", + "Post ablation logit diff: 2.92\n", + "Direct Logit Attribution of top name mover head: 2.99\n", + "Naive prediction of post ablation logit diff: 0.57\n" + ] + } + ], + "source": [ + "top_name_mover = per_head_logit_diffs.flatten().argmax().item()\n", + "top_name_mover_layer = top_name_mover // model.cfg.n_heads\n", + "top_name_mover_head = top_name_mover % model.cfg.n_heads\n", + "print(f\"Top Name Mover to ablate: L{top_name_mover_layer}H{top_name_mover_head}\")\n", + "\n", + "\n", + "def ablate_top_head_hook(z: Float[torch.Tensor, \"batch pos head_index d_head\"], hook):\n", + " z[:, -1, top_name_mover_head, :] = 0\n", + " return z\n", + "\n", + "\n", + "# Adds a hook into global model state\n", + "model.blocks[top_name_mover_layer].attn.hook_z.add_hook(ablate_top_head_hook)\n", + "# Runs the model, temporarily adds caching hooks and then removes *all* hooks after running, including the ablation hook.\n", + "ablated_logits, ablated_cache = model.run_with_cache(tokens)\n", + "print(f\"Original logit diff: {original_average_logit_diff:.2f}\")\n", + "print(\n", + " f\"Post ablation logit diff: {logits_to_ave_logit_diff(ablated_logits, answer_tokens).item():.2f}\"\n", + ")\n", + "print(\n", + " f\"Direct Logit Attribution of top name mover head: {per_head_logit_diffs.flatten()[top_name_mover].item():.2f}\"\n", + ")\n", + "print(\n", + " f\"Naive prediction of post ablation logit diff: {original_average_logit_diff - per_head_logit_diffs.flatten()[top_name_mover].item():.2f}\"\n", + ")" + ] }, - "yaxis": { - "anchor": "x", - "autorange": "reversed", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "title": { - "text": "Layer" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ { - "hovertemplate": "%{hovertext}

Ablated=%{x}
Original=%{y}", - "hovertext": [ - "L0H0", - "L0H1", - "L0H2", - "L0H3", - "L0H4", - "L0H5", - "L0H6", - "L0H7", - "L0H8", - "L0H9", - "L0H10", - "L0H11", - "L1H0", - "L1H1", - "L1H2", - "L1H3", - "L1H4", - "L1H5", - "L1H6", - "L1H7", - "L1H8", - "L1H9", - "L1H10", - "L1H11", - "L2H0", - "L2H1", - "L2H2", - "L2H3", - "L2H4", - "L2H5", - "L2H6", - "L2H7", - "L2H8", - "L2H9", - "L2H10", - "L2H11", - "L3H0", - "L3H1", - "L3H2", - "L3H3", - "L3H4", - "L3H5", - "L3H6", - "L3H7", - "L3H8", - "L3H9", - "L3H10", - "L3H11", - "L4H0", - "L4H1", - "L4H2", - "L4H3", - "L4H4", - "L4H5", - "L4H6", - "L4H7", - "L4H8", - "L4H9", - "L4H10", - "L4H11", - "L5H0", - "L5H1", - "L5H2", - "L5H3", - "L5H4", - "L5H5", - "L5H6", - "L5H7", - "L5H8", - "L5H9", - "L5H10", - "L5H11", - "L6H0", - "L6H1", - "L6H2", - "L6H3", - "L6H4", - "L6H5", - "L6H6", - "L6H7", - "L6H8", - "L6H9", - "L6H10", - "L6H11", - "L7H0", - "L7H1", - "L7H2", - "L7H3", - "L7H4", - "L7H5", - "L7H6", - "L7H7", - "L7H8", - "L7H9", - "L7H10", - "L7H11", - "L8H0", - "L8H1", - "L8H2", - "L8H3", - "L8H4", - "L8H5", - "L8H6", - "L8H7", - "L8H8", - "L8H9", - "L8H10", - "L8H11", - "L9H0", - "L9H1", - "L9H2", - "L9H3", - "L9H4", - "L9H5", - "L9H6", - "L9H7", - "L9H8", - "L9H9", - "L9H10", - "L9H11", - "L10H0", - "L10H1", - "L10H2", - "L10H3", - "L10H4", - "L10H5", - "L10H6", - "L10H7", - "L10H8", - "L10H9", - "L10H10", - "L10H11", - "L11H0", - "L11H1", - "L11H2", - "L11H3", - "L11H4", - "L11H5", - "L11H6", - "L11H7", - "L11H8", - "L11H9", - "L11H10", - "L11H11" - ], - "legendgroup": "", - "marker": { - "color": "#636efa", - "symbol": "circle" - }, - "mode": "markers", - "name": "", - "orientation": "v", - "showlegend": false, - "type": "scatter", - "x": [ - -0.002156503964215517, - -0.0004650682385545224, - 0.00024167183437384665, - 0.0002806585980579257, - -0.0004162999684922397, - -0.0004892416181974113, - -0.002620948012918234, - -0.002935677068307996, - 0.00042561208829283714, - 0.0005418329383246601, - 0.00023754138965159655, - -7.48957390896976e-05, - -0.000658505829051137, - 0.0004060641804244369, - -0.0009330413886345923, - 0.0008937822422012687, - -0.0009785268921405077, - -0.000533820129930973, - -0.0027988189831376076, - -0.004214101936668158, - 0.002578593324869871, - 0.0024506838526576757, - 0.0005351756699383259, - 0.0012349633034318686, - 0.0009405204327777028, - -0.0011168691562488675, - -0.0011541967978700995, - -0.0015697095077484846, - -0.0005699327448382974, - 0.001451514894142747, - 0.002439911477267742, - 0.003158293664455414, - 0.000923738582059741, - -0.003578126197680831, - -0.0010650777257978916, - -0.0003558753523975611, - -0.0005624951445497572, - -1.1960582924075425e-05, - 0.0011531109921634197, - 0.0007360265008173883, - 0.0016493839211761951, - 0.0008800819050520658, - -0.0006905529880896211, - -0.003031972097232938, - 0.0008080147090367973, - 0.00010368914809077978, - -0.0005807994166389108, - -0.0011067037703469396, - -0.0026375530287623405, - 0.0002691895351745188, - -0.0016417437000200152, - -0.003406986128538847, - 0.0017449699807912111, - 0.00046454701805487275, - -0.0007899806369096041, - 0.0018328562146052718, - -0.00086324627045542, - -0.0003978293389081955, - 0.0007879206677898765, - -0.00012048585631418973, - 0.0008688560919836164, - 0.0009473530226387084, - -0.0022812988609075546, - -0.0011803123634308577, - 0.0002407809515716508, - -0.0004318578285165131, - -0.0003728170122485608, - -0.000738416681997478, - 0.0008113418589346111, - -0.00040444196201860905, - -0.007074396125972271, - 0.003946478478610516, - -0.014917617663741112, - -0.0022801742888987064, - 0.0022679336834698915, - -8.302251808345318e-05, - -0.004980948753654957, - 0.0027670026756823063, - 0.006266288459300995, - -0.003485947148874402, - -0.0013348984066396952, - -0.0017918883822858334, - -0.0012231896398589015, - 0.00040514359716326, - -0.0002460568503011018, - -0.005790225230157375, - -0.0004975841729901731, - 0.142182856798172, - -0.0014961492270231247, - -0.019006317481398582, - 0.003133433870971203, - -0.001858205534517765, - -0.011305196210741997, - 0.1922595500946045, - -0.0011892566690221429, - -0.0010282933944836259, - -0.0038003993686288595, - -0.0008570950012654066, - -0.013956742361187935, - 0.00828910805284977, - 0.004315475933253765, - -0.009073829278349876, - -0.08315148949623108, - 0.0034569751005619764, - -0.01805492490530014, - 0.002178061753511429, - 0.29780513048171997, - 0.02409379370510578, - 0.08904723823070526, - -0.0007931794971227646, - 0.07247699797153473, - 0.015016308054327965, - -0.02120928093791008, - 0.05205465108156204, - 1.4411165714263916, - 0.04743674397468567, - -0.03229031339287758, - 0, - 0.0019993737805634737, - -0.00807223655283451, - 0.8600788116455078, - 0.3260062038898468, - 0.16344408690929413, - 0.07133537530899048, - -0.00444837287068367, - 0.000681330740917474, - 0.36613449454307556, - -0.7105098962783813, - -0.002031375654041767, - -0.032143525779247284, - 1.2294330596923828, - 0.0018453558441251516, - 0.016877274960279465, - -0.001730365096591413, - -0.5010868310928345, - 0.02749764919281006, - -0.0059662917628884315, - -0.004944110754877329, - -0.08855228126049042, - 0.006622308399528265, - 0.044124361127614975, - -0.02726735547184944, - -1.134916067123413, - 0.02287953346967697 - ], - "xaxis": "x", - "y": [ - -0.0020563392899930477, - -0.0005101899732835591, - 0.0004685786843765527, - 0.00012512074317783117, - -0.0006028738571330905, - -0.0002429460291750729, - -0.0023189077619463205, - -0.002758360467851162, - 0.000564602785743773, - 0.0009697531932033598, - -0.0002504526637494564, - 4.737317794933915e-06, - -0.0010070882271975279, - 0.00039470894262194633, - -0.00154874159488827, - 0.0014034928753972054, - -0.0012653048615902662, - -0.0011358022456988692, - -0.00281596090644598, - -0.0029645217582583427, - 0.0029190476052463055, - 0.0025743592996150255, - 0.00036239007022231817, - 0.0017548729665577412, - 0.0005569400964304805, - -0.001126631861552596, - -0.0017353934235870838, - -0.0014514457434415817, - -0.00028735760133713484, - 0.0017211002996191382, - 0.0026658899150788784, - 0.00311466702260077, - 0.0005667927907779813, - -0.003666515462100506, - -0.0018847601022571325, - 7.039372576400638e-06, - -0.0007264417363330722, - 0.00011364505917299539, - 0.0014301587361842394, - 0.0007490540738217533, - 0.0020184689201414585, - 0.0007436950691044331, - -0.00046178390039131045, - -0.0039057559333741665, - 0.0011406694538891315, - -4.022853681817651e-05, - -0.0013293239753693342, - -0.0017636751290410757, - -0.0028280913829803467, - 0.00033634810824878514, - -0.0014248639345169067, - -0.003777273464947939, - 0.0015998880844563246, - 0.0002989505883306265, - -0.000804675742983818, - 0.002038792008534074, - -0.0015593919670209289, - -0.0006436670082621276, - 0.0011168173514306545, - -0.00035012533771805465, - 0.0011338205076754093, - 0.0011259170714765787, - -0.002516670385375619, - -0.0014790185960009694, - 0.0003878737334161997, - -6.408110493794084e-05, - -0.0005096744280308485, - -0.0008840755908749998, - 0.0006398351397365332, - -0.0010097370250150561, - -0.006759158335626125, - 0.0033667823299765587, - -0.01514742337167263, - -0.0021350777242332697, - 0.002593174111098051, - -0.00042678468162193894, - -0.005558924749493599, - 0.0026658528950065374, - 0.006411008536815643, - -0.003826778382062912, - -0.0003843410813715309, - -0.0016430341638624668, - -0.0013344454346224666, - -9.20506427064538e-05, - -9.476230479776859e-05, - -0.0057889921590685844, - -0.0006383581785485148, - 0.13493388891220093, - -0.001768707763403654, - -0.018917907029390335, - 0.003873429261147976, - -0.0021450775675475597, - -0.010327338241040707, - 0.18325845897197723, - -0.0007747983909212053, - -0.00104526337236166, - -0.003833949100226164, - -0.0008046097937040031, - -0.012673400342464447, - 0.00804573018103838, - 0.003604492638260126, - -0.009398287162184715, - -0.08272082358598709, - 0.003555194940418005, - -0.018404025584459305, - 0.0017587244510650635, - 0.2896133363246918, - 0.022854052484035492, - 0.08595258742570877, - -0.0006932877004146576, - 0.06817055493593216, - 0.013111240230500698, - -0.021098043769598007, - 0.05112447217106819, - 1.3844914436340332, - 0.045836858451366425, - -0.03830280900001526, - 2.985445976257324, - 0.0019662054255604744, - -0.008030137047171593, - 0.5608693957328796, - 0.17083050310611725, - -0.03361757844686508, - 0.05821544677019119, - -0.0024530249647796154, - 0.0018771197646856308, - 0.28827205300331116, - -1.8986485004425049, - -0.0015286931302398443, - -0.035129792988300323, - 0.4802178740501404, - -0.0009115453576669097, - 0.016075748950242996, - -0.03986122086644173, - -0.3879126012325287, - 0.011123123578727245, - -0.005477819126099348, - -0.0025129620917141438, - -0.08056175708770752, - 0.007518616039305925, - 0.0430111438035965, - -0.040082238614559174, - -0.9702364802360535, - 0.011862239800393581 - ], - "yaxis": "y" - } - ], - "layout": { - "legend": { - "tracegroupgap": 0 + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "So what's up with this? As before, we can look at the direct logit attribution of each head to see what's going on. It's easiest to interpret if plotted as a scatter plot against the initial per head logit difference.\n", + "\n", + "And we can see a *really* big difference in a few heads! (Hover to see labels) In particular the negative name mover L10H7 decreases its negative effect a lot, adding +1 to the logit diff, and the backup name mover L10H10 adjusts its effect to be more positive, adding +0.8 to the logit diff (with several other marginal changes). (And obviously the ablated head has gone down to zero!)" + ] }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tried to stack head results when they weren't cached. Computing head results now\n" + ] + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + -0.002156503964215517, + -0.0004650682385545224, + 0.00024167183437384665, + 0.0002806585980579257, + -0.0004162999684922397, + -0.0004892416181974113, + -0.002620948012918234, + -0.002935677068307996, + 0.00042561208829283714, + 0.0005418329383246601, + 0.00023754138965159655, + -0.0000748957390896976 + ], + [ + -0.000658505829051137, + 0.0004060641804244369, + -0.0009330413886345923, + 0.0008937822422012687, + -0.0009785268921405077, + -0.000533820129930973, + -0.0027988189831376076, + -0.004214101936668158, + 0.002578593324869871, + 0.0024506838526576757, + 0.0005351756699383259, + 0.0012349633034318686 + ], + [ + 0.0009405204327777028, + -0.0011168691562488675, + -0.0011541967978700995, + -0.0015697095077484846, + -0.0005699327448382974, + 0.001451514894142747, + 0.002439911477267742, + 0.003158293664455414, + 0.000923738582059741, + -0.003578126197680831, + -0.0010650777257978916, + -0.0003558753523975611 + ], + [ + -0.0005624951445497572, + -0.000011960582924075425, + 0.0011531109921634197, + 0.0007360265008173883, + 0.0016493839211761951, + 0.0008800819050520658, + -0.0006905529880896211, + -0.003031972097232938, + 0.0008080147090367973, + 0.00010368914809077978, + -0.0005807994166389108, + -0.0011067037703469396 + ], + [ + -0.0026375530287623405, + 0.0002691895351745188, + -0.0016417437000200152, + -0.003406986128538847, + 0.0017449699807912111, + 0.00046454701805487275, + -0.0007899806369096041, + 0.0018328562146052718, + -0.00086324627045542, + -0.0003978293389081955, + 0.0007879206677898765, + -0.00012048585631418973 + ], + [ + 0.0008688560919836164, + 0.0009473530226387084, + -0.0022812988609075546, + -0.0011803123634308577, + 0.0002407809515716508, + -0.0004318578285165131, + -0.0003728170122485608, + -0.000738416681997478, + 0.0008113418589346111, + -0.00040444196201860905, + -0.007074396125972271, + 0.003946478478610516 + ], + [ + -0.014917617663741112, + -0.0022801742888987064, + 0.0022679336834698915, + -0.00008302251808345318, + -0.004980948753654957, + 0.0027670026756823063, + 0.006266288459300995, + -0.003485947148874402, + -0.0013348984066396952, + -0.0017918883822858334, + -0.0012231896398589015, + 0.00040514359716326 + ], + [ + -0.0002460568503011018, + -0.005790225230157375, + -0.0004975841729901731, + 0.142182856798172, + -0.0014961492270231247, + -0.019006317481398582, + 0.003133433870971203, + -0.001858205534517765, + -0.011305196210741997, + 0.1922595500946045, + -0.0011892566690221429, + -0.0010282933944836259 + ], + [ + -0.0038003993686288595, + -0.0008570950012654066, + -0.013956742361187935, + 0.00828910805284977, + 0.004315475933253765, + -0.009073829278349876, + -0.08315148949623108, + 0.0034569751005619764, + -0.01805492490530014, + 0.002178061753511429, + 0.29780513048171997, + 0.02409379370510578 + ], + [ + 0.08904723823070526, + -0.0007931794971227646, + 0.07247699797153473, + 0.015016308054327965, + -0.02120928093791008, + 0.05205465108156204, + 1.4411165714263916, + 0.04743674397468567, + -0.03229031339287758, + 0, + 0.0019993737805634737, + -0.00807223655283451 + ], + [ + 0.8600788116455078, + 0.3260062038898468, + 0.16344408690929413, + 0.07133537530899048, + -0.00444837287068367, + 0.000681330740917474, + 0.36613449454307556, + -0.7105098962783813, + -0.002031375654041767, + -0.032143525779247284, + 1.2294330596923828, + 0.0018453558441251516 + ], + [ + 0.016877274960279465, + -0.001730365096591413, + -0.5010868310928345, + 0.02749764919281006, + -0.0059662917628884315, + -0.004944110754877329, + -0.08855228126049042, + 0.006622308399528265, + 0.044124361127614975, + -0.02726735547184944, + -1.134916067123413, + 0.02287953346967697 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "margin": { + "t": 60 + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "scaleanchor": "y", + "title": { + "text": "Head" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Layer" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "hovertemplate": "%{hovertext}

Ablated=%{x}
Original=%{y}", + "hovertext": [ + "L0H0", + "L0H1", + "L0H2", + "L0H3", + "L0H4", + "L0H5", + "L0H6", + "L0H7", + "L0H8", + "L0H9", + "L0H10", + "L0H11", + "L1H0", + "L1H1", + "L1H2", + "L1H3", + "L1H4", + "L1H5", + "L1H6", + "L1H7", + "L1H8", + "L1H9", + "L1H10", + "L1H11", + "L2H0", + "L2H1", + "L2H2", + "L2H3", + "L2H4", + "L2H5", + "L2H6", + "L2H7", + "L2H8", + "L2H9", + "L2H10", + "L2H11", + "L3H0", + "L3H1", + "L3H2", + "L3H3", + "L3H4", + "L3H5", + "L3H6", + "L3H7", + "L3H8", + "L3H9", + "L3H10", + "L3H11", + "L4H0", + "L4H1", + "L4H2", + "L4H3", + "L4H4", + "L4H5", + "L4H6", + "L4H7", + "L4H8", + "L4H9", + "L4H10", + "L4H11", + "L5H0", + "L5H1", + "L5H2", + "L5H3", + "L5H4", + "L5H5", + "L5H6", + "L5H7", + "L5H8", + "L5H9", + "L5H10", + "L5H11", + "L6H0", + "L6H1", + "L6H2", + "L6H3", + "L6H4", + "L6H5", + "L6H6", + "L6H7", + "L6H8", + "L6H9", + "L6H10", + "L6H11", + "L7H0", + "L7H1", + "L7H2", + "L7H3", + "L7H4", + "L7H5", + "L7H6", + "L7H7", + "L7H8", + "L7H9", + "L7H10", + "L7H11", + "L8H0", + "L8H1", + "L8H2", + "L8H3", + "L8H4", + "L8H5", + "L8H6", + "L8H7", + "L8H8", + "L8H9", + "L8H10", + "L8H11", + "L9H0", + "L9H1", + "L9H2", + "L9H3", + "L9H4", + "L9H5", + "L9H6", + "L9H7", + "L9H8", + "L9H9", + "L9H10", + "L9H11", + "L10H0", + "L10H1", + "L10H2", + "L10H3", + "L10H4", + "L10H5", + "L10H6", + "L10H7", + "L10H8", + "L10H9", + "L10H10", + "L10H11", + "L11H0", + "L11H1", + "L11H2", + "L11H3", + "L11H4", + "L11H5", + "L11H6", + "L11H7", + "L11H8", + "L11H9", + "L11H10", + "L11H11" + ], + "legendgroup": "", + "marker": { + "color": "#636efa", + "symbol": "circle" + }, + "mode": "markers", + "name": "", + "orientation": "v", + "showlegend": false, + "type": "scatter", + "x": [ + -0.002156503964215517, + -0.0004650682385545224, + 0.00024167183437384665, + 0.0002806585980579257, + -0.0004162999684922397, + -0.0004892416181974113, + -0.002620948012918234, + -0.002935677068307996, + 0.00042561208829283714, + 0.0005418329383246601, + 0.00023754138965159655, + -0.0000748957390896976, + -0.000658505829051137, + 0.0004060641804244369, + -0.0009330413886345923, + 0.0008937822422012687, + -0.0009785268921405077, + -0.000533820129930973, + -0.0027988189831376076, + -0.004214101936668158, + 0.002578593324869871, + 0.0024506838526576757, + 0.0005351756699383259, + 0.0012349633034318686, + 0.0009405204327777028, + -0.0011168691562488675, + -0.0011541967978700995, + -0.0015697095077484846, + -0.0005699327448382974, + 0.001451514894142747, + 0.002439911477267742, + 0.003158293664455414, + 0.000923738582059741, + -0.003578126197680831, + -0.0010650777257978916, + -0.0003558753523975611, + -0.0005624951445497572, + -0.000011960582924075425, + 0.0011531109921634197, + 0.0007360265008173883, + 0.0016493839211761951, + 0.0008800819050520658, + -0.0006905529880896211, + -0.003031972097232938, + 0.0008080147090367973, + 0.00010368914809077978, + -0.0005807994166389108, + -0.0011067037703469396, + -0.0026375530287623405, + 0.0002691895351745188, + -0.0016417437000200152, + -0.003406986128538847, + 0.0017449699807912111, + 0.00046454701805487275, + -0.0007899806369096041, + 0.0018328562146052718, + -0.00086324627045542, + -0.0003978293389081955, + 0.0007879206677898765, + -0.00012048585631418973, + 0.0008688560919836164, + 0.0009473530226387084, + -0.0022812988609075546, + -0.0011803123634308577, + 0.0002407809515716508, + -0.0004318578285165131, + -0.0003728170122485608, + -0.000738416681997478, + 0.0008113418589346111, + -0.00040444196201860905, + -0.007074396125972271, + 0.003946478478610516, + -0.014917617663741112, + -0.0022801742888987064, + 0.0022679336834698915, + -0.00008302251808345318, + -0.004980948753654957, + 0.0027670026756823063, + 0.006266288459300995, + -0.003485947148874402, + -0.0013348984066396952, + -0.0017918883822858334, + -0.0012231896398589015, + 0.00040514359716326, + -0.0002460568503011018, + -0.005790225230157375, + -0.0004975841729901731, + 0.142182856798172, + -0.0014961492270231247, + -0.019006317481398582, + 0.003133433870971203, + -0.001858205534517765, + -0.011305196210741997, + 0.1922595500946045, + -0.0011892566690221429, + -0.0010282933944836259, + -0.0038003993686288595, + -0.0008570950012654066, + -0.013956742361187935, + 0.00828910805284977, + 0.004315475933253765, + -0.009073829278349876, + -0.08315148949623108, + 0.0034569751005619764, + -0.01805492490530014, + 0.002178061753511429, + 0.29780513048171997, + 0.02409379370510578, + 0.08904723823070526, + -0.0007931794971227646, + 0.07247699797153473, + 0.015016308054327965, + -0.02120928093791008, + 0.05205465108156204, + 1.4411165714263916, + 0.04743674397468567, + -0.03229031339287758, + 0, + 0.0019993737805634737, + -0.00807223655283451, + 0.8600788116455078, + 0.3260062038898468, + 0.16344408690929413, + 0.07133537530899048, + -0.00444837287068367, + 0.000681330740917474, + 0.36613449454307556, + -0.7105098962783813, + -0.002031375654041767, + -0.032143525779247284, + 1.2294330596923828, + 0.0018453558441251516, + 0.016877274960279465, + -0.001730365096591413, + -0.5010868310928345, + 0.02749764919281006, + -0.0059662917628884315, + -0.004944110754877329, + -0.08855228126049042, + 0.006622308399528265, + 0.044124361127614975, + -0.02726735547184944, + -1.134916067123413, + 0.02287953346967697 + ], + "xaxis": "x", + "y": [ + -0.0020563392899930477, + -0.0005101899732835591, + 0.0004685786843765527, + 0.00012512074317783117, + -0.0006028738571330905, + -0.0002429460291750729, + -0.0023189077619463205, + -0.002758360467851162, + 0.000564602785743773, + 0.0009697531932033598, + -0.0002504526637494564, + 0.000004737317794933915, + -0.0010070882271975279, + 0.00039470894262194633, + -0.00154874159488827, + 0.0014034928753972054, + -0.0012653048615902662, + -0.0011358022456988692, + -0.00281596090644598, + -0.0029645217582583427, + 0.0029190476052463055, + 0.0025743592996150255, + 0.00036239007022231817, + 0.0017548729665577412, + 0.0005569400964304805, + -0.001126631861552596, + -0.0017353934235870838, + -0.0014514457434415817, + -0.00028735760133713484, + 0.0017211002996191382, + 0.0026658899150788784, + 0.00311466702260077, + 0.0005667927907779813, + -0.003666515462100506, + -0.0018847601022571325, + 0.000007039372576400638, + -0.0007264417363330722, + 0.00011364505917299539, + 0.0014301587361842394, + 0.0007490540738217533, + 0.0020184689201414585, + 0.0007436950691044331, + -0.00046178390039131045, + -0.0039057559333741665, + 0.0011406694538891315, + -0.00004022853681817651, + -0.0013293239753693342, + -0.0017636751290410757, + -0.0028280913829803467, + 0.00033634810824878514, + -0.0014248639345169067, + -0.003777273464947939, + 0.0015998880844563246, + 0.0002989505883306265, + -0.000804675742983818, + 0.002038792008534074, + -0.0015593919670209289, + -0.0006436670082621276, + 0.0011168173514306545, + -0.00035012533771805465, + 0.0011338205076754093, + 0.0011259170714765787, + -0.002516670385375619, + -0.0014790185960009694, + 0.0003878737334161997, + -0.00006408110493794084, + -0.0005096744280308485, + -0.0008840755908749998, + 0.0006398351397365332, + -0.0010097370250150561, + -0.006759158335626125, + 0.0033667823299765587, + -0.01514742337167263, + -0.0021350777242332697, + 0.002593174111098051, + -0.00042678468162193894, + -0.005558924749493599, + 0.0026658528950065374, + 0.006411008536815643, + -0.003826778382062912, + -0.0003843410813715309, + -0.0016430341638624668, + -0.0013344454346224666, + -0.0000920506427064538, + -0.00009476230479776859, + -0.0057889921590685844, + -0.0006383581785485148, + 0.13493388891220093, + -0.001768707763403654, + -0.018917907029390335, + 0.003873429261147976, + -0.0021450775675475597, + -0.010327338241040707, + 0.18325845897197723, + -0.0007747983909212053, + -0.00104526337236166, + -0.003833949100226164, + -0.0008046097937040031, + -0.012673400342464447, + 0.00804573018103838, + 0.003604492638260126, + -0.009398287162184715, + -0.08272082358598709, + 0.003555194940418005, + -0.018404025584459305, + 0.0017587244510650635, + 0.2896133363246918, + 0.022854052484035492, + 0.08595258742570877, + -0.0006932877004146576, + 0.06817055493593216, + 0.013111240230500698, + -0.021098043769598007, + 0.05112447217106819, + 1.3844914436340332, + 0.045836858451366425, + -0.03830280900001526, + 2.985445976257324, + 0.0019662054255604744, + -0.008030137047171593, + 0.5608693957328796, + 0.17083050310611725, + -0.03361757844686508, + 0.05821544677019119, + -0.0024530249647796154, + 0.0018771197646856308, + 0.28827205300331116, + -1.8986485004425049, + -0.0015286931302398443, + -0.035129792988300323, + 0.4802178740501404, + -0.0009115453576669097, + 0.016075748950242996, + -0.03986122086644173, + -0.3879126012325287, + 0.011123123578727245, + -0.005477819126099348, + -0.0025129620917141438, + -0.08056175708770752, + 0.007518616039305925, + 0.0430111438035965, + -0.040082238614559174, + -0.9702364802360535, + 0.011862239800393581 + ], + "yaxis": "y" + } + ], + "layout": { + "legend": { + "tracegroupgap": 0 + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Original vs Post-Ablation Direct Logit Attribution of Heads" + }, + "xaxis": { + "anchor": "y", + "domain": [ + 0, + 1 + ], + "range": [ + -3, + 3 + ], + "title": { + "text": "Ablated" + } + }, + "yaxis": { + "anchor": "x", + "domain": [ + 0, + 1 + ], + "range": [ + -3, + 3 + ], + "title": { + "text": "Original" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "per_head_ablated_residual, labels = ablated_cache.stack_head_results(\n", + " layer=-1, pos_slice=-1, return_labels=True\n", + ")\n", + "per_head_ablated_logit_diffs = residual_stack_to_logit_diff(\n", + " per_head_ablated_residual, ablated_cache\n", + ")\n", + "per_head_ablated_logit_diffs = per_head_ablated_logit_diffs.reshape(\n", + " model.cfg.n_layers, model.cfg.n_heads\n", + ")\n", + "imshow(per_head_ablated_logit_diffs, labels={\"x\": \"Head\", \"y\": \"Layer\"})\n", + "scatter(\n", + " y=per_head_logit_diffs.flatten(),\n", + " x=per_head_ablated_logit_diffs.flatten(),\n", + " hover_name=head_labels,\n", + " range_x=(-3, 3),\n", + " range_y=(-3, 3),\n", + " xaxis=\"Ablated\",\n", + " yaxis=\"Original\",\n", + " title=\"Original vs Post-Ablation Direct Logit Attribution of Heads\",\n", + ")" ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "One natural hypothesis is that this is because the final LayerNorm scaling has changed, which can scale up or down the final residual stream. This is slightly true, and we can see that the typical head is a bit off from the x=y line. But the average LN scaling ratio is 1.04, and this should uniformly change *all* heads by the same factor, so this can't be sufficient" ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average LN scaling ratio: 1.042\n", + "Ablation LN scale tensor([[18.5200],\n", + " [17.4700],\n", + " [17.8200],\n", + " [17.5100],\n", + " [17.2600],\n", + " [18.2500],\n", + " [16.1800],\n", + " [17.4300]])\n", + "Original LN scale tensor([[19.5700],\n", + " [18.3500],\n", + " [18.2900],\n", + " [18.6800],\n", + " [17.4900],\n", + " [18.8700],\n", + " [16.4200],\n", + " [18.6800]])\n" + ] + } + ], + "source": [ + "print(\n", + " \"Average LN scaling ratio:\",\n", + " round(\n", + " (\n", + " cache[\"ln_final.hook_scale\"][:, -1]\n", + " / ablated_cache[\"ln_final.hook_scale\"][:, -1]\n", + " )\n", + " .mean()\n", + " .item(),\n", + " 3,\n", + " ),\n", + ")\n", + "print(\n", + " \"Ablation LN scale\",\n", + " ablated_cache[\"ln_final.hook_scale\"][:, -1].detach().cpu().round(decimals=2),\n", + ")\n", + "print(\n", + " \"Original LN scale\",\n", + " cache[\"ln_final.hook_scale\"][:, -1].detach().cpu().round(decimals=2),\n", + ")" ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } }, - "title": { - "text": "Original vs Post-Ablation Direct Logit Attribution of Heads" + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Exercise to the reader:** Can you finish off this analysis? What's going on here? Why are the backup name movers changing their behaviour? Why is one negative name mover becoming significantly less important?" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" }, - "xaxis": { - "anchor": "y", - "domain": [ - 0, - 1 - ], - "range": [ - -3, - 3 - ], - "title": { - "text": "Ablated" - } + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" }, - "yaxis": { - "anchor": "x", - "domain": [ - 0, - 1 - ], - "range": [ - -3, - 3 - ], - "title": { - "text": "Original" - } + "vscode": { + "interpreter": { + "hash": "eb812820b5094695c8a581672e17220e30dd2c15d704c018326e3cc2e1a566f1" + } } - } - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "per_head_ablated_residual, labels = ablated_cache.stack_head_results(\n", - " layer=-1, pos_slice=-1, return_labels=True\n", - ")\n", - "per_head_ablated_logit_diffs = residual_stack_to_logit_diff(\n", - " per_head_ablated_residual, ablated_cache\n", - ")\n", - "per_head_ablated_logit_diffs = per_head_ablated_logit_diffs.reshape(\n", - " model.cfg.n_layers, model.cfg.n_heads\n", - ")\n", - "imshow(per_head_ablated_logit_diffs, labels={\"x\": \"Head\", \"y\": \"Layer\"})\n", - "scatter(\n", - " y=per_head_logit_diffs.flatten(),\n", - " x=per_head_ablated_logit_diffs.flatten(),\n", - " hover_name=head_labels,\n", - " range_x=(-3, 3),\n", - " range_y=(-3, 3),\n", - " xaxis=\"Ablated\",\n", - " yaxis=\"Original\",\n", - " title=\"Original vs Post-Ablation Direct Logit Attribution of Heads\",\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "One natural hypothesis is that this is because the final LayerNorm scaling has changed, which can scale up or down the final residual stream. This is slightly true, and we can see that the typical head is a bit off from the x=y line. But the average LN scaling ratio is 1.04, and this should uniformly change *all* heads by the same factor, so this can't be sufficient" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Average LN scaling ratio: 1.042\n", - "Ablation LN scale tensor([[18.5200],\n", - " [17.4700],\n", - " [17.8200],\n", - " [17.5100],\n", - " [17.2600],\n", - " [18.2500],\n", - " [16.1800],\n", - " [17.4300]])\n", - "Original LN scale tensor([[19.5700],\n", - " [18.3500],\n", - " [18.2900],\n", - " [18.6800],\n", - " [17.4900],\n", - " [18.8700],\n", - " [16.4200],\n", - " [18.6800]])\n" - ] - } - ], - "source": [ - "print(\n", - " \"Average LN scaling ratio:\",\n", - " round(\n", - " (\n", - " cache[\"ln_final.hook_scale\"][:, -1]\n", - " / ablated_cache[\"ln_final.hook_scale\"][:, -1]\n", - " )\n", - " .mean()\n", - " .item(),\n", - " 3,\n", - " ),\n", - ")\n", - "print(\n", - " \"Ablation LN scale\",\n", - " ablated_cache[\"ln_final.hook_scale\"][:, -1].detach().cpu().round(decimals=2),\n", - ")\n", - "print(\n", - " \"Original LN scale\",\n", - " cache[\"ln_final.hook_scale\"][:, -1].detach().cpu().round(decimals=2),\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Exercise to the reader:** Can you finish off this analysis? What's going on here? Why are the backup name movers changing their behaviour? Why is one negative name mover becoming significantly less important?" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" - }, - "vscode": { - "interpreter": { - "hash": "eb812820b5094695c8a581672e17220e30dd2c15d704c018326e3cc2e1a566f1" - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 + }, + "nbformat": 4, + "nbformat_minor": 2 } From ce1559ab4a0263be3db974ceb0b809e468acc8ba Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Fri, 27 Feb 2026 14:26:38 -0600 Subject: [PATCH 3/7] Fix boot_transformers kwargs and clear stale outputs - Move weight processing args (center_unembed, fold_ln, etc.) from boot_transformers() to enable_compatibility_mode() where they belong - Clear stale outputs from cell with execution_count=None Notebook blocked on missing TransformerBridge features: W_U property delegation (Bug 6), tokens_to_residual_directions (Bug 7), and pos_embed batch dim mismatch (Bug 3). See .claude/plans/transformer_bridge_bugs.md. --- demos/Exploratory_Analysis_Demo.ipynb | 40432 ++++++++++++------------ 1 file changed, 20208 insertions(+), 20224 deletions(-) diff --git a/demos/Exploratory_Analysis_Demo.ipynb b/demos/Exploratory_Analysis_Demo.ipynb index b12304844..0ec844270 100644 --- a/demos/Exploratory_Analysis_Demo.ipynb +++ b/demos/Exploratory_Analysis_Demo.ipynb @@ -1,20355 +1,20339 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Exploratory_Analysis_Demo.ipynb)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Exploratory Analysis Demo\n", - "\n", - "This notebook demonstrates how to use the\n", - "[TransformerLens](https://github.com/TransformerLensOrg/TransformerLens/) library to perform exploratory\n", - "analysis. The notebook tries to replicate the analysis of the Indirect Object Identification circuit\n", - "in the [Interpretability in the Wild](https://arxiv.org/abs/2211.00593) paper." - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Tips for Reading This\n", - "\n", - "* If running in Google Colab, go to Runtime > Change Runtime Type and select GPU as the hardware\n", - "accelerator.\n", - "* Look up unfamiliar terms in [the mech interp explainer](https://neelnanda.io/glossary)\n", - "* You can run all this code for yourself\n", - "* The graphs are interactive\n", - "* Use the table of contents pane in the sidebar to navigate (in Colab) or VSCode's \"Outline\" in the\n", - " explorer tab.\n", - "* Collapse irrelevant sections with the dropdown arrows\n", - "* Search the page using the search in the sidebar (with Colab) not CTRL+F" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Setup" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Environment Setup (ignore)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**You can ignore this part:** It's just for use internally to setup the tutorial in different\n", - "environments. You can delete this section if using in your own repo." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "# Detect if we're running in Google Colab\n", - "try:\n", - " import google.colab\n", - " IN_COLAB = True\n", - " print(\"Running as a Colab notebook\")\n", - "except:\n", - " IN_COLAB = False\n", - "\n", - "# Install if in Colab\n", - "if IN_COLAB:\n", - " %pip install transformer_lens\n", - " %pip install circuitsvis\n", - " # Install a faster Node version\n", - " !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs # noqa\n", - "\n", - "# Hot reload in development mode & not running on the CD\n", - "if not IN_COLAB:\n", - " from IPython import get_ipython\n", - " ip = get_ipython()\n", - " if not ip.extension_manager.loaded:\n", - " ip.extension_manager.load('autoreload')\n", - " %autoreload 2\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Imports" - ] - }, + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Exploratory_Analysis_Demo.ipynb)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Exploratory Analysis Demo\n", + "\n", + "This notebook demonstrates how to use the\n", + "[TransformerLens](https://github.com/TransformerLensOrg/TransformerLens/) library to perform exploratory\n", + "analysis. The notebook tries to replicate the analysis of the Indirect Object Identification circuit\n", + "in the [Interpretability in the Wild](https://arxiv.org/abs/2211.00593) paper." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Tips for Reading This\n", + "\n", + "* If running in Google Colab, go to Runtime > Change Runtime Type and select GPU as the hardware\n", + "accelerator.\n", + "* Look up unfamiliar terms in [the mech interp explainer](https://neelnanda.io/glossary)\n", + "* You can run all this code for yourself\n", + "* The graphs are interactive\n", + "* Use the table of contents pane in the sidebar to navigate (in Colab) or VSCode's \"Outline\" in the\n", + " explorer tab.\n", + "* Collapse irrelevant sections with the dropdown arrows\n", + "* Search the page using the search in the sidebar (with Colab) not CTRL+F" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Environment Setup (ignore)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**You can ignore this part:** It's just for use internally to setup the tutorial in different\n", + "environments. You can delete this section if using in your own repo." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# Detect if we're running in Google Colab\n", + "try:\n", + " import google.colab\n", + " IN_COLAB = True\n", + " print(\"Running as a Colab notebook\")\n", + "except:\n", + " IN_COLAB = False\n", + "\n", + "# Install if in Colab\n", + "if IN_COLAB:\n", + " %pip install transformer_lens\n", + " %pip install circuitsvis\n", + " # Install a faster Node version\n", + " !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs # noqa\n", + "\n", + "# Hot reload in development mode & not running on the CD\n", + "if not IN_COLAB:\n", + " from IPython import get_ipython\n", + " ip = get_ipython()\n", + " if not ip.extension_manager.loaded:\n", + " ip.extension_manager.load('autoreload')\n", + " %autoreload 2\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from functools import partial\n", + "from typing import List, Optional, Union\n", + "\n", + "import einops\n", + "import numpy as np\n", + "import plotly.express as px\n", + "import plotly.io as pio\n", + "import torch\n", + "from circuitsvis.attention import attention_heads\n", + "from fancy_einsum import einsum\n", + "from IPython.display import HTML, IFrame\n", + "from jaxtyping import Float\n", + "\n", + "import transformer_lens.utils as utils\n", + "from transformer_lens import ActivationCache\n", + "from transformer_lens.model_bridge import TransformerBridge" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### PyTorch Setup" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Disabled automatic differentiation\n" + ] + } + ], + "source": [ + "torch.set_grad_enabled(False)\n", + "print(\"Disabled automatic differentiation\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plotting Helper Functions (ignore)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Some plotting helper functions are included here (for simplicity)." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "def imshow(tensor, **kwargs):\n", + " px.imshow(\n", + " utils.to_numpy(tensor),\n", + " color_continuous_midpoint=0.0,\n", + " color_continuous_scale=\"RdBu\",\n", + " **kwargs,\n", + " ).show()\n", + "\n", + "\n", + "def line(tensor, **kwargs):\n", + " px.line(\n", + " y=utils.to_numpy(tensor),\n", + " **kwargs,\n", + " ).show()\n", + "\n", + "\n", + "def scatter(x, y, xaxis=\"\", yaxis=\"\", caxis=\"\", **kwargs):\n", + " x = utils.to_numpy(x)\n", + " y = utils.to_numpy(y)\n", + " px.scatter(\n", + " y=y,\n", + " x=x,\n", + " labels={\"x\": xaxis, \"y\": yaxis, \"color\": caxis},\n", + " **kwargs,\n", + " ).show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Introduction\n", + "\n", + "This is a demo notebook for [TransformerLens](https://github.com/TransformerLensOrg/TransformerLens), a library for mechanistic interpretability of GPT-2 style transformer language models. A core design principle of the library is to enable exploratory analysis - one of the most fun parts of mechanistic interpretability compared to normal ML is the extremely short feedback loops! The point of this library is to keep the gap between having an experiment idea and seeing the results as small as possible, to make it easy for **research to feel like play** and to enter a flow state.\n", + "\n", + "The goal of this notebook is to demonstrate what exploratory analysis looks like in practice with the library. I use my standard toolkit of basic mechanistic interpretability techniques to try interpreting a real circuit in GPT-2 small. Check out [the main demo](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Main_Demo.ipynb) for an introduction to the library and how to use it. \n", + "\n", + "Stylistically, I will go fairly slowly and explain in detail what I'm doing and why, aiming to help convey how to do this kind of research yourself! But the code itself is written to be simple and generic, and easy to copy and paste into your own projects for different tasks and models.\n", + "\n", + "Details tags contain asides, flavour + interpretability intuitions. These are more in the weeds and you don't need to read them or understand them, but they're helpful if you want to learn how to do mechanistic interpretability yourself! I star the ones I think are most important.\n", + "
(*) Example details tagExample aside!
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Indirect Object Identification\n", + "\n", + "The first step when trying to reverse engineer a circuit in a model is to identify *what* capability\n", + "I want to reverse engineer. Indirect Object Identification is a task studied in Redwood Research's\n", + "excellent [Interpretability in the Wild](https://arxiv.org/abs/2211.00593) paper (see [my interview\n", + "with the authors](https://www.youtube.com/watch?v=gzwj0jWbvbo) or [Kevin Wang's Twitter\n", + "thread](https://threadreaderapp.com/thread/1587601532639494146.html) for an overview). The task is\n", + "to complete sentences like \"After John and Mary went to the shops, John gave a bottle of milk to\"\n", + "with \" Mary\" rather than \" John\". \n", + "\n", + "In the paper they rigorously reverse engineer a 26 head circuit, with 7 separate categories of heads\n", + "used to perform this capability. Their rigorous methods are fairly involved, so in this notebook,\n", + "I'm going to skimp on rigour and instead try to speed run the process of finding suggestive evidence\n", + "for this circuit!\n", + "\n", + "The circuit they found roughly breaks down into three parts:\n", + "1. Identify what names are in the sentence\n", + "2. Identify which names are duplicated\n", + "3. Predict the name that is *not* duplicated" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The first step is to load in our model, GPT-2 Small, a 12 layer and 80M parameter transformer with `TransformerBridge.boot_transformers`. The various flags are simplifications that preserve the model's output but simplify its internals." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "model = TransformerBridge.boot_transformers(\"gpt2\")\n", + "model.enable_compatibility_mode(\n", + " center_unembed=True,\n", + " center_writing_weights=True,\n", + " fold_ln=True,\n", + " refactor_factored_attn_matrices=True,\n", + ")\n", + "\n", + "# Get the default device used\n", + "device: torch.device = utils.get_device()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The next step is to verify that the model can *actually* do the task! Here we use `utils.test_prompt`, and see that the model is significantly better at predicting Mary than John! \n", + "\n", + "
Asides:\n", + "\n", + "Note: If we were being careful, we'd want to run the model on a range of prompts and find the average performance\n", + "\n", + "`prepend_bos` is a flag to add a BOS (beginning of sequence) to the start of the prompt. GPT-2 was not trained with this, but I find that it often makes model behaviour more stable, as the first token is treated weirdly.\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tokenized prompt: ['<|endoftext|>', 'After', ' John', ' and', ' Mary', ' went', ' to', ' the', ' store', ',', ' John', ' gave', ' a', ' bottle', ' of', ' milk', ' to']\n", + "Tokenized answer: [' Mary']\n" + ] + }, + { + "data": { + "text/html": [ + "
Performance on answer token:\n",
+       "Rank: 0        Logit: 18.09 Prob: 70.07% Token: | Mary|\n",
+       "
\n" + ], + "text/plain": [ + "Performance on answer token:\n", + "\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m18.09\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m70.07\u001b[0m\u001b[1m% Token: | Mary|\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Top 0th token. Logit: 18.09 Prob: 70.07% Token: | Mary|\n", + "Top 1th token. Logit: 15.38 Prob: 4.67% Token: | the|\n", + "Top 2th token. Logit: 15.35 Prob: 4.54% Token: | John|\n", + "Top 3th token. Logit: 15.25 Prob: 4.11% Token: | them|\n", + "Top 4th token. Logit: 14.84 Prob: 2.73% Token: | his|\n", + "Top 5th token. Logit: 14.06 Prob: 1.24% Token: | her|\n", + "Top 6th token. Logit: 13.54 Prob: 0.74% Token: | a|\n", + "Top 7th token. Logit: 13.52 Prob: 0.73% Token: | their|\n", + "Top 8th token. Logit: 13.13 Prob: 0.49% Token: | Jesus|\n", + "Top 9th token. Logit: 12.97 Prob: 0.42% Token: | him|\n" + ] + }, + { + "data": { + "text/html": [ + "
Ranks of the answer tokens: [(' Mary', 0)]\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Mary'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "example_prompt = \"After John and Mary went to the store, John gave a bottle of milk to\"\n", + "example_answer = \" Mary\"\n", + "utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now want to find a reference prompt to run the model on. Even though our ultimate goal is to reverse engineer how this behaviour is done in general, often the best way to start out in mechanistic interpretability is by zooming in on a concrete example and understanding it in detail, and only *then* zooming out and verifying that our analysis generalises.\n", + "\n", + "We'll run the model on 4 instances of this task, each prompt given twice - one with the first name as the indirect object, one with the second name. To make our lives easier, we'll carefully choose prompts with single token names and the corresponding names in the same token positions.\n", + "\n", + "
(*) Aside on tokenization\n", + "\n", + "We want models that can take in arbitrary text, but models need to have a fixed vocabulary. So the solution is to define a vocabulary of **tokens** and to deterministically break up arbitrary text into tokens. Tokens are, essentially, subwords, and are determined by finding the most frequent substrings - this means that tokens vary a lot in length and frequency! \n", + "\n", + "Tokens are a *massive* headache and are one of the most annoying things about reverse engineering language models... Different names will be different numbers of tokens, different prompts will have the relevant tokens at different positions, different prompts will have different total numbers of tokens, etc. Language models often devote significant amounts of parameters in early layers to convert inputs from tokens to a more sensible internal format (and do the reverse in later layers). You really, really want to avoid needing to think about tokenization wherever possible when doing exploratory analysis (though, of course, it's relevant later when trying to flesh out your analysis and make it rigorous!). TransformerBridge comes with several helper methods to deal with tokens: `to_tokens, to_string, to_str_tokens, to_single_token, get_token_position`\n", + "\n", + "**Exercise:** I recommend using `model.to_str_tokens` to explore how the model tokenizes different strings. In particular, try adding or removing spaces at the start, or changing capitalization - these change tokenization!
" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['When John and Mary went to the shops, John gave the bag to', 'When John and Mary went to the shops, Mary gave the bag to', 'When Tom and James went to the park, James gave the ball to', 'When Tom and James went to the park, Tom gave the ball to', 'When Dan and Sid went to the shops, Sid gave an apple to', 'When Dan and Sid went to the shops, Dan gave an apple to', 'After Martin and Amy went to the park, Amy gave a drink to', 'After Martin and Amy went to the park, Martin gave a drink to']\n", + "[(' Mary', ' John'), (' John', ' Mary'), (' Tom', ' James'), (' James', ' Tom'), (' Dan', ' Sid'), (' Sid', ' Dan'), (' Martin', ' Amy'), (' Amy', ' Martin')]\n" + ] + } + ], + "source": [ + "prompt_format = [\n", + " \"When John and Mary went to the shops,{} gave the bag to\",\n", + " \"When Tom and James went to the park,{} gave the ball to\",\n", + " \"When Dan and Sid went to the shops,{} gave an apple to\",\n", + " \"After Martin and Amy went to the park,{} gave a drink to\",\n", + "]\n", + "names = [\n", + " (\" Mary\", \" John\"),\n", + " (\" Tom\", \" James\"),\n", + " (\" Dan\", \" Sid\"),\n", + " (\" Martin\", \" Amy\"),\n", + "]\n", + "# List of prompts\n", + "prompts = []\n", + "# List of answers, in the format (correct, incorrect)\n", + "answers = []\n", + "# List of the token (ie an integer) corresponding to each answer, in the format (correct_token, incorrect_token)\n", + "answer_tokens = []\n", + "for i in range(len(prompt_format)):\n", + " for j in range(2):\n", + " answers.append((names[i][j], names[i][1 - j]))\n", + " answer_tokens.append(\n", + " (\n", + " model.to_single_token(answers[-1][0]),\n", + " model.to_single_token(answers[-1][1]),\n", + " )\n", + " )\n", + " # Insert the *incorrect* answer to the prompt, making the correct answer the indirect object.\n", + " prompts.append(prompt_format[i].format(answers[-1][1]))\n", + "answer_tokens = torch.tensor(answer_tokens).to(device)\n", + "print(prompts)\n", + "print(answers)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Gotcha**: It's important that all of your prompts have the same number of tokens. If they're different lengths, then the position of the \"final\" logit where you can check logit difference will differ between prompts, and this will break the below code. The easiest solution is just to choose your prompts carefully to have the same number of tokens (you can eg add filler words like The, or newlines to start).\n", + "\n", + "There's a range of other ways of solving this, eg you can index more intelligently to get the final logit. A better way is to just use left padding by setting `model.tokenizer.padding_side = 'left'` before tokenizing the inputs and running the model; this way, you can use something like `logits[:, -1, :]` to easily access the final token outputs without complicated indexing. TransformerLens checks the value of `padding_side` of the tokenizer internally, and if the flag is set to be `'left'`, it adjusts the calculation of absolute position embedding and causal masking accordingly.\n", + "\n", + "In this demo, though, we stick to using the prompts of the same number of tokens because we want to show some visualisations aggregated along the batch dimension later in the demo." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prompt length: 15\n", + "Prompt as tokens: ['<|endoftext|>', 'When', ' John', ' and', ' Mary', ' went', ' to', ' the', ' shops', ',', ' John', ' gave', ' the', ' bag', ' to']\n", + "Prompt length: 15\n", + "Prompt as tokens: ['<|endoftext|>', 'When', ' John', ' and', ' Mary', ' went', ' to', ' the', ' shops', ',', ' Mary', ' gave', ' the', ' bag', ' to']\n", + "Prompt length: 15\n", + "Prompt as tokens: ['<|endoftext|>', 'When', ' Tom', ' and', ' James', ' went', ' to', ' the', ' park', ',', ' James', ' gave', ' the', ' ball', ' to']\n", + "Prompt length: 15\n", + "Prompt as tokens: ['<|endoftext|>', 'When', ' Tom', ' and', ' James', ' went', ' to', ' the', ' park', ',', ' Tom', ' gave', ' the', ' ball', ' to']\n", + "Prompt length: 15\n", + "Prompt as tokens: ['<|endoftext|>', 'When', ' Dan', ' and', ' Sid', ' went', ' to', ' the', ' shops', ',', ' Sid', ' gave', ' an', ' apple', ' to']\n", + "Prompt length: 15\n", + "Prompt as tokens: ['<|endoftext|>', 'When', ' Dan', ' and', ' Sid', ' went', ' to', ' the', ' shops', ',', ' Dan', ' gave', ' an', ' apple', ' to']\n", + "Prompt length: 15\n", + "Prompt as tokens: ['<|endoftext|>', 'After', ' Martin', ' and', ' Amy', ' went', ' to', ' the', ' park', ',', ' Amy', ' gave', ' a', ' drink', ' to']\n", + "Prompt length: 15\n", + "Prompt as tokens: ['<|endoftext|>', 'After', ' Martin', ' and', ' Amy', ' went', ' to', ' the', ' park', ',', ' Martin', ' gave', ' a', ' drink', ' to']\n" + ] + } + ], + "source": [ + "for prompt in prompts:\n", + " str_tokens = model.to_str_tokens(prompt)\n", + " print(\"Prompt length:\", len(str_tokens))\n", + " print(\"Prompt as tokens:\", str_tokens)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now run the model on these prompts and use `run_with_cache` to get both the logits and a cache of all internal activations for later analysis" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "tokens = model.to_tokens(prompts, prepend_bos=True)\n", + "\n", + "# Run the model and cache all activations\n", + "original_logits, cache = model.run_with_cache(tokens)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We'll later be evaluating how model performance differs upon performing various interventions, so it's useful to have a metric to measure model performance. Our metric here will be the **logit difference**, the difference in logit between the indirect object's name and the subject's name (eg, `logit(Mary)-logit(John)`). " + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Per prompt logit difference: tensor([3.3370, 3.2020, 2.7090, 3.7970, 1.7200, 5.2810, 2.6010, 5.7670])\n", + "Average logit difference: 3.552\n" + ] + } + ], + "source": [ + "def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):\n", + " # Only the final logits are relevant for the answer\n", + " final_logits = logits[:, -1, :]\n", + " answer_logits = final_logits.gather(dim=-1, index=answer_tokens)\n", + " answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]\n", + " if per_prompt:\n", + " return answer_logit_diff\n", + " else:\n", + " return answer_logit_diff.mean()\n", + "\n", + "\n", + "print(\n", + " \"Per prompt logit difference:\",\n", + " logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True)\n", + " .detach()\n", + " .cpu()\n", + " .round(decimals=3),\n", + ")\n", + "original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)\n", + "print(\n", + " \"Average logit difference:\",\n", + " round(logits_to_ave_logit_diff(original_logits, answer_tokens).item(), 3),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We see that the average logit difference is 3.5 - for context, this represents putting an $e^{3.5}\\approx 33\\times$ higher probability on the correct answer. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Brainstorm What's Actually Going On (Optional)\n", + "\n", + "Before diving into running experiments, it's often useful to spend some time actually reasoning about how the behaviour in question could be implemented in the transformer. **This is optional, and you'll likely get the most out of engaging with this section if you have a decent understanding already of what a transformer is and how it works!**\n", + "\n", + "You don't have to do this and forming hypotheses after exploration is also reasonable, but I think it's often easier to explore and interpret results with some grounding in what you might find. In this particular case, I'm cheating somewhat, since I know the answer, but I'm trying to simulate the process of reasoning about it!\n", + "\n", + "Note that often your hypothesis will be wrong in some ways and often be completely off. We're doing science here, and the goal is to understand how the model *actually* works, and to form true beliefs! There are two separate traps here at two extremes that it's worth tracking:\n", + "* Confusion: Having no hypotheses at all, getting a lot of data and not knowing what to do with it, and just floundering around\n", + "* Dogmatism: Being overconfident in an incorrect hypothesis and being unwilling to let go of it when reality contradicts you, or flinching away from running the experiments that might disconfirm it.\n", + "\n", + "**Exercise:** Spend some time thinking through how you might imagine this behaviour being implemented in a transformer. Try to think through this for yourself before reading through my thoughts! \n", + "\n", + "
(*) My reasoning\n", + "\n", + "

Brainstorming:

\n", + "\n", + "So, what's hard about the task? Let's focus on the concrete example of the first prompt, \"When John and Mary went to the shops, John gave the bag to\" -> \" Mary\". \n", + "\n", + "A good starting point is thinking though whether a tiny model could do this, eg a 1L Attn-Only model. I'm pretty sure the answer is no! Attention is really good at the primitive operations of looking nearby, or copying information. I can believe a tiny model could figure out that at `to` it should look for names and predict that those names came next (eg the skip trigram \" John...to -> John\"). But it's much harder to tell how many of each previous name there are - attending 0.3 to each copy of John will look exactly the same as attending 0.6 to a single John token. So this will be pretty hard to figure out on the \" to\" token!\n", + "\n", + "The natural place to break this symmetry is on the second \" John\" token - telling whether there is an earlier copy of the current token should be a much easier task. So I might expect there to be a head which detects duplicate tokens on the second \" John\" token, and then another head which moves that information from the second \" John\" token to the \" to\" token. \n", + "\n", + "The model then needs to learn to predict \" Mary\" and not \" John\". I can see two natural ways to do this: \n", + "1. Detect all preceding names and move this information to \" to\" and then delete the any name corresponding to the duplicate token feature. This feels easier done with a non-linearity, since precisely cancelling out vectors is hard, so I'd imagine an MLP layer deletes the \" John\" direction of the residual stream\n", + "2. Have a head which attends to all previous names, but where the duplicate token features inhibit it from attending to specific names. So this only attends to Mary. And then the output of this head maps to the logits. \n", + "\n", + "(Spoiler: It's the second one).\n", + "\n", + "

Experiment Ideas

\n", + "\n", + "A test that could distinguish these two is to look at which components of the model add directly to the logits - if it's mostly attention heads which attend to \" Mary\" and to neither \" John\" it's probably hypothesis 2, if it's mostly MLPs it's probably hypothesis 1.\n", + "\n", + "And we should be able to identify duplicate token heads by finding ones which attend from \" John\" to \" John\", and whose outputs are then moved to the \" to\" token by V-Composition with another head (Spoiler: It's more complicated than that!)\n", + "\n", + "Note that all of the above reasoning is very simplistic and could easily break in a real model! There'll be significant parts of the model that figure out whether to use this circuit at all (we don't want to inhibit duplicated names when, eg, figuring out what goes at the start of the next sentence), and may be parts towards the end of the model that do \"post-processing\" just before the final output. But it's a good starting point for thinking about what's going on." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Direct Logit Attribution" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "*Look up unfamiliar terms in the [mech interp explainer](https://neelnanda.io/glossary)*\n", + "\n", + "Further, the easiest part of the model to understand is the output - this is what the model is trained to optimize, and so it can always be directly interpreted! Often the right approach to reverse engineering a circuit is to start at the end, understand how the model produces the right answer, and to then work backwards. The main technique used to do this is called **direct logit attribution**\n", + "\n", + "**Background:** The central object of a transformer is the **residual stream**. This is the sum of the outputs of each layer and of the original token and positional embedding. Importantly, this means that any linear function of the residual stream can be perfectly decomposed into the contribution of each layer of the transformer. Further, each attention layer's output can be broken down into the sum of the output of each head (See [A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html) for details), and each MLP layer's output can be broken down into the sum of the output of each neuron (and a bias term for each layer). \n", + "\n", + "The logits of a model are `logits=Unembed(LayerNorm(final_residual_stream))`. The Unembed is a linear map, and LayerNorm is approximately a linear map, so we can decompose the logits into the sum of the contributions of each component, and look at which components contribute the most to the logit of the correct token! This is called **direct logit attribution**. Here we look at the direct attribution to the logit difference!\n", + "\n", + "
(*) Background and motivation of the logit difference\n", + "\n", + "Logit difference is actually a *really* nice and elegant metric and is a particularly nice aspect of the setup of Indirect Object Identification. In general, there are two natural ways to interpret the model's outputs: the output logits, or the output log probabilities (or probabilities). \n", + "\n", + "The logits are much nicer and easier to understand, as noted above. However, the model is trained to optimize the cross-entropy loss (the average of log probability of the correct token). This means it does not directly optimize the logits, and indeed if the model adds an arbitrary constant to every logit, the log probabilities are unchanged. \n", + "\n", + "But `log_probs == logits.log_softmax(dim=-1) == logits - logsumexp(logits)`, and so `log_probs(\" Mary\") - log_probs(\" John\") = logits(\" Mary\") - logits(\" John\")` - the ability to add an arbitrary constant cancels out!\n", + "\n", + "Further, the metric helps us isolate the precise capability we care about - figuring out *which* name is the Indirect Object. There are many other components of the task - deciding whether to return an article (the) or pronoun (her) or name, realising that the sentence wants a person next at all, etc. By taking the logit difference we control for all of that.\n", + "\n", + "Our metric is further refined, because each prompt is repeated twice, for each possible indirect object. This controls for irrelevant behaviour such as the model learning that John is a more frequent token than Mary (this actually happens! The final layernorm bias increases the John logit by 1 relative to the Mary logit)\n", + "\n", + "
\n", + "\n", + "
Ignoring LayerNorm\n", + "\n", + "LayerNorm is an analogous normalization technique to BatchNorm (that's friendlier to massive parallelization) that transformers use. Every time a transformer layer reads information from the residual stream, it applies a LayerNorm to normalize the vector at each position (translating to set the mean to 0 and scaling to set the variance to 1) and then applying a learned vector of weights and biases to scale and translate the normalized vector. This is *almost* a linear map, apart from the scaling step, because that divides by the norm of the vector and the norm is not a linear function. (The `fold_ln` flag when loading a model factors out all the linear parts).\n", + "\n", + "But if we fixed the scale factor, the LayerNorm would be fully linear. And the scale of the residual stream is a global property that's a function of *all* components of the stream, while in practice there is normally just a few directions relevant to any particular component, so in practice this is an acceptable approximation. So when doing direct logit attribution we use the `apply_ln` flag on the `cache` to apply the global layernorm scaling factor to each constant. See [my clean GPT-2 implementation](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/clean-transformer-demo/Clean_Transformer_Demo.ipynb#scrollTo=Clean_Transformer_Implementation) for more on LayerNorm.\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Getting an output logit is equivalent to projecting onto a direction in the residual stream. We use `model.tokens_to_residual_directions` to map the answer tokens to that direction, and then convert this to a logit difference direction for each batch" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Answer residual directions shape: torch.Size([8, 2, 768])\n", + "Logit difference directions shape: torch.Size([8, 768])\n" + ] + } + ], + "source": [ + "answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)\n", + "print(\"Answer residual directions shape:\", answer_residual_directions.shape)\n", + "logit_diff_directions = (\n", + " answer_residual_directions[:, 0] - answer_residual_directions[:, 1]\n", + ")\n", + "print(\"Logit difference directions shape:\", logit_diff_directions.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To verify that this works, we can apply this to the final residual stream for our cached prompts (after applying LayerNorm scaling) and verify that we get the same answer. \n", + "\n", + "
Technical details\n", + "\n", + "`logits = Unembed(LayerNorm(final_residual_stream))`, so we technically need to account for the centering, and then learned translation and scaling of the layernorm, not just the variance 1 scaling. \n", + "\n", + "The centering is accounted for with the preprocessing flag `center_writing_weights` which ensures that every weight matrix writing to the residual stream has mean zero. \n", + "\n", + "The learned scaling is folded into the unembedding weights `model.unembed.W_U` via `W_U_fold = layer_norm.weights[:, None] * unembed.W_U`\n", + "\n", + "The learned translation is folded to `model.unembed.b_U`, a bias added to the logits (note that GPT-2 is not trained with an existing `b_U`). This roughly represents unigram statistics. But we can ignore this because each prompt occurs twice with names in the opposite order, so this perfectly cancels out. \n", + "\n", + "Note that rather than using layernorm scaling we could just study cache[\"ln_final.hook_normalised\"]\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Final residual stream shape: torch.Size([8, 15, 768])\n", + "Calculated average logit diff: 3.552\n", + "Original logit difference: 3.552\n" + ] + } + ], + "source": [ + "# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].\n", + "final_residual_stream = cache[\"resid_post\", -1]\n", + "print(\"Final residual stream shape:\", final_residual_stream.shape)\n", + "final_token_residual_stream = final_residual_stream[:, -1, :]\n", + "# Apply LayerNorm scaling\n", + "# pos_slice is the subset of the positions we take - here the final token of each prompt\n", + "scaled_final_token_residual_stream = cache.apply_ln_to_stack(\n", + " final_token_residual_stream, layer=-1, pos_slice=-1\n", + ")\n", + "\n", + "average_logit_diff = einsum(\n", + " \"batch d_model, batch d_model -> \",\n", + " scaled_final_token_residual_stream,\n", + " logit_diff_directions,\n", + ") / len(prompts)\n", + "print(\"Calculated average logit diff:\", round(average_logit_diff.item(), 3))\n", + "print(\"Original logit difference:\", round(original_average_logit_diff.item(), 3))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Logit Lens" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now decompose the residual stream! First we apply a technique called the [**logit lens**](https://www.alignmentforum.org/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens) - this looks at the residual stream after each layer and calculates the logit difference from that. This simulates what happens if we delete all subsequence layers. " + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "def residual_stack_to_logit_diff(\n", + " residual_stack: Float[torch.Tensor, \"components batch d_model\"],\n", + " cache: ActivationCache,\n", + ") -> float:\n", + " scaled_residual_stack = cache.apply_ln_to_stack(\n", + " residual_stack, layer=-1, pos_slice=-1\n", + " )\n", + " return einsum(\n", + " \"... batch d_model, batch d_model -> ...\",\n", + " scaled_residual_stack,\n", + " logit_diff_directions,\n", + " ) / len(prompts)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Fascinatingly, we see that the model is utterly unable to do the task until layer 7, almost all performance comes from attention layer 9, and performance actually *decreases* from there.\n", + "\n", + "**Note:** Hover over each data point to see what residual stream position it's from!\n", + "\n", + "
Details on `accumulated_resid`\n", + "**Key:** `n_pre` means the residual stream at the start of layer n, `n_mid` means the residual stream after the attention part of layer n (`n_post` is the same as `n+1_pre` so is not included)\n", + "\n", + "* `layer` is the layer for which we input the residual stream (this is used to identify *which* layer norm scaling factor we want)\n", + "* `incl_mid` is whether to include the residual stream in the middle of a layer, ie after attention & before MLP\n", + "* `pos_slice` is the subset of the positions used. See `utils.Slice` for details on the syntax.\n", + "* return_labels is whether to return the labels for each component returned (useful for plotting)\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from functools import partial\n", - "from typing import List, Optional, Union\n", - "\n", - "import einops\n", - "import numpy as np\n", - "import plotly.express as px\n", - "import plotly.io as pio\n", - "import torch\n", - "from circuitsvis.attention import attention_heads\n", - "from fancy_einsum import einsum\n", - "from IPython.display import HTML, IFrame\n", - "from jaxtyping import Float\n", - "\n", - "import transformer_lens.utils as utils\n", - "from transformer_lens import ActivationCache\n", - "from transformer_lens.model_bridge import TransformerBridge" - ] + "hovertemplate": "%{hovertext}

x=%{x}
y=%{y}", + "hovertext": [ + "0_pre", + "0_mid", + "1_pre", + "1_mid", + "2_pre", + "2_mid", + "3_pre", + "3_mid", + "4_pre", + "4_mid", + "5_pre", + "5_mid", + "6_pre", + "6_mid", + "7_pre", + "7_mid", + "8_pre", + "8_mid", + "9_pre", + "9_mid", + "10_pre", + "10_mid", + "11_pre", + "11_mid", + "final_post" + ], + "legendgroup": "", + "line": { + "color": "#636efa", + "dash": "solid" + }, + "marker": { + "symbol": "circle" + }, + "mode": "lines", + "name": "", + "orientation": "v", + "showlegend": false, + "type": "scatter", + "x": [ + 0, + 0.5, + 1, + 1.5, + 2, + 2.5, + 3, + 3.5, + 4, + 4.5, + 5, + 5.5, + 6, + 6.5, + 7, + 7.5, + 8, + 8.5, + 9, + 9.5, + 10, + 10.5, + 11, + 11.5, + 12 + ], + "xaxis": "x", + "y": [ + 1.2937933206558228e-05, + -0.006643360480666161, + -0.007525032386183739, + -0.009075596928596497, + -0.008736769668757915, + -0.008685456588864326, + -0.006480347365140915, + -0.007939882576465607, + -0.009661720134317875, + -0.015095856040716171, + -0.01419061329215765, + -0.019930001348257065, + -0.00912435818463564, + -0.027298055589199066, + -0.02985510788857937, + 0.2497255504131317, + 0.250558078289032, + 0.45005205273628235, + 0.45996904373168945, + 5.02545166015625, + 5.142900466918945, + 4.730565071105957, + 4.887058258056641, + 3.445383071899414, + 3.5518720149993896 + ], + "yaxis": "y" + } + ], + "layout": { + "legend": { + "tracegroupgap": 0 }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### PyTorch Setup" + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training." + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Disabled automatic differentiation\n" - ] - } - ], - "source": [ - "torch.set_grad_enabled(False)\n", - "print(\"Disabled automatic differentiation\")" + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Plotting Helper Functions (ignore)" - ] + "title": { + "text": "Logit Difference From Accumulate Residual Stream" }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Some plotting helper functions are included here (for simplicity)." - ] + "xaxis": { + "anchor": "y", + "domain": [ + 0, + 1 + ], + "title": { + "text": "x" + } }, + "yaxis": { + "anchor": "x", + "domain": [ + 0, + 1 + ], + "title": { + "text": "y" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "accumulated_residual, labels = cache.accumulated_resid(\n", + " layer=-1, incl_mid=True, pos_slice=-1, return_labels=True\n", + ")\n", + "logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, cache)\n", + "line(\n", + " logit_lens_logit_diffs,\n", + " x=np.arange(model.cfg.n_layers * 2 + 1) / 2,\n", + " hover_name=labels,\n", + " title=\"Logit Difference From Accumulate Residual Stream\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Layer Attribution" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can repeat the above analysis but for each layer (this is equivalent to the differences between adjacent residual streams)\n", + "\n", + "Note: Annoying terminology overload - layer k of a transformer means the kth **transformer block**, but each block consists of an **attention layer** (to move information around) *and* an **MLP layer** (to process information). \n", + "\n", + "We see that only attention layers matter, which makes sense! The IOI task is about moving information around (ie moving the correct name and not the incorrect name), and less about processing it. And again we note that attention layer 9 improves things a lot, while attention 10 and attention 11 *decrease* performance" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "def imshow(tensor, **kwargs):\n", - " px.imshow(\n", - " utils.to_numpy(tensor),\n", - " color_continuous_midpoint=0.0,\n", - " color_continuous_scale=\"RdBu\",\n", - " **kwargs,\n", - " ).show()\n", - "\n", - "\n", - "def line(tensor, **kwargs):\n", - " px.line(\n", - " y=utils.to_numpy(tensor),\n", - " **kwargs,\n", - " ).show()\n", - "\n", - "\n", - "def scatter(x, y, xaxis=\"\", yaxis=\"\", caxis=\"\", **kwargs):\n", - " x = utils.to_numpy(x)\n", - " y = utils.to_numpy(y)\n", - " px.scatter(\n", - " y=y,\n", - " x=x,\n", - " labels={\"x\": xaxis, \"y\": yaxis, \"color\": caxis},\n", - " **kwargs,\n", - " ).show()" - ] + "hovertemplate": "%{hovertext}

x=%{x}
y=%{y}", + "hovertext": [ + "embed", + "pos_embed", + "0_attn_out", + "0_mlp_out", + "1_attn_out", + "1_mlp_out", + "2_attn_out", + "2_mlp_out", + "3_attn_out", + "3_mlp_out", + "4_attn_out", + "4_mlp_out", + "5_attn_out", + "5_mlp_out", + "6_attn_out", + "6_mlp_out", + "7_attn_out", + "7_mlp_out", + "8_attn_out", + "8_mlp_out", + "9_attn_out", + "9_mlp_out", + "10_attn_out", + "10_mlp_out", + "11_attn_out", + "11_mlp_out" + ], + "legendgroup": "", + "line": { + "color": "#636efa", + "dash": "solid" + }, + "marker": { + "symbol": "circle" + }, + "mode": "lines", + "name": "", + "orientation": "v", + "showlegend": false, + "type": "scatter", + "x": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25 + ], + "xaxis": "x", + "y": [ + -0.00028366726473905146, + 0.00029660604195669293, + -0.0066563040018081665, + -0.0008816685294732451, + -0.0015505650080740452, + 0.00033882574643939734, + 5.131529178470373e-05, + 0.0022051138803362846, + -0.0014595506945624948, + -0.0017218313878402114, + -0.005434143822640181, + 0.0009052485693246126, + -0.0057394010946154594, + 0.010805649682879448, + -0.018173698335886, + -0.002557049971073866, + 0.27958065271377563, + 0.0008325176313519478, + 0.19949400424957275, + 0.00991708692163229, + 4.565483093261719, + 0.11744903028011322, + -0.4123360514640808, + 0.15649384260177612, + -1.4416757822036743, + 0.10648896545171738 + ], + "yaxis": "y" + } + ], + "layout": { + "legend": { + "tracegroupgap": 0 }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Introduction\n", - "\n", - "This is a demo notebook for [TransformerLens](https://github.com/TransformerLensOrg/TransformerLens), a library for mechanistic interpretability of GPT-2 style transformer language models. A core design principle of the library is to enable exploratory analysis - one of the most fun parts of mechanistic interpretability compared to normal ML is the extremely short feedback loops! The point of this library is to keep the gap between having an experiment idea and seeing the results as small as possible, to make it easy for **research to feel like play** and to enter a flow state.\n", - "\n", - "The goal of this notebook is to demonstrate what exploratory analysis looks like in practice with the library. I use my standard toolkit of basic mechanistic interpretability techniques to try interpreting a real circuit in GPT-2 small. Check out [the main demo](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Main_Demo.ipynb) for an introduction to the library and how to use it. \n", - "\n", - "Stylistically, I will go fairly slowly and explain in detail what I'm doing and why, aiming to help convey how to do this kind of research yourself! But the code itself is written to be simple and generic, and easy to copy and paste into your own projects for different tasks and models.\n", - "\n", - "Details tags contain asides, flavour + interpretability intuitions. These are more in the weeds and you don't need to read them or understand them, but they're helpful if you want to learn how to do mechanistic interpretability yourself! I star the ones I think are most important.\n", - "
(*) Example details tagExample aside!
" + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Indirect Object Identification\n", - "\n", - "The first step when trying to reverse engineer a circuit in a model is to identify *what* capability\n", - "I want to reverse engineer. Indirect Object Identification is a task studied in Redwood Research's\n", - "excellent [Interpretability in the Wild](https://arxiv.org/abs/2211.00593) paper (see [my interview\n", - "with the authors](https://www.youtube.com/watch?v=gzwj0jWbvbo) or [Kevin Wang's Twitter\n", - "thread](https://threadreaderapp.com/thread/1587601532639494146.html) for an overview). The task is\n", - "to complete sentences like \"After John and Mary went to the shops, John gave a bottle of milk to\"\n", - "with \" Mary\" rather than \" John\". \n", - "\n", - "In the paper they rigorously reverse engineer a 26 head circuit, with 7 separate categories of heads\n", - "used to perform this capability. Their rigorous methods are fairly involved, so in this notebook,\n", - "I'm going to skimp on rigour and instead try to speed run the process of finding suggestive evidence\n", - "for this circuit!\n", - "\n", - "The circuit they found roughly breaks down into three parts:\n", - "1. Identify what names are in the sentence\n", - "2. Identify which names are duplicated\n", - "3. Predict the name that is *not* duplicated" + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The first step is to load in our model, GPT-2 Small, a 12 layer and 80M parameter transformer with `TransformerBridge.boot_transformers`. The various flags are simplifications that preserve the model's output but simplify its internals." + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using pad_token, but it is not set yet.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Loaded pretrained model gpt2-small into HookedTransformer\n" - ] - } - ], - "source": [ - "# NBVAL_IGNORE_OUTPUT\n", - "model = TransformerBridge.boot_transformers(\n", - " \"gpt2\",\n", - " center_unembed=True,\n", - " center_writing_weights=True,\n", - " fold_ln=True,\n", - " refactor_factored_attn_matrices=True,\n", - ")\n", - "model.enable_compatibility_mode()\n", - "\n", - "# Get the default device used\n", - "device: torch.device = utils.get_device()" - ] + "title": { + "text": "Logit Difference From Each Layer" }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The next step is to verify that the model can *actually* do the task! Here we use `utils.test_prompt`, and see that the model is significantly better at predicting Mary than John! \n", - "\n", - "
Asides:\n", - "\n", - "Note: If we were being careful, we'd want to run the model on a range of prompts and find the average performance\n", - "\n", - "`prepend_bos` is a flag to add a BOS (beginning of sequence) to the start of the prompt. GPT-2 was not trained with this, but I find that it often makes model behaviour more stable, as the first token is treated weirdly.\n", - "
" - ] + "xaxis": { + "anchor": "y", + "domain": [ + 0, + 1 + ], + "title": { + "text": "x" + } }, + "yaxis": { + "anchor": "x", + "domain": [ + 0, + 1 + ], + "title": { + "text": "y" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "per_layer_residual, labels = cache.decompose_resid(\n", + " layer=-1, pos_slice=-1, return_labels=True\n", + ")\n", + "per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, cache)\n", + "line(per_layer_logit_diffs, hover_name=labels, title=\"Logit Difference From Each Layer\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Head Attribution" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can further break down the output of each attention layer into the sum of the outputs of each attention head. Each attention layer consists of 12 heads, which each act independently and additively.\n", + "\n", + "
Decomposing attention output into sums of heads \n", + "The standard way to compute the output of an attention layer is by concatenating the mixed values of each head, and multiplying by a big output weight matrix. But as described in [A Mathematical Framework](https://transformer-circuits.pub/2021/framework/index.html) this is equivalent to splitting the output weight matrix into a per-head output (here `model.blocks[k].attn.W_O`) and adding them up (including an overall bias term for the entire layer)\n", + "
\n", + "\n", + "We see that only a few heads really matter - heads L9H6 and L9H9 contribute a lot positively (explaining why attention layer 9 is so important), while heads L10H7 and L11H10 contribute a lot negatively (explaining why attention layer 10 and layer 11 are actively harmful). These correspond to (some of) the name movers and negative name movers discussed in the paper. There are also several heads that matter positively or negatively but less strongly (other name movers and backup name movers)\n", + "\n", + "There are a few meta observations worth making here - our model has 144 heads, yet we could localise this behaviour to a handful of specific heads, using straightforward, general techniques. This supports the claim in [A Mathematical Framework](https://transformer-circuits.pub/2021/framework/index.html) that attention heads are the right level of abstraction to understand attention. It also really surprising that there are *negative* heads - eg L10H7 makes the incorrect logit 7x *more* likely. I'm not sure what's going on there, though the paper discusses some possibilities." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tried to stack head results when they weren't cached. Computing head results now\n" + ] + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tokenized prompt: ['<|endoftext|>', 'After', ' John', ' and', ' Mary', ' went', ' to', ' the', ' store', ',', ' John', ' gave', ' a', ' bottle', ' of', ' milk', ' to']\n", - "Tokenized answer: [' Mary']\n" - ] - }, - { - "data": { - "text/html": [ - "
Performance on answer token:\n",
-                            "Rank: 0        Logit: 18.09 Prob: 70.07% Token: | Mary|\n",
-                            "
\n" - ], - "text/plain": [ - "Performance on answer token:\n", - "\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m18.09\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m70.07\u001b[0m\u001b[1m% Token: | Mary|\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Top 0th token. Logit: 18.09 Prob: 70.07% Token: | Mary|\n", - "Top 1th token. Logit: 15.38 Prob: 4.67% Token: | the|\n", - "Top 2th token. Logit: 15.35 Prob: 4.54% Token: | John|\n", - "Top 3th token. Logit: 15.25 Prob: 4.11% Token: | them|\n", - "Top 4th token. Logit: 14.84 Prob: 2.73% Token: | his|\n", - "Top 5th token. Logit: 14.06 Prob: 1.24% Token: | her|\n", - "Top 6th token. Logit: 13.54 Prob: 0.74% Token: | a|\n", - "Top 7th token. Logit: 13.52 Prob: 0.73% Token: | their|\n", - "Top 8th token. Logit: 13.13 Prob: 0.49% Token: | Jesus|\n", - "Top 9th token. Logit: 12.97 Prob: 0.42% Token: | him|\n" - ] - }, - { - "data": { - "text/html": [ - "
Ranks of the answer tokens: [(' Mary', 0)]\n",
-                            "
\n" - ], - "text/plain": [ - "\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Mary'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "example_prompt = \"After John and Mary went to the store, John gave a bottle of milk to\"\n", - "example_answer = \" Mary\"\n", - "utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)" - ] + "coloraxis": "coloraxis", + "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + -0.0020563392899930477, + -0.0005101899732835591, + 0.0004685786843765527, + 0.00012512074317783117, + -0.0006028738571330905, + -0.0002429460291750729, + -0.0023189077619463205, + -0.002758360467851162, + 0.000564602785743773, + 0.0009697531932033598, + -0.0002504526637494564, + 4.737317794933915e-06 + ], + [ + -0.0010070882271975279, + 0.00039470894262194633, + -0.00154874159488827, + 0.0014034928753972054, + -0.0012653048615902662, + -0.0011358022456988692, + -0.00281596090644598, + -0.0029645217582583427, + 0.0029190476052463055, + 0.0025743592996150255, + 0.00036239007022231817, + 0.0017548729665577412 + ], + [ + 0.0005569400964304805, + -0.001126631861552596, + -0.0017353934235870838, + -0.0014514457434415817, + -0.00028735760133713484, + 0.0017211002996191382, + 0.0026658899150788784, + 0.00311466702260077, + 0.0005667927907779813, + -0.003666515462100506, + -0.0018847601022571325, + 7.039372576400638e-06 + ], + [ + -0.0007264417363330722, + 0.00011364505917299539, + 0.0014301587361842394, + 0.0007490540738217533, + 0.0020184689201414585, + 0.0007436950691044331, + -0.00046178390039131045, + -0.0039057559333741665, + 0.0011406694538891315, + -4.022853681817651e-05, + -0.0013293239753693342, + -0.0017636751290410757 + ], + [ + -0.0028280913829803467, + 0.00033634810824878514, + -0.0014248639345169067, + -0.003777273464947939, + 0.0015998880844563246, + 0.0002989505883306265, + -0.000804675742983818, + 0.002038792008534074, + -0.0015593919670209289, + -0.0006436670082621276, + 0.0011168173514306545, + -0.00035012533771805465 + ], + [ + 0.0011338205076754093, + 0.0011259170714765787, + -0.002516670385375619, + -0.0014790185960009694, + 0.0003878737334161997, + -6.408110493794084e-05, + -0.0005096744280308485, + -0.0008840755908749998, + 0.0006398351397365332, + -0.0010097370250150561, + -0.006759158335626125, + 0.0033667823299765587 + ], + [ + -0.01514742337167263, + -0.0021350777242332697, + 0.002593174111098051, + -0.00042678468162193894, + -0.005558924749493599, + 0.0026658528950065374, + 0.006411008536815643, + -0.003826778382062912, + -0.0003843410813715309, + -0.0016430341638624668, + -0.0013344454346224666, + -9.20506427064538e-05 + ], + [ + -9.476230479776859e-05, + -0.0057889921590685844, + -0.0006383581785485148, + 0.13493388891220093, + -0.001768707763403654, + -0.018917907029390335, + 0.003873429261147976, + -0.0021450775675475597, + -0.010327338241040707, + 0.18325845897197723, + -0.0007747983909212053, + -0.00104526337236166 + ], + [ + -0.003833949100226164, + -0.0008046097937040031, + -0.012673400342464447, + 0.00804573018103838, + 0.003604492638260126, + -0.009398287162184715, + -0.08272082358598709, + 0.003555194940418005, + -0.018404025584459305, + 0.0017587244510650635, + 0.2896133363246918, + 0.022854052484035492 + ], + [ + 0.08595258742570877, + -0.0006932877004146576, + 0.06817055493593216, + 0.013111240230500698, + -0.021098043769598007, + 0.05112447217106819, + 1.3844914436340332, + 0.045836858451366425, + -0.03830280900001526, + 2.985445976257324, + 0.0019662054255604744, + -0.008030137047171593 + ], + [ + 0.5608693957328796, + 0.17083050310611725, + -0.03361757844686508, + 0.05821544677019119, + -0.0024530249647796154, + 0.0018771197646856308, + 0.28827205300331116, + -1.8986485004425049, + -0.0015286931302398443, + -0.035129792988300323, + 0.4802178740501404, + -0.0009115453576669097 + ], + [ + 0.016075748950242996, + -0.03986122086644173, + -0.3879126012325287, + 0.011123123578727245, + -0.005477819126099348, + -0.0025129620917141438, + -0.08056175708770752, + 0.007518616039305925, + 0.0430111438035965, + -0.040082238614559174, + -0.9702364802360535, + 0.011862239800393581 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We now want to find a reference prompt to run the model on. Even though our ultimate goal is to reverse engineer how this behaviour is done in general, often the best way to start out in mechanistic interpretability is by zooming in on a concrete example and understanding it in detail, and only *then* zooming out and verifying that our analysis generalises.\n", - "\n", - "We'll run the model on 4 instances of this task, each prompt given twice - one with the first name as the indirect object, one with the second name. To make our lives easier, we'll carefully choose prompts with single token names and the corresponding names in the same token positions.\n", - "\n", - "
(*) Aside on tokenization\n", - "\n", - "We want models that can take in arbitrary text, but models need to have a fixed vocabulary. So the solution is to define a vocabulary of **tokens** and to deterministically break up arbitrary text into tokens. Tokens are, essentially, subwords, and are determined by finding the most frequent substrings - this means that tokens vary a lot in length and frequency! \n", - "\n", - "Tokens are a *massive* headache and are one of the most annoying things about reverse engineering language models... Different names will be different numbers of tokens, different prompts will have the relevant tokens at different positions, different prompts will have different total numbers of tokens, etc. Language models often devote significant amounts of parameters in early layers to convert inputs from tokens to a more sensible internal format (and do the reverse in later layers). You really, really want to avoid needing to think about tokenization wherever possible when doing exploratory analysis (though, of course, it's relevant later when trying to flesh out your analysis and make it rigorous!). TransformerBridge comes with several helper methods to deal with tokens: `to_tokens, to_string, to_str_tokens, to_single_token, get_token_position`\n", - "\n", - "**Exercise:** I recommend using `model.to_str_tokens` to explore how the model tokenizes different strings. In particular, try adding or removing spaces at the start, or changing capitalization - these change tokenization!
" + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['When John and Mary went to the shops, John gave the bag to', 'When John and Mary went to the shops, Mary gave the bag to', 'When Tom and James went to the park, James gave the ball to', 'When Tom and James went to the park, Tom gave the ball to', 'When Dan and Sid went to the shops, Sid gave an apple to', 'When Dan and Sid went to the shops, Dan gave an apple to', 'After Martin and Amy went to the park, Amy gave a drink to', 'After Martin and Amy went to the park, Martin gave a drink to']\n", - "[(' Mary', ' John'), (' John', ' Mary'), (' Tom', ' James'), (' James', ' Tom'), (' Dan', ' Sid'), (' Sid', ' Dan'), (' Martin', ' Amy'), (' Amy', ' Martin')]\n" - ] - } - ], - "source": [ - "prompt_format = [\n", - " \"When John and Mary went to the shops,{} gave the bag to\",\n", - " \"When Tom and James went to the park,{} gave the ball to\",\n", - " \"When Dan and Sid went to the shops,{} gave an apple to\",\n", - " \"After Martin and Amy went to the park,{} gave a drink to\",\n", - "]\n", - "names = [\n", - " (\" Mary\", \" John\"),\n", - " (\" Tom\", \" James\"),\n", - " (\" Dan\", \" Sid\"),\n", - " (\" Martin\", \" Amy\"),\n", - "]\n", - "# List of prompts\n", - "prompts = []\n", - "# List of answers, in the format (correct, incorrect)\n", - "answers = []\n", - "# List of the token (ie an integer) corresponding to each answer, in the format (correct_token, incorrect_token)\n", - "answer_tokens = []\n", - "for i in range(len(prompt_format)):\n", - " for j in range(2):\n", - " answers.append((names[i][j], names[i][1 - j]))\n", - " answer_tokens.append(\n", - " (\n", - " model.to_single_token(answers[-1][0]),\n", - " model.to_single_token(answers[-1][1]),\n", - " )\n", - " )\n", - " # Insert the *incorrect* answer to the prompt, making the correct answer the indirect object.\n", - " prompts.append(prompt_format[i].format(answers[-1][1]))\n", - "answer_tokens = torch.tensor(answer_tokens).to(device)\n", - "print(prompts)\n", - "print(answers)" + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Gotcha**: It's important that all of your prompts have the same number of tokens. If they're different lengths, then the position of the \"final\" logit where you can check logit difference will differ between prompts, and this will break the below code. The easiest solution is just to choose your prompts carefully to have the same number of tokens (you can eg add filler words like The, or newlines to start).\n", - "\n", - "There's a range of other ways of solving this, eg you can index more intelligently to get the final logit. A better way is to just use left padding by setting `model.tokenizer.padding_side = 'left'` before tokenizing the inputs and running the model; this way, you can use something like `logits[:, -1, :]` to easily access the final token outputs without complicated indexing. TransformerLens checks the value of `padding_side` of the tokenizer internally, and if the flag is set to be `'left'`, it adjusts the calculation of absolute position embedding and causal masking accordingly.\n", - "\n", - "In this demo, though, we stick to using the prompts of the same number of tokens because we want to show some visualisations aggregated along the batch dimension later in the demo." + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Prompt length: 15\n", - "Prompt as tokens: ['<|endoftext|>', 'When', ' John', ' and', ' Mary', ' went', ' to', ' the', ' shops', ',', ' John', ' gave', ' the', ' bag', ' to']\n", - "Prompt length: 15\n", - "Prompt as tokens: ['<|endoftext|>', 'When', ' John', ' and', ' Mary', ' went', ' to', ' the', ' shops', ',', ' Mary', ' gave', ' the', ' bag', ' to']\n", - "Prompt length: 15\n", - "Prompt as tokens: ['<|endoftext|>', 'When', ' Tom', ' and', ' James', ' went', ' to', ' the', ' park', ',', ' James', ' gave', ' the', ' ball', ' to']\n", - "Prompt length: 15\n", - "Prompt as tokens: ['<|endoftext|>', 'When', ' Tom', ' and', ' James', ' went', ' to', ' the', ' park', ',', ' Tom', ' gave', ' the', ' ball', ' to']\n", - "Prompt length: 15\n", - "Prompt as tokens: ['<|endoftext|>', 'When', ' Dan', ' and', ' Sid', ' went', ' to', ' the', ' shops', ',', ' Sid', ' gave', ' an', ' apple', ' to']\n", - "Prompt length: 15\n", - "Prompt as tokens: ['<|endoftext|>', 'When', ' Dan', ' and', ' Sid', ' went', ' to', ' the', ' shops', ',', ' Dan', ' gave', ' an', ' apple', ' to']\n", - "Prompt length: 15\n", - "Prompt as tokens: ['<|endoftext|>', 'After', ' Martin', ' and', ' Amy', ' went', ' to', ' the', ' park', ',', ' Amy', ' gave', ' a', ' drink', ' to']\n", - "Prompt length: 15\n", - "Prompt as tokens: ['<|endoftext|>', 'After', ' Martin', ' and', ' Amy', ' went', ' to', ' the', ' park', ',', ' Martin', ' gave', ' a', ' drink', ' to']\n" - ] - } - ], - "source": [ - "for prompt in prompts:\n", - " str_tokens = model.to_str_tokens(prompt)\n", - " print(\"Prompt length:\", len(str_tokens))\n", - " print(\"Prompt as tokens:\", str_tokens)" - ] + "title": { + "text": "Logit Difference From Each Head" }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We now run the model on these prompts and use `run_with_cache` to get both the logits and a cache of all internal activations for later analysis" - ] + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "scaleanchor": "y", + "title": { + "text": "Head" + } }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Layer" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "per_head_residual, labels = cache.stack_head_results(\n", + " layer=-1, pos_slice=-1, return_labels=True\n", + ")\n", + "per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)\n", + "per_head_logit_diffs = einops.rearrange(\n", + " per_head_logit_diffs,\n", + " \"(layer head_index) -> layer head_index\",\n", + " layer=model.cfg.n_layers,\n", + " head_index=model.cfg.n_heads,\n", + ")\n", + "imshow(\n", + " per_head_logit_diffs,\n", + " labels={\"x\": \"Head\", \"y\": \"Layer\"},\n", + " title=\"Logit Difference From Each Head\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Attention Analysis" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Attention heads are particularly easy to study because we can look directly at their attention patterns and study from what positions they move information from and two. This is particularly easy here as we're looking at the direct effect on the logits so we need only look at the attention patterns from the final token. \n", + "\n", + "We use Alan Cooney's circuitsvis library to visualize the attention patterns! We visualize the top 3 positive and negative heads by direct logit attribution, and show these for the first prompt (as an illustration).\n", + "\n", + "
Interpreting Attention Patterns \n", + "An easy mistake to make when looking at attention patterns is thinking that they must convey information about the token looked at (maybe accounting for the context of the token). But actually, all we can confidently say is that it moves information from the *residual stream position* corresponding to that input token. Especially later on in the model, there may be components in the residual stream that are nothing to do with the input token! Eg the period at the end of a sentence may contain summary information for that sentence, and the head may solely move that, rather than caring about whether it ends in \".\", \"!\" or \"?\"\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "def visualize_attention_patterns(\n", + " heads: Union[List[int], int, Float[torch.Tensor, \"heads\"]],\n", + " local_cache: ActivationCache,\n", + " local_tokens: torch.Tensor,\n", + " title: Optional[str] = \"\",\n", + " max_width: Optional[int] = 700,\n", + ") -> str:\n", + " # If a single head is given, convert to a list\n", + " if isinstance(heads, int):\n", + " heads = [heads]\n", + "\n", + " # Create the plotting data\n", + " labels: List[str] = []\n", + " patterns: List[Float[torch.Tensor, \"dest_pos src_pos\"]] = []\n", + "\n", + " # Assume we have a single batch item\n", + " batch_index = 0\n", + "\n", + " for head in heads:\n", + " # Set the label\n", + " layer = head // model.cfg.n_heads\n", + " head_index = head % model.cfg.n_heads\n", + " labels.append(f\"L{layer}H{head_index}\")\n", + "\n", + " # Get the attention patterns for the head\n", + " # Attention patterns have shape [batch, head_index, query_pos, key_pos]\n", + " patterns.append(local_cache[\"attn\", layer][batch_index, head_index])\n", + "\n", + " # Convert the tokens to strings (for the axis labels)\n", + " str_tokens = model.to_str_tokens(local_tokens)\n", + "\n", + " # Combine the patterns into a single tensor\n", + " patterns: Float[torch.Tensor, \"head_index dest_pos src_pos\"] = torch.stack(\n", + " patterns, dim=0\n", + " )\n", + "\n", + " # Circuitsvis Plot (note we get the code version so we can concatenate with the title)\n", + " plot = attention_heads(\n", + " attention=patterns, tokens=str_tokens, attention_head_names=labels\n", + " ).show_code()\n", + "\n", + " # Display the title\n", + " title_html = f\"

{title}


\"\n", + "\n", + " # Return the visualisation as raw code\n", + " return f\"
{title_html + plot}
\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Inspecting the patterns, we can see that both types of name movers attend to the indirect object - this suggests they're simply copying the name attended to (with the OV circuit) and that the interesting part is the circuit behind the attention pattern that calculates *where* to move information from (the QK circuit)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "

Top 3 Positive Logit Attribution Heads


\n", + "

Top 3 Negative Logit Attribution Heads


\n", + "
" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "top_k = 3\n", + "\n", + "top_positive_logit_attr_heads = torch.topk(\n", + " per_head_logit_diffs.flatten(), k=top_k\n", + ").indices\n", + "\n", + "positive_html = visualize_attention_patterns(\n", + " top_positive_logit_attr_heads,\n", + " cache,\n", + " tokens[0],\n", + " f\"Top {top_k} Positive Logit Attribution Heads\",\n", + ")\n", + "\n", + "top_negative_logit_attr_heads = torch.topk(\n", + " -per_head_logit_diffs.flatten(), k=top_k\n", + ").indices\n", + "\n", + "negative_html = visualize_attention_patterns(\n", + " top_negative_logit_attr_heads,\n", + " cache,\n", + " tokens[0],\n", + " title=f\"Top {top_k} Negative Logit Attribution Heads\",\n", + ")\n", + "\n", + "HTML(positive_html + negative_html)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Activation Patching" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**This section explains how to do activation patching conceptually by implementing it from scratch. To use it in practice with TransformerLens, see [this demonstration instead](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Activation_Patching_in_TL_Demo.ipynb)**.\n", + "\n", + "The obvious limitation to the techniques used above is that they only look at the very end of the circuit - the parts that directly affect the logits. Clearly this is not sufficient to understand the circuit! We want to understand how things compose together to produce this final output, and ideally to produce an end-to-end circuit fully explaining this behaviour. \n", + "\n", + "The technique we'll use to investigate this is called **activation patching**. This was first introduced in [David Bau and Kevin Meng's excellent ROME paper](https://rome.baulab.info/), there called causal tracing. \n", + "\n", + "The setup of activation patching is to take two runs of the model on two different inputs, the clean run and the corrupted run. The clean run outputs the correct answer and the corrupted run does not. The key idea is that we give the model the corrupted input, but then **intervene** on a specific activation and **patch** in the corresponding activation from the clean run (ie replace the corrupted activation with the clean activation), and then continue the run. And we then measure how much the output has updated towards the correct answer. \n", + "\n", + "We can then iterate over many possible activations and look at how much they affect the corrupted run. If patching in an activation significantly increases the probability of the correct answer, this allows us to *localise* which activations matter. \n", + "\n", + "The ability to localise is a key move in mechanistic interpretability - if the computation is diffuse and spread across the entire model, it is likely much harder to form a clean mechanistic story for what's going on. But if we can identify precisely which parts of the model matter, we can then zoom in and determine what they represent and how they connect up with each other, and ultimately reverse engineer the underlying circuit that they represent. \n", + "\n", + "Here's an animation from the ROME paper demonstrating this technique (they studied factual recall, and use stars to represent corruption applied to the subject of the sentence, but the same principles apply):\n", + "\n", + "![CT Animation](https://rome.baulab.info/images/small-ct-animation.gif)\n", + "\n", + "See also [the explanation in a mech interp explainer](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=qeWBvs-R-taFfcCq-S_hgMqx) and [this piece](https://www.neelnanda.io/mechanistic-interpretability/attribution-patching#how-to-think-about-activation-patching) describing how to think about patching on a conceptual level" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The above was all fairly abstract, so let's zoom in and lay out a concrete example to understand Indirect Object Identification.\n", + "\n", + "Here our clean input will be eg \"After John and Mary went to the store, **John** gave a bottle of milk to\" and our corrupted input will be eg \"After John and Mary went to the store, **Mary** gave a bottle of milk to\". These prompts are identical except for the name of the indirect object, and so patching is a causal intervention which will allow us to understand precisely which parts of the network are identifying the indirect object. \n", + "\n", + "One natural thing to patch in is the residual stream at a specific layer and specific position. For example, the model is likely initially doing some processing on the second subject token to realise that it's a duplicate, but then uses attention to move that information to the \" to\" token. So patching in the residual stream at the \" to\" token will likely matter a lot in later layers but not at all in early layers.\n", + "\n", + "We can zoom in much further and patch in specific activations from specific layers. For example, we think that the output of head L9H9 on the final token is significant for directly connecting to the logits\n", + "\n", + "We can patch in specific activations, and can zoom in as far as seems reasonable. For example, if we patch in the output of head L9H9 on the final token, we would predict that it will significantly affect performance. \n", + "\n", + "Note that this technique does *not* tell us how the components of the circuit connect up, just what they are. \n", + "\n", + "
Technical details \n", + "The choice of clean and corrupted prompt has both pros and cons. By carefully setting up the counterfactual, that only differs in the second subject, we avoid detecting the parts of the model doing irrelevant computation like detecting that the indirect object task is relevant at all or that it should be outputting a name rather than an article or pronoun. Or even context like that John and Mary are names at all. \n", + "\n", + "However, it *also* bakes in some details that *are* relevant to the task. Such as finding the location of the second subject, and of the names in the first clause. Or that the name mover heads have learned to copy whatever they look at. \n", + "\n", + "Some of these could be patched by also changing up the order of the names in the original sentence - patching in \"After John and Mary went to the store, John gave a bottle of milk to\" vs \"After Mary and John went to the store, John gave a bottle of milk to\".\n", + "\n", + "In the ROME paper they take a different tack. Rather than carefully setting up counterfactuals between two different but related inputs, they **corrupt** the clean input by adding Gaussian noise to the token embedding for the subject. This is in some ways much lower effort (you don't need to set up a similar but different prompt) but can also introduce some issues, such as ways this noise might break things. In practice, you should take care about how you choose your counterfactuals and try out several. Try to reason beforehand about what they will and will not tell you, and compare the results between different counterfactuals.\n", + "\n", + "I discuss some of these limitations and how the author's solved them with much more refined usage of these techniques in our interview\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Residual Stream" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Lets begin by patching in the residual stream at the start of each layer and for each token position. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We first create a set of corrupted tokens - where we swap each pair of prompts to have the opposite answer." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Corrupted Average Logit Diff -3.55\n", + "Clean Average Logit Diff 3.55\n" + ] + } + ], + "source": [ + "corrupted_prompts = []\n", + "for i in range(0, len(prompts), 2):\n", + " corrupted_prompts.append(prompts[i + 1])\n", + " corrupted_prompts.append(prompts[i])\n", + "corrupted_tokens = model.to_tokens(corrupted_prompts, prepend_bos=True)\n", + "corrupted_logits, corrupted_cache = model.run_with_cache(\n", + " corrupted_tokens, return_type=\"logits\"\n", + ")\n", + "corrupted_average_logit_diff = logits_to_ave_logit_diff(corrupted_logits, answer_tokens)\n", + "print(\"Corrupted Average Logit Diff\", round(corrupted_average_logit_diff.item(), 2))\n", + "print(\"Clean Average Logit Diff\", round(original_average_logit_diff.item(), 2))" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['<|endoftext|>When John and Mary went to the shops, Mary gave the bag to',\n", + " '<|endoftext|>When John and Mary went to the shops, John gave the bag to',\n", + " '<|endoftext|>When Tom and James went to the park, Tom gave the ball to',\n", + " '<|endoftext|>When Tom and James went to the park, James gave the ball to',\n", + " '<|endoftext|>When Dan and Sid went to the shops, Dan gave an apple to',\n", + " '<|endoftext|>When Dan and Sid went to the shops, Sid gave an apple to',\n", + " '<|endoftext|>After Martin and Amy went to the park, Martin gave a drink to',\n", + " '<|endoftext|>After Martin and Amy went to the park, Amy gave a drink to']" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.to_string(corrupted_tokens)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now intervene on the corrupted run and patch in the clean residual stream at a specific layer and position.\n", + "\n", + "We do the intervention using TransformerLens's `HookPoint` feature. We can design a hook function that takes in a specific activation and returns an edited copy, and temporarily add it in with `model.run_with_hooks`. " + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "def patch_residual_component(\n", + " corrupted_residual_component: Float[torch.Tensor, \"batch pos d_model\"],\n", + " hook,\n", + " pos,\n", + " clean_cache,\n", + "):\n", + " corrupted_residual_component[:, pos, :] = clean_cache[hook.name][:, pos, :]\n", + " return corrupted_residual_component\n", + "\n", + "\n", + "def normalize_patched_logit_diff(patched_logit_diff):\n", + " # Subtract corrupted logit diff to measure the improvement, divide by the total improvement from clean to corrupted to normalise\n", + " # 0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively *improved* on clean performance\n", + " return (patched_logit_diff - corrupted_average_logit_diff) / (\n", + " original_average_logit_diff - corrupted_average_logit_diff\n", + " )\n", + "\n", + "\n", + "patched_residual_stream_diff = torch.zeros(\n", + " model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32\n", + ")\n", + "for layer in range(model.cfg.n_layers):\n", + " for position in range(tokens.shape[1]):\n", + " hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)\n", + " patched_logits = model.run_with_hooks(\n", + " corrupted_tokens,\n", + " fwd_hooks=[(utils.get_act_name(\"resid_pre\", layer), hook_fn)],\n", + " return_type=\"logits\",\n", + " )\n", + " patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)\n", + "\n", + " patched_residual_stream_diff[layer, position] = normalize_patched_logit_diff(\n", + " patched_logit_diff\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can immediately see that, exactly as predicted, originally all relevant computation happens on the second subject token, and at layers 7 and 8, the information is moved to the final token. Moving the residual stream at the correct position near *exactly* recovers performance!\n", + "\n", + "For reference, tokens and their index from the first prompt are on the x-axis. In an abuse of notation, note that the difference here is averaged over *all* 8 prompts, while the labels only come from the *first* prompt. \n", + "\n", + "To be easier to interpret, we normalise the logit difference, by subtracting the corrupted logit difference, and dividing by the total improvement from clean to corrupted to normalise\n", + "0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively *improved* on clean performance" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "tokens = model.to_tokens(prompts, prepend_bos=True)\n", - "\n", - "# Run the model and cache all activations\n", - "original_logits, cache = model.run_with_cache(tokens)" - ] + "coloraxis": "coloraxis", + "hovertemplate": "Position: %{x}
Layer: %{y}
color: %{z}", + "name": "0", + "type": "heatmap", + "x": [ + "<|endoftext|>_0", + "When_1", + " John_2", + " and_3", + " Mary_4", + " went_5", + " to_6", + " the_7", + " shops_8", + ",_9", + " John_10", + " gave_11", + " the_12", + " bag_13", + " to_14" + ], + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1.000650405883789, + -0.0002469856117386371, + 9.76665523921838e-06, + -0.00036458822432905436, + -4.8967522161547095e-05 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1.001051902770996, + -2.7621845219982788e-05, + -1.9768245692830533e-05, + -0.0004596704675350338, + -0.0005947590689174831 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1.0002663135528564, + 0.0008680911851115525, + 0.0005157867562957108, + -0.0009929431835189462, + -0.0008658089209347963 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.994907796382904, + 0.005429857410490513, + 0.0016050540143623948, + -0.0006193603039719164, + -0.0016324409516528249 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.9675672054290771, + 0.03134213387966156, + 0.0028418952133506536, + -0.0012302964460104704, + -0.000985861523076892 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.967520534992218, + 0.03100077249109745, + 0.0017823305679485202, + -0.00048668819363228977, + -0.0006467136554419994 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.9228319525718689, + 0.05134531855583191, + 0.004728672094643116, + 0.0009345446596853435, + 0.017046840861439705 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.6565483808517456, + 0.02385685034096241, + 0.002357019344344735, + -1.7183941963594407e-05, + 0.3186916410923004 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.027302566915750504, + 0.03142499923706055, + 0.0018202561186626554, + 0.0007990868762135506, + 0.9383866190910339 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.026841485872864723, + 0.02098155952990055, + 0.0012512058019638062, + 0.00032317222212441266, + 1.0048279762268066 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.005687985569238663, + 0.014263377524912357, + 0.00048709093243815005, + -8.977938705356792e-05, + 0.9914212226867676 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We'll later be evaluating how model performance differs upon performing various interventions, so it's useful to have a metric to measure model performance. Our metric here will be the **logit difference**, the difference in logit between the indirect object's name and the subject's name (eg, `logit(Mary)-logit(John)`). " - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Per prompt logit difference: tensor([3.3370, 3.2020, 2.7090, 3.7970, 1.7200, 5.2810, 2.6010, 5.7670])\n", - "Average logit difference: 3.552\n" - ] - } - ], - "source": [ - "def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):\n", - " # Only the final logits are relevant for the answer\n", - " final_logits = logits[:, -1, :]\n", - " answer_logits = final_logits.gather(dim=-1, index=answer_tokens)\n", - " answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]\n", - " if per_prompt:\n", - " return answer_logit_diff\n", - " else:\n", - " return answer_logit_diff.mean()\n", - "\n", - "\n", - "print(\n", - " \"Per prompt logit difference:\",\n", - " logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True)\n", - " .detach()\n", - " .cpu()\n", - " .round(decimals=3),\n", - ")\n", - "original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)\n", - "print(\n", - " \"Average logit difference:\",\n", - " round(logits_to_ave_logit_diff(original_logits, answer_tokens).item(), 3),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We see that the average logit difference is 3.5 - for context, this represents putting an $e^{3.5}\\approx 33\\times$ higher probability on the correct answer. " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Brainstorm What's Actually Going On (Optional)\n", - "\n", - "Before diving into running experiments, it's often useful to spend some time actually reasoning about how the behaviour in question could be implemented in the transformer. **This is optional, and you'll likely get the most out of engaging with this section if you have a decent understanding already of what a transformer is and how it works!**\n", - "\n", - "You don't have to do this and forming hypotheses after exploration is also reasonable, but I think it's often easier to explore and interpret results with some grounding in what you might find. In this particular case, I'm cheating somewhat, since I know the answer, but I'm trying to simulate the process of reasoning about it!\n", - "\n", - "Note that often your hypothesis will be wrong in some ways and often be completely off. We're doing science here, and the goal is to understand how the model *actually* works, and to form true beliefs! There are two separate traps here at two extremes that it's worth tracking:\n", - "* Confusion: Having no hypotheses at all, getting a lot of data and not knowing what to do with it, and just floundering around\n", - "* Dogmatism: Being overconfident in an incorrect hypothesis and being unwilling to let go of it when reality contradicts you, or flinching away from running the experiments that might disconfirm it.\n", - "\n", - "**Exercise:** Spend some time thinking through how you might imagine this behaviour being implemented in a transformer. Try to think through this for yourself before reading through my thoughts! \n", - "\n", - "
(*) My reasoning\n", - "\n", - "

Brainstorming:

\n", - "\n", - "So, what's hard about the task? Let's focus on the concrete example of the first prompt, \"When John and Mary went to the shops, John gave the bag to\" -> \" Mary\". \n", - "\n", - "A good starting point is thinking though whether a tiny model could do this, eg a 1L Attn-Only model. I'm pretty sure the answer is no! Attention is really good at the primitive operations of looking nearby, or copying information. I can believe a tiny model could figure out that at `to` it should look for names and predict that those names came next (eg the skip trigram \" John...to -> John\"). But it's much harder to tell how many of each previous name there are - attending 0.3 to each copy of John will look exactly the same as attending 0.6 to a single John token. So this will be pretty hard to figure out on the \" to\" token!\n", - "\n", - "The natural place to break this symmetry is on the second \" John\" token - telling whether there is an earlier copy of the current token should be a much easier task. So I might expect there to be a head which detects duplicate tokens on the second \" John\" token, and then another head which moves that information from the second \" John\" token to the \" to\" token. \n", - "\n", - "The model then needs to learn to predict \" Mary\" and not \" John\". I can see two natural ways to do this: \n", - "1. Detect all preceding names and move this information to \" to\" and then delete the any name corresponding to the duplicate token feature. This feels easier done with a non-linearity, since precisely cancelling out vectors is hard, so I'd imagine an MLP layer deletes the \" John\" direction of the residual stream\n", - "2. Have a head which attends to all previous names, but where the duplicate token features inhibit it from attending to specific names. So this only attends to Mary. And then the output of this head maps to the logits. \n", - "\n", - "(Spoiler: It's the second one).\n", - "\n", - "

Experiment Ideas

\n", - "\n", - "A test that could distinguish these two is to look at which components of the model add directly to the logits - if it's mostly attention heads which attend to \" Mary\" and to neither \" John\" it's probably hypothesis 2, if it's mostly MLPs it's probably hypothesis 1.\n", - "\n", - "And we should be able to identify duplicate token heads by finding ones which attend from \" John\" to \" John\", and whose outputs are then moved to the \" to\" token by V-Composition with another head (Spoiler: It's more complicated than that!)\n", - "\n", - "Note that all of the above reasoning is very simplistic and could easily break in a real model! There'll be significant parts of the model that figure out whether to use this circuit at all (we don't want to inhibit duplicated names when, eg, figuring out what goes at the start of the next sentence), and may be parts towards the end of the model that do \"post-processing\" just before the final output. But it's a good starting point for thinking about what's going on." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Direct Logit Attribution" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "*Look up unfamiliar terms in the [mech interp explainer](https://neelnanda.io/glossary)*\n", - "\n", - "Further, the easiest part of the model to understand is the output - this is what the model is trained to optimize, and so it can always be directly interpreted! Often the right approach to reverse engineering a circuit is to start at the end, understand how the model produces the right answer, and to then work backwards. The main technique used to do this is called **direct logit attribution**\n", - "\n", - "**Background:** The central object of a transformer is the **residual stream**. This is the sum of the outputs of each layer and of the original token and positional embedding. Importantly, this means that any linear function of the residual stream can be perfectly decomposed into the contribution of each layer of the transformer. Further, each attention layer's output can be broken down into the sum of the output of each head (See [A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html) for details), and each MLP layer's output can be broken down into the sum of the output of each neuron (and a bias term for each layer). \n", - "\n", - "The logits of a model are `logits=Unembed(LayerNorm(final_residual_stream))`. The Unembed is a linear map, and LayerNorm is approximately a linear map, so we can decompose the logits into the sum of the contributions of each component, and look at which components contribute the most to the logit of the correct token! This is called **direct logit attribution**. Here we look at the direct attribution to the logit difference!\n", - "\n", - "
(*) Background and motivation of the logit difference\n", - "\n", - "Logit difference is actually a *really* nice and elegant metric and is a particularly nice aspect of the setup of Indirect Object Identification. In general, there are two natural ways to interpret the model's outputs: the output logits, or the output log probabilities (or probabilities). \n", - "\n", - "The logits are much nicer and easier to understand, as noted above. However, the model is trained to optimize the cross-entropy loss (the average of log probability of the correct token). This means it does not directly optimize the logits, and indeed if the model adds an arbitrary constant to every logit, the log probabilities are unchanged. \n", - "\n", - "But `log_probs == logits.log_softmax(dim=-1) == logits - logsumexp(logits)`, and so `log_probs(\" Mary\") - log_probs(\" John\") = logits(\" Mary\") - logits(\" John\")` - the ability to add an arbitrary constant cancels out!\n", - "\n", - "Further, the metric helps us isolate the precise capability we care about - figuring out *which* name is the Indirect Object. There are many other components of the task - deciding whether to return an article (the) or pronoun (her) or name, realising that the sentence wants a person next at all, etc. By taking the logit difference we control for all of that.\n", - "\n", - "Our metric is further refined, because each prompt is repeated twice, for each possible indirect object. This controls for irrelevant behaviour such as the model learning that John is a more frequent token than Mary (this actually happens! The final layernorm bias increases the John logit by 1 relative to the Mary logit)\n", - "\n", - "
\n", - "\n", - "
Ignoring LayerNorm\n", - "\n", - "LayerNorm is an analogous normalization technique to BatchNorm (that's friendlier to massive parallelization) that transformers use. Every time a transformer layer reads information from the residual stream, it applies a LayerNorm to normalize the vector at each position (translating to set the mean to 0 and scaling to set the variance to 1) and then applying a learned vector of weights and biases to scale and translate the normalized vector. This is *almost* a linear map, apart from the scaling step, because that divides by the norm of the vector and the norm is not a linear function. (The `fold_ln` flag when loading a model factors out all the linear parts).\n", - "\n", - "But if we fixed the scale factor, the LayerNorm would be fully linear. And the scale of the residual stream is a global property that's a function of *all* components of the stream, while in practice there is normally just a few directions relevant to any particular component, so in practice this is an acceptable approximation. So when doing direct logit attribution we use the `apply_ln` flag on the `cache` to apply the global layernorm scaling factor to each constant. See [my clean GPT-2 implementation](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/clean-transformer-demo/Clean_Transformer_Demo.ipynb#scrollTo=Clean_Transformer_Implementation) for more on LayerNorm.\n", - "
" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Getting an output logit is equivalent to projecting onto a direction in the residual stream. We use `model.tokens_to_residual_directions` to map the answer tokens to that direction, and then convert this to a logit difference direction for each batch" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Answer residual directions shape: torch.Size([8, 2, 768])\n", - "Logit difference directions shape: torch.Size([8, 768])\n" - ] - } - ], - "source": [ - "answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)\n", - "print(\"Answer residual directions shape:\", answer_residual_directions.shape)\n", - "logit_diff_directions = (\n", - " answer_residual_directions[:, 0] - answer_residual_directions[:, 1]\n", - ")\n", - "print(\"Logit difference directions shape:\", logit_diff_directions.shape)" + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "To verify that this works, we can apply this to the final residual stream for our cached prompts (after applying LayerNorm scaling) and verify that we get the same answer. \n", - "\n", - "
Technical details\n", - "\n", - "`logits = Unembed(LayerNorm(final_residual_stream))`, so we technically need to account for the centering, and then learned translation and scaling of the layernorm, not just the variance 1 scaling. \n", - "\n", - "The centering is accounted for with the preprocessing flag `center_writing_weights` which ensures that every weight matrix writing to the residual stream has mean zero. \n", - "\n", - "The learned scaling is folded into the unembedding weights `model.unembed.W_U` via `W_U_fold = layer_norm.weights[:, None] * unembed.W_U`\n", - "\n", - "The learned translation is folded to `model.unembed.b_U`, a bias added to the logits (note that GPT-2 is not trained with an existing `b_U`). This roughly represents unigram statistics. But we can ignore this because each prompt occurs twice with names in the opposite order, so this perfectly cancels out. \n", - "\n", - "Note that rather than using layernorm scaling we could just study cache[\"ln_final.hook_normalised\"]\n", - "\n", - "
" + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Final residual stream shape: torch.Size([8, 15, 768])\n", - "Calculated average logit diff: 3.552\n", - "Original logit difference: 3.552\n" - ] - } - ], - "source": [ - "# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].\n", - "final_residual_stream = cache[\"resid_post\", -1]\n", - "print(\"Final residual stream shape:\", final_residual_stream.shape)\n", - "final_token_residual_stream = final_residual_stream[:, -1, :]\n", - "# Apply LayerNorm scaling\n", - "# pos_slice is the subset of the positions we take - here the final token of each prompt\n", - "scaled_final_token_residual_stream = cache.apply_ln_to_stack(\n", - " final_token_residual_stream, layer=-1, pos_slice=-1\n", - ")\n", - "\n", - "average_logit_diff = einsum(\n", - " \"batch d_model, batch d_model -> \",\n", - " scaled_final_token_residual_stream,\n", - " logit_diff_directions,\n", - ") / len(prompts)\n", - "print(\"Calculated average logit diff:\", round(average_logit_diff.item(), 3))\n", - "print(\"Original logit difference:\", round(original_average_logit_diff.item(), 3))" + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Logit Lens" - ] + "title": { + "text": "Logit Difference From Patched Residual Stream" }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can now decompose the residual stream! First we apply a technique called the [**logit lens**](https://www.alignmentforum.org/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens) - this looks at the residual stream after each layer and calculates the logit difference from that. This simulates what happens if we delete all subsequence layers. " - ] + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "scaleanchor": "y", + "title": { + "text": "Position" + } }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Layer" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "prompt_position_labels = [\n", + " f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(tokens[0]))\n", + "]\n", + "imshow(\n", + " patched_residual_stream_diff,\n", + " x=prompt_position_labels,\n", + " title=\"Logit Difference From Patched Residual Stream\",\n", + " labels={\"x\": \"Position\", \"y\": \"Layer\"},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Layers" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can apply exactly the same idea, but this time patching in attention or MLP layers. These are also residual components with identical shapes to the residual stream terms, so we can reuse the same hooks." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "patched_attn_diff = torch.zeros(\n", + " model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32\n", + ")\n", + "patched_mlp_diff = torch.zeros(\n", + " model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32\n", + ")\n", + "for layer in range(model.cfg.n_layers):\n", + " for position in range(tokens.shape[1]):\n", + " hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)\n", + " patched_attn_logits = model.run_with_hooks(\n", + " corrupted_tokens,\n", + " fwd_hooks=[(utils.get_act_name(\"attn_out\", layer), hook_fn)],\n", + " return_type=\"logits\",\n", + " )\n", + " patched_attn_logit_diff = logits_to_ave_logit_diff(\n", + " patched_attn_logits, answer_tokens\n", + " )\n", + " patched_mlp_logits = model.run_with_hooks(\n", + " corrupted_tokens,\n", + " fwd_hooks=[(utils.get_act_name(\"mlp_out\", layer), hook_fn)],\n", + " return_type=\"logits\",\n", + " )\n", + " patched_mlp_logit_diff = logits_to_ave_logit_diff(\n", + " patched_mlp_logits, answer_tokens\n", + " )\n", + "\n", + " patched_attn_diff[layer, position] = normalize_patched_logit_diff(\n", + " patched_attn_logit_diff\n", + " )\n", + " patched_mlp_diff[layer, position] = normalize_patched_logit_diff(\n", + " patched_mlp_logit_diff\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We see that several attention layers are significant but that, matching the residual stream results, early layers matter on the second subject token, and later layers matter on the final token, and layers essentially don't matter on any other token. Extremely localised! As with direct logit attribution, layer 9 is positive and layers 10 and 11 are not, suggesting that the late layers only matter for direct logit effects, but we also see that layers 7 and 8 matter significantly. Presumably these are the heads that move information about which name is duplicated from the second subject token to the final token." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "def residual_stack_to_logit_diff(\n", - " residual_stack: Float[torch.Tensor, \"components batch d_model\"],\n", - " cache: ActivationCache,\n", - ") -> float:\n", - " scaled_residual_stack = cache.apply_ln_to_stack(\n", - " residual_stack, layer=-1, pos_slice=-1\n", - " )\n", - " return einsum(\n", - " \"... batch d_model, batch d_model -> ...\",\n", - " scaled_residual_stack,\n", - " logit_diff_directions,\n", - " ) / len(prompts)" - ] + "coloraxis": "coloraxis", + "hovertemplate": "Position: %{x}
Layer: %{y}
color: %{z}", + "name": "0", + "type": "heatmap", + "x": [ + "<|endoftext|>_0", + "When_1", + " John_2", + " and_3", + " Mary_4", + " went_5", + " to_6", + " the_7", + " shops_8", + ",_9", + " John_10", + " gave_11", + " the_12", + " bag_13", + " to_14" + ], + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.035456884652376175, + -0.0002469856117386371, + 9.76665523921838e-06, + -0.00036458822432905436, + -4.8967522161547095e-05 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.0029848709236830473, + 7.950929284561425e-05, + 2.0842242520302534e-05, + 8.088535105343908e-05, + -0.0005967392353340983 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.0019131568260490894, + 0.0006668510613963008, + 0.00039482791908085346, + -0.0007051457650959492, + -0.00027282864903099835 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.1546323299407959, + 0.0038019807543605566, + 0.0005171628436073661, + -0.00011964991426793858, + -0.0005599213181994855 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.005406397394835949, + 0.019581740722060204, + 0.001007509301416576, + -0.0002424211270408705, + 0.0007936497568152845 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.3520970046520233, + 0.0010525835677981377, + 0.00022436455765273422, + 0.00013367898645810783, + 8.172441448550671e-05 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.11986024677753448, + 0.021243548020720482, + 0.002727783052250743, + 0.0013409851817414165, + 0.01797366514801979 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.013310473412275314, + 0.011509180068969727, + 0.00037542887730523944, + -4.094611358596012e-05, + 0.29760244488716125 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.0015009435592219234, + 0.017351653426885605, + 0.0005848917062394321, + 0.0010122752282768488, + 0.5697318911552429 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.00012901381705887616, + 0.00630143890157342, + 0.00014156615361571312, + 0.00031229801243171096, + 0.27152299880981445 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.0009373303619213402, + 8.669164526509121e-05, + 0.00033243544748984277, + 9.73309283835988e-07, + -0.1929796040058136 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.40617984533309937 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Fascinatingly, we see that the model is utterly unable to do the task until layer 7, almost all performance comes from attention layer 9, and performance actually *decreases* from there.\n", - "\n", - "**Note:** Hover over each data point to see what residual stream position it's from!\n", - "\n", - "
Details on `accumulated_resid`\n", - "**Key:** `n_pre` means the residual stream at the start of layer n, `n_mid` means the residual stream after the attention part of layer n (`n_post` is the same as `n+1_pre` so is not included)\n", - "\n", - "* `layer` is the layer for which we input the residual stream (this is used to identify *which* layer norm scaling factor we want)\n", - "* `incl_mid` is whether to include the residual stream in the middle of a layer, ie after attention & before MLP\n", - "* `pos_slice` is the subset of the positions used. See `utils.Slice` for details on the syntax.\n", - "* return_labels is whether to return the labels for each component returned (useful for plotting)\n", - "
" + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ - { - "hovertemplate": "%{hovertext}

x=%{x}
y=%{y}", - "hovertext": [ - "0_pre", - "0_mid", - "1_pre", - "1_mid", - "2_pre", - "2_mid", - "3_pre", - "3_mid", - "4_pre", - "4_mid", - "5_pre", - "5_mid", - "6_pre", - "6_mid", - "7_pre", - "7_mid", - "8_pre", - "8_mid", - "9_pre", - "9_mid", - "10_pre", - "10_mid", - "11_pre", - "11_mid", - "final_post" - ], - "legendgroup": "", - "line": { - "color": "#636efa", - "dash": "solid" - }, - "marker": { - "symbol": "circle" - }, - "mode": "lines", - "name": "", - "orientation": "v", - "showlegend": false, - "type": "scatter", - "x": [ - 0, - 0.5, - 1, - 1.5, - 2, - 2.5, - 3, - 3.5, - 4, - 4.5, - 5, - 5.5, - 6, - 6.5, - 7, - 7.5, - 8, - 8.5, - 9, - 9.5, - 10, - 10.5, - 11, - 11.5, - 12 - ], - "xaxis": "x", - "y": [ - 0.000012937933206558228, - -0.006643360480666161, - -0.007525032386183739, - -0.009075596928596497, - -0.008736769668757915, - -0.008685456588864326, - -0.006480347365140915, - -0.007939882576465607, - -0.009661720134317875, - -0.015095856040716171, - -0.01419061329215765, - -0.019930001348257065, - -0.00912435818463564, - -0.027298055589199066, - -0.02985510788857937, - 0.2497255504131317, - 0.250558078289032, - 0.45005205273628235, - 0.45996904373168945, - 5.02545166015625, - 5.142900466918945, - 4.730565071105957, - 4.887058258056641, - 3.445383071899414, - 3.5518720149993896 - ], - "yaxis": "y" - } - ], - "layout": { - "legend": { - "tracegroupgap": 0 - }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" - ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } - }, - "title": { - "text": "Logit Difference From Accumulate Residual Stream" - }, - "xaxis": { - "anchor": "y", - "domain": [ - 0, - 1 - ], - "title": { - "text": "x" - } - }, - "yaxis": { - "anchor": "x", - "domain": [ - 0, - 1 - ], - "title": { - "text": "y" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "accumulated_residual, labels = cache.accumulated_resid(\n", - " layer=-1, incl_mid=True, pos_slice=-1, return_labels=True\n", - ")\n", - "logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, cache)\n", - "line(\n", - " logit_lens_logit_diffs,\n", - " x=np.arange(model.cfg.n_layers * 2 + 1) / 2,\n", - " hover_name=labels,\n", - " title=\"Logit Difference From Accumulate Residual Stream\",\n", - ")" + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Layer Attribution" + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can repeat the above analysis but for each layer (this is equivalent to the differences between adjacent residual streams)\n", - "\n", - "Note: Annoying terminology overload - layer k of a transformer means the kth **transformer block**, but each block consists of an **attention layer** (to move information around) *and* an **MLP layer** (to process information). \n", - "\n", - "We see that only attention layers matter, which makes sense! The IOI task is about moving information around (ie moving the correct name and not the incorrect name), and less about processing it. And again we note that attention layer 9 improves things a lot, while attention 10 and attention 11 *decrease* performance" - ] + "title": { + "text": "Logit Difference From Patched Attention Layer" }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ - { - "hovertemplate": "%{hovertext}

x=%{x}
y=%{y}", - "hovertext": [ - "embed", - "pos_embed", - "0_attn_out", - "0_mlp_out", - "1_attn_out", - "1_mlp_out", - "2_attn_out", - "2_mlp_out", - "3_attn_out", - "3_mlp_out", - "4_attn_out", - "4_mlp_out", - "5_attn_out", - "5_mlp_out", - "6_attn_out", - "6_mlp_out", - "7_attn_out", - "7_mlp_out", - "8_attn_out", - "8_mlp_out", - "9_attn_out", - "9_mlp_out", - "10_attn_out", - "10_mlp_out", - "11_attn_out", - "11_mlp_out" - ], - "legendgroup": "", - "line": { - "color": "#636efa", - "dash": "solid" - }, - "marker": { - "symbol": "circle" - }, - "mode": "lines", - "name": "", - "orientation": "v", - "showlegend": false, - "type": "scatter", - "x": [ - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - 10, - 11, - 12, - 13, - 14, - 15, - 16, - 17, - 18, - 19, - 20, - 21, - 22, - 23, - 24, - 25 - ], - "xaxis": "x", - "y": [ - -0.00028366726473905146, - 0.00029660604195669293, - -0.0066563040018081665, - -0.0008816685294732451, - -0.0015505650080740452, - 0.00033882574643939734, - 0.00005131529178470373, - 0.0022051138803362846, - -0.0014595506945624948, - -0.0017218313878402114, - -0.005434143822640181, - 0.0009052485693246126, - -0.0057394010946154594, - 0.010805649682879448, - -0.018173698335886, - -0.002557049971073866, - 0.27958065271377563, - 0.0008325176313519478, - 0.19949400424957275, - 0.00991708692163229, - 4.565483093261719, - 0.11744903028011322, - -0.4123360514640808, - 0.15649384260177612, - -1.4416757822036743, - 0.10648896545171738 - ], - "yaxis": "y" - } - ], - "layout": { - "legend": { - "tracegroupgap": 0 - }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" - ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } - }, - "title": { - "text": "Logit Difference From Each Layer" - }, - "xaxis": { - "anchor": "y", - "domain": [ - 0, - 1 - ], - "title": { - "text": "x" - } - }, - "yaxis": { - "anchor": "x", - "domain": [ - 0, - 1 - ], - "title": { - "text": "y" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "per_layer_residual, labels = cache.decompose_resid(\n", - " layer=-1, pos_slice=-1, return_labels=True\n", - ")\n", - "per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, cache)\n", - "line(per_layer_logit_diffs, hover_name=labels, title=\"Logit Difference From Each Layer\")" - ] + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "scaleanchor": "y", + "title": { + "text": "Position" + } }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Layer" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "imshow(\n", + " patched_attn_diff,\n", + " x=prompt_position_labels,\n", + " title=\"Logit Difference From Patched Attention Layer\",\n", + " labels={\"x\": \"Position\", \"y\": \"Layer\"},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In contrast, the MLP layers do not matter much. This makes sense, since this is more a task about moving information than about processing it, and the MLP layers specialise in processing information.\n", + "\n", + "The one exception is MLP 0, which matters a lot, but I think this is misleading and just a generally true statement about MLP 0 rather than being about the circuit on this task.\n", + "\n", + "
My takes on MLP0 \n", + "It's often observed on GPT-2 Small that MLP0 matters a lot, and that ablating it utterly destroys performance. My current best guess is that the first MLP layer is essentially acting as an extension of the embedding (for whatever reason) and that when later layers want to access the input tokens they mostly read in the output of the first MLP layer, rather than the token embeddings. Within this frame, the first attention layer doesn't do much. \n", + "\n", + "In this framing, it makes sense that MLP0 matters on the second subject token, because that's the one position with a different input token!\n", + "\n", + "I'm not entirely sure why this happens, but I would guess that it's because the embedding and unembedding matrices in GPT-2 Small are the same. This is pretty unprincipled, as the tasks of embedding and unembedding tokens are not inverses, but this is common practice, and plausibly models want to dedicate some parameters to overcoming this. \n", + "\n", + "I only have suggestive evidence of this, and would love to see someone look into this properly!\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Head Attribution" - ] + "coloraxis": "coloraxis", + "hovertemplate": "Position: %{x}
Layer: %{y}
color: %{z}", + "name": "0", + "type": "heatmap", + "x": [ + "<|endoftext|>_0", + "When_1", + " John_2", + " and_3", + " Mary_4", + " went_5", + " to_6", + " the_7", + " shops_8", + ",_9", + " John_10", + " gave_11", + " the_12", + " bag_13", + " to_14" + ], + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.8507890701293945, + -0.00027843358111567795, + -7.293107046280056e-05, + -0.00047373308916576207, + 4.0039929444901645e-05 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.008863994851708412, + 0.000222149450564757, + 0.00014938619278836995, + -4.853121208725497e-05, + 0.000304041663184762 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.013550343923270702, + 5.86334899708163e-05, + -0.0003296833310741931, + -0.0006382559076882899, + 0.0007730424986220896 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.0019468198297545314, + 0.0004995090421289206, + 0.00017318192112725228, + 0.00016871812113095075, + 0.00040764876757748425 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.019787074998021126, + 0.004128609783947468, + -4.86990247736685e-05, + -0.00017019486404024065, + 0.0007914346642792225 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.09652391821146011, + -0.0018826150335371494, + -0.0004844730719923973, + 0.0007094081956893206, + -0.00018335132335778326 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.015900013968348503, + -0.0008501688134856522, + 0.00012337534280959517, + 2.7521158699528314e-05, + -0.007238299585878849 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.010360540822148323, + 0.0031509376130998135, + 0.0005309234256856143, + 0.0002361114020459354, + 0.008496351540088654 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.012533102184534073, + 2.201692586822901e-05, + -0.00035374757135286927, + 8.615465048933402e-05, + -0.021631328389048576 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.00033465056912973523, + 0.0008094912045635283, + 1.6244195649051107e-05, + 0.00012924875773023814, + 0.03162466362118721 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.0013599144294857979, + -0.00019499746849760413, + -9.934466652339324e-05, + -0.00014217027637641877, + 0.028764141723513603 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.02044912613928318 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can further break down the output of each attention layer into the sum of the outputs of each attention head. Each attention layer consists of 12 heads, which each act independently and additively.\n", - "\n", - "
Decomposing attention output into sums of heads \n", - "The standard way to compute the output of an attention layer is by concatenating the mixed values of each head, and multiplying by a big output weight matrix. But as described in [A Mathematical Framework](https://transformer-circuits.pub/2021/framework/index.html) this is equivalent to splitting the output weight matrix into a per-head output (here `model.blocks[k].attn.W_O`) and adding them up (including an overall bias term for the entire layer)\n", - "
\n", - "\n", - "We see that only a few heads really matter - heads L9H6 and L9H9 contribute a lot positively (explaining why attention layer 9 is so important), while heads L10H7 and L11H10 contribute a lot negatively (explaining why attention layer 10 and layer 11 are actively harmful). These correspond to (some of) the name movers and negative name movers discussed in the paper. There are also several heads that matter positively or negatively but less strongly (other name movers and backup name movers)\n", - "\n", - "There are a few meta observations worth making here - our model has 144 heads, yet we could localise this behaviour to a handful of specific heads, using straightforward, general techniques. This supports the claim in [A Mathematical Framework](https://transformer-circuits.pub/2021/framework/index.html) that attention heads are the right level of abstraction to understand attention. It also really surprising that there are *negative* heads - eg L10H7 makes the incorrect logit 7x *more* likely. I'm not sure what's going on there, though the paper discusses some possibilities." + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tried to stack head results when they weren't cached. Computing head results now\n" - ] - }, - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ - { - "coloraxis": "coloraxis", - "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", - "name": "0", - "type": "heatmap", - "xaxis": "x", - "yaxis": "y", - "z": [ - [ - -0.0020563392899930477, - -0.0005101899732835591, - 0.0004685786843765527, - 0.00012512074317783117, - -0.0006028738571330905, - -0.0002429460291750729, - -0.0023189077619463205, - -0.002758360467851162, - 0.000564602785743773, - 0.0009697531932033598, - -0.0002504526637494564, - 0.000004737317794933915 - ], - [ - -0.0010070882271975279, - 0.00039470894262194633, - -0.00154874159488827, - 0.0014034928753972054, - -0.0012653048615902662, - -0.0011358022456988692, - -0.00281596090644598, - -0.0029645217582583427, - 0.0029190476052463055, - 0.0025743592996150255, - 0.00036239007022231817, - 0.0017548729665577412 - ], - [ - 0.0005569400964304805, - -0.001126631861552596, - -0.0017353934235870838, - -0.0014514457434415817, - -0.00028735760133713484, - 0.0017211002996191382, - 0.0026658899150788784, - 0.00311466702260077, - 0.0005667927907779813, - -0.003666515462100506, - -0.0018847601022571325, - 0.000007039372576400638 - ], - [ - -0.0007264417363330722, - 0.00011364505917299539, - 0.0014301587361842394, - 0.0007490540738217533, - 0.0020184689201414585, - 0.0007436950691044331, - -0.00046178390039131045, - -0.0039057559333741665, - 0.0011406694538891315, - -0.00004022853681817651, - -0.0013293239753693342, - -0.0017636751290410757 - ], - [ - -0.0028280913829803467, - 0.00033634810824878514, - -0.0014248639345169067, - -0.003777273464947939, - 0.0015998880844563246, - 0.0002989505883306265, - -0.000804675742983818, - 0.002038792008534074, - -0.0015593919670209289, - -0.0006436670082621276, - 0.0011168173514306545, - -0.00035012533771805465 - ], - [ - 0.0011338205076754093, - 0.0011259170714765787, - -0.002516670385375619, - -0.0014790185960009694, - 0.0003878737334161997, - -0.00006408110493794084, - -0.0005096744280308485, - -0.0008840755908749998, - 0.0006398351397365332, - -0.0010097370250150561, - -0.006759158335626125, - 0.0033667823299765587 - ], - [ - -0.01514742337167263, - -0.0021350777242332697, - 0.002593174111098051, - -0.00042678468162193894, - -0.005558924749493599, - 0.0026658528950065374, - 0.006411008536815643, - -0.003826778382062912, - -0.0003843410813715309, - -0.0016430341638624668, - -0.0013344454346224666, - -0.0000920506427064538 - ], - [ - -0.00009476230479776859, - -0.0057889921590685844, - -0.0006383581785485148, - 0.13493388891220093, - -0.001768707763403654, - -0.018917907029390335, - 0.003873429261147976, - -0.0021450775675475597, - -0.010327338241040707, - 0.18325845897197723, - -0.0007747983909212053, - -0.00104526337236166 - ], - [ - -0.003833949100226164, - -0.0008046097937040031, - -0.012673400342464447, - 0.00804573018103838, - 0.003604492638260126, - -0.009398287162184715, - -0.08272082358598709, - 0.003555194940418005, - -0.018404025584459305, - 0.0017587244510650635, - 0.2896133363246918, - 0.022854052484035492 - ], - [ - 0.08595258742570877, - -0.0006932877004146576, - 0.06817055493593216, - 0.013111240230500698, - -0.021098043769598007, - 0.05112447217106819, - 1.3844914436340332, - 0.045836858451366425, - -0.03830280900001526, - 2.985445976257324, - 0.0019662054255604744, - -0.008030137047171593 - ], - [ - 0.5608693957328796, - 0.17083050310611725, - -0.03361757844686508, - 0.05821544677019119, - -0.0024530249647796154, - 0.0018771197646856308, - 0.28827205300331116, - -1.8986485004425049, - -0.0015286931302398443, - -0.035129792988300323, - 0.4802178740501404, - -0.0009115453576669097 - ], - [ - 0.016075748950242996, - -0.03986122086644173, - -0.3879126012325287, - 0.011123123578727245, - -0.005477819126099348, - -0.0025129620917141438, - -0.08056175708770752, - 0.007518616039305925, - 0.0430111438035965, - -0.040082238614559174, - -0.9702364802360535, - 0.011862239800393581 - ] - ] - } - ], - "layout": { - "coloraxis": { - "cmid": 0, - "colorscale": [ - [ - 0, - "rgb(103,0,31)" - ], - [ - 0.1, - "rgb(178,24,43)" - ], - [ - 0.2, - "rgb(214,96,77)" - ], - [ - 0.3, - "rgb(244,165,130)" - ], - [ - 0.4, - "rgb(253,219,199)" - ], - [ - 0.5, - "rgb(247,247,247)" - ], - [ - 0.6, - "rgb(209,229,240)" - ], - [ - 0.7, - "rgb(146,197,222)" - ], - [ - 0.8, - "rgb(67,147,195)" - ], - [ - 0.9, - "rgb(33,102,172)" - ], - [ - 1, - "rgb(5,48,97)" - ] - ] - }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" - ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } - }, - "title": { - "text": "Logit Difference From Each Head" - }, - "xaxis": { - "anchor": "y", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "scaleanchor": "y", - "title": { - "text": "Head" - } - }, - "yaxis": { - "anchor": "x", - "autorange": "reversed", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "title": { - "text": "Layer" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "per_head_residual, labels = cache.stack_head_results(\n", - " layer=-1, pos_slice=-1, return_labels=True\n", - ")\n", - "per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)\n", - "per_head_logit_diffs = einops.rearrange(\n", - " per_head_logit_diffs,\n", - " \"(layer head_index) -> layer head_index\",\n", - " layer=model.cfg.n_layers,\n", - " head_index=model.cfg.n_heads,\n", - ")\n", - "imshow(\n", - " per_head_logit_diffs,\n", - " labels={\"x\": \"Head\", \"y\": \"Layer\"},\n", - " title=\"Logit Difference From Each Head\",\n", - ")" + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Attention Analysis" + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Attention heads are particularly easy to study because we can look directly at their attention patterns and study from what positions they move information from and two. This is particularly easy here as we're looking at the direct effect on the logits so we need only look at the attention patterns from the final token. \n", - "\n", - "We use Alan Cooney's circuitsvis library to visualize the attention patterns! We visualize the top 3 positive and negative heads by direct logit attribution, and show these for the first prompt (as an illustration).\n", - "\n", - "
Interpreting Attention Patterns \n", - "An easy mistake to make when looking at attention patterns is thinking that they must convey information about the token looked at (maybe accounting for the context of the token). But actually, all we can confidently say is that it moves information from the *residual stream position* corresponding to that input token. Especially later on in the model, there may be components in the residual stream that are nothing to do with the input token! Eg the period at the end of a sentence may contain summary information for that sentence, and the head may solely move that, rather than caring about whether it ends in \".\", \"!\" or \"?\"\n", - "
" - ] + "title": { + "text": "Logit Difference From Patched MLP Layer" }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [], - "source": [ - "def visualize_attention_patterns(\n", - " heads: Union[List[int], int, Float[torch.Tensor, \"heads\"]],\n", - " local_cache: ActivationCache,\n", - " local_tokens: torch.Tensor,\n", - " title: Optional[str] = \"\",\n", - " max_width: Optional[int] = 700,\n", - ") -> str:\n", - " # If a single head is given, convert to a list\n", - " if isinstance(heads, int):\n", - " heads = [heads]\n", - "\n", - " # Create the plotting data\n", - " labels: List[str] = []\n", - " patterns: List[Float[torch.Tensor, \"dest_pos src_pos\"]] = []\n", - "\n", - " # Assume we have a single batch item\n", - " batch_index = 0\n", - "\n", - " for head in heads:\n", - " # Set the label\n", - " layer = head // model.cfg.n_heads\n", - " head_index = head % model.cfg.n_heads\n", - " labels.append(f\"L{layer}H{head_index}\")\n", - "\n", - " # Get the attention patterns for the head\n", - " # Attention patterns have shape [batch, head_index, query_pos, key_pos]\n", - " patterns.append(local_cache[\"attn\", layer][batch_index, head_index])\n", - "\n", - " # Convert the tokens to strings (for the axis labels)\n", - " str_tokens = model.to_str_tokens(local_tokens)\n", - "\n", - " # Combine the patterns into a single tensor\n", - " patterns: Float[torch.Tensor, \"head_index dest_pos src_pos\"] = torch.stack(\n", - " patterns, dim=0\n", - " )\n", - "\n", - " # Circuitsvis Plot (note we get the code version so we can concatenate with the title)\n", - " plot = attention_heads(\n", - " attention=patterns, tokens=str_tokens, attention_head_names=labels\n", - " ).show_code()\n", - "\n", - " # Display the title\n", - " title_html = f\"

{title}


\"\n", - "\n", - " # Return the visualisation as raw code\n", - " return f\"
{title_html + plot}
\"" - ] + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "scaleanchor": "y", + "title": { + "text": "Position" + } }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Layer" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "imshow(\n", + " patched_mlp_diff,\n", + " x=prompt_position_labels,\n", + " title=\"Logit Difference From Patched MLP Layer\",\n", + " labels={\"x\": \"Position\", \"y\": \"Layer\"},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Heads\n", + "\n", + "We can refine the above analysis by patching in individual heads! This is somewhat more annoying, because there are now three dimensions (head_index, position and layer), so for now lets patch in a head's output across all positions.\n", + "\n", + "The easiest way to do this is to patch in the activation `z`, the \"mixed value\" of the attention head. That is, the average of all previous values weighted by the attention pattern, ie the activation that is then multiplied by `W_O`, the output weights. " + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "def patch_head_vector(\n", + " corrupted_head_vector: Float[torch.Tensor, \"batch pos head_index d_head\"],\n", + " hook,\n", + " head_index,\n", + " clean_cache,\n", + "):\n", + " corrupted_head_vector[:, :, head_index, :] = clean_cache[hook.name][\n", + " :, :, head_index, :\n", + " ]\n", + " return corrupted_head_vector\n", + "\n", + "\n", + "patched_head_z_diff = torch.zeros(\n", + " model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32\n", + ")\n", + "for layer in range(model.cfg.n_layers):\n", + " for head_index in range(model.cfg.n_heads):\n", + " hook_fn = partial(patch_head_vector, head_index=head_index, clean_cache=cache)\n", + " patched_logits = model.run_with_hooks(\n", + " corrupted_tokens,\n", + " fwd_hooks=[(utils.get_act_name(\"z\", layer, \"attn\"), hook_fn)],\n", + " return_type=\"logits\",\n", + " )\n", + " patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)\n", + "\n", + " patched_head_z_diff[layer, head_index] = normalize_patched_logit_diff(\n", + " patched_logit_diff\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now see that, in addition to the name mover heads identified before, in mid-late layers the heads L8H6, L8H10, L7H9 matter and are presumably responsible for moving information from the second subject to the final token. And heads L5H5, L6H9, L3H0 also matter a lot, and are presumably involved in detecting duplicated tokens." + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Inspecting the patterns, we can see that both types of name movers attend to the indirect object - this suggests they're simply copying the name attended to (with the OV circuit) and that the interesting part is the circuit behind the attention pattern that calculates *where* to move information from (the QK circuit)" - ] + "coloraxis": "coloraxis", + "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + 0.0009487751522101462, + 0.016124747693538666, + 0.0018548924708738923, + 0.0034389030188322067, + -0.00982347596436739, + 0.011058605276048183, + -0.004063969012349844, + -0.0015792781487107277, + -0.0012082795146852732, + 0.003828897839412093, + -0.004256919026374817, + -0.0011422622483223677 + ], + [ + -0.0010771177476271987, + -0.00037898647133260965, + 2.5171791548928013e-06, + -0.00026067905128002167, + -0.00014146546891424805, + 0.0038321535103023052, + -0.0004293300735298544, + -0.00142992555629462, + -0.0009228314156644046, + 0.0006944393389858305, + 0.00043302192352712154, + -0.0035714071709662676 + ], + [ + -0.0004967569257132709, + 0.0008057993836700916, + 0.0005424688570201397, + -0.0005309234256856143, + -0.0007159864180721343, + -0.0010389237431809306, + -0.0009490771917626262, + -8.649027586216107e-05, + 0.0002766547549981624, + 0.0021084228064864874, + -0.0001975146442418918, + -0.0016405630158260465 + ], + [ + 0.1162627637386322, + 0.0002507446042727679, + -0.0014675153652206063, + -0.00039680811460129917, + 0.018962211906909943, + -0.00018764731066767126, + 0.011170871555805206, + -0.0013301445869728923, + -0.0007356539717875421, + -0.00030253134900704026, + -0.00014683544577565044, + -0.00022228369198273867 + ], + [ + -0.001650598249398172, + 0.0002927311579696834, + -0.00143563118763268, + 0.03084198758006096, + -0.007432155776768923, + -0.00028236035723239183, + 0.006017433945089579, + -0.011007187888026237, + -0.001266107545234263, + 0.0014901700196787715, + -0.0001800622121663764, + 0.002944394713267684 + ], + [ + -0.004211106337606907, + 0.0029597999528050423, + 0.002045023487880826, + 0.0013397098518908024, + -0.0012190865818411112, + 0.34349915385246277, + 0.0005632104002870619, + -0.0001262281439267099, + -0.00515326950699091, + 0.016240738332271576, + 0.01709030382335186, + -0.004175194539129734 + ], + [ + 0.039775289595127106, + 0.015226684510707855, + -0.0010229480685666203, + 0.0008072761120274663, + -0.004935584031045437, + -0.002123525831848383, + -0.014274083077907562, + 0.0013746818294748664, + 0.0014838266652077436, + 0.1302703619003296, + -0.00033616088330745697, + 0.0012919505825266242 + ], + [ + 0.00037177055492065847, + 0.019514480605721474, + 0.00022255218937061727, + 0.124249167740345, + -0.00040352059295400977, + -0.007652895525097847, + 0.0013010123511776328, + -0.0011253133416175842, + -0.007449474185705185, + 0.19224143028259277, + -0.003275118535384536, + -0.0005017912480980158 + ], + [ + -0.001007912098430097, + 3.091096004936844e-05, + -0.0008595998515374959, + 0.012359987013041973, + -0.0004041247011628002, + -0.004328910261392593, + 0.3185553252696991, + 0.002330605871975422, + 0.0021182901691645384, + 0.0001405928487656638, + 0.2779357433319092, + 0.005738262087106705 + ], + [ + 0.0058898297138512135, + -0.0009689796715974808, + 0.00912561360746622, + 0.020675739273428917, + -0.03700518235564232, + 0.014263041317462921, + -0.04828466475009918, + 0.05834139883518219, + 0.0006514795240946114, + 0.26360899209976196, + 0.0004918567719869316, + -0.00261044898070395 + ], + [ + 0.08374208211898804, + 0.020676210522651672, + -0.003743582172319293, + 0.01085072010755539, + -0.001096583902835846, + 0.00047430366976186633, + 0.04818058758974075, + -0.4799128472805023, + 0.00018429107149131596, + 0.011861988343298435, + 0.06088569387793541, + 0.0008461413672193885 + ], + [ + 0.005328264087438583, + -0.011493473313748837, + -0.11350836604833603, + 0.006329597905278206, + 0.00031669469899497926, + -0.0011600167490541935, + -0.022669579833745956, + 0.004070379305630922, + 0.0073160636238753796, + -0.00834545586258173, + -0.27817651629447937, + 0.0036344374530017376 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "

Top 3 Positive Logit Attribution Heads


\n", - "

Top 3 Negative Logit Attribution Heads


\n", - "
" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "top_k = 3\n", - "\n", - "top_positive_logit_attr_heads = torch.topk(\n", - " per_head_logit_diffs.flatten(), k=top_k\n", - ").indices\n", - "\n", - "positive_html = visualize_attention_patterns(\n", - " top_positive_logit_attr_heads,\n", - " cache,\n", - " tokens[0],\n", - " f\"Top {top_k} Positive Logit Attribution Heads\",\n", - ")\n", - "\n", - "top_negative_logit_attr_heads = torch.topk(\n", - " -per_head_logit_diffs.flatten(), k=top_k\n", - ").indices\n", - "\n", - "negative_html = visualize_attention_patterns(\n", - " top_negative_logit_attr_heads,\n", - " cache,\n", - " tokens[0],\n", - " title=f\"Top {top_k} Negative Logit Attribution Heads\",\n", - ")\n", - "\n", - "HTML(positive_html + negative_html)" + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Activation Patching" + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**This section explains how to do activation patching conceptually by implementing it from scratch. To use it in practice with TransformerLens, see [this demonstration instead](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Activation_Patching_in_TL_Demo.ipynb)**.\n", - "\n", - "The obvious limitation to the techniques used above is that they only look at the very end of the circuit - the parts that directly affect the logits. Clearly this is not sufficient to understand the circuit! We want to understand how things compose together to produce this final output, and ideally to produce an end-to-end circuit fully explaining this behaviour. \n", - "\n", - "The technique we'll use to investigate this is called **activation patching**. This was first introduced in [David Bau and Kevin Meng's excellent ROME paper](https://rome.baulab.info/), there called causal tracing. \n", - "\n", - "The setup of activation patching is to take two runs of the model on two different inputs, the clean run and the corrupted run. The clean run outputs the correct answer and the corrupted run does not. The key idea is that we give the model the corrupted input, but then **intervene** on a specific activation and **patch** in the corresponding activation from the clean run (ie replace the corrupted activation with the clean activation), and then continue the run. And we then measure how much the output has updated towards the correct answer. \n", - "\n", - "We can then iterate over many possible activations and look at how much they affect the corrupted run. If patching in an activation significantly increases the probability of the correct answer, this allows us to *localise* which activations matter. \n", - "\n", - "The ability to localise is a key move in mechanistic interpretability - if the computation is diffuse and spread across the entire model, it is likely much harder to form a clean mechanistic story for what's going on. But if we can identify precisely which parts of the model matter, we can then zoom in and determine what they represent and how they connect up with each other, and ultimately reverse engineer the underlying circuit that they represent. \n", - "\n", - "Here's an animation from the ROME paper demonstrating this technique (they studied factual recall, and use stars to represent corruption applied to the subject of the sentence, but the same principles apply):\n", - "\n", - "![CT Animation](https://rome.baulab.info/images/small-ct-animation.gif)\n", - "\n", - "See also [the explanation in a mech interp explainer](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=qeWBvs-R-taFfcCq-S_hgMqx) and [this piece](https://www.neelnanda.io/mechanistic-interpretability/attribution-patching#how-to-think-about-activation-patching) describing how to think about patching on a conceptual level" + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The above was all fairly abstract, so let's zoom in and lay out a concrete example to understand Indirect Object Identification.\n", - "\n", - "Here our clean input will be eg \"After John and Mary went to the store, **John** gave a bottle of milk to\" and our corrupted input will be eg \"After John and Mary went to the store, **Mary** gave a bottle of milk to\". These prompts are identical except for the name of the indirect object, and so patching is a causal intervention which will allow us to understand precisely which parts of the network are identifying the indirect object. \n", - "\n", - "One natural thing to patch in is the residual stream at a specific layer and specific position. For example, the model is likely initially doing some processing on the second subject token to realise that it's a duplicate, but then uses attention to move that information to the \" to\" token. So patching in the residual stream at the \" to\" token will likely matter a lot in later layers but not at all in early layers.\n", - "\n", - "We can zoom in much further and patch in specific activations from specific layers. For example, we think that the output of head L9H9 on the final token is significant for directly connecting to the logits\n", - "\n", - "We can patch in specific activations, and can zoom in as far as seems reasonable. For example, if we patch in the output of head L9H9 on the final token, we would predict that it will significantly affect performance. \n", - "\n", - "Note that this technique does *not* tell us how the components of the circuit connect up, just what they are. \n", - "\n", - "
Technical details \n", - "The choice of clean and corrupted prompt has both pros and cons. By carefully setting up the counterfactual, that only differs in the second subject, we avoid detecting the parts of the model doing irrelevant computation like detecting that the indirect object task is relevant at all or that it should be outputting a name rather than an article or pronoun. Or even context like that John and Mary are names at all. \n", - "\n", - "However, it *also* bakes in some details that *are* relevant to the task. Such as finding the location of the second subject, and of the names in the first clause. Or that the name mover heads have learned to copy whatever they look at. \n", - "\n", - "Some of these could be patched by also changing up the order of the names in the original sentence - patching in \"After John and Mary went to the store, John gave a bottle of milk to\" vs \"After Mary and John went to the store, John gave a bottle of milk to\".\n", - "\n", - "In the ROME paper they take a different tack. Rather than carefully setting up counterfactuals between two different but related inputs, they **corrupt** the clean input by adding Gaussian noise to the token embedding for the subject. This is in some ways much lower effort (you don't need to set up a similar but different prompt) but can also introduce some issues, such as ways this noise might break things. In practice, you should take care about how you choose your counterfactuals and try out several. Try to reason beforehand about what they will and will not tell you, and compare the results between different counterfactuals.\n", - "\n", - "I discuss some of these limitations and how the author's solved them with much more refined usage of these techniques in our interview\n", - "
" - ] + "title": { + "text": "Logit Difference From Patched Head Output" }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Residual Stream" - ] + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "scaleanchor": "y", + "title": { + "text": "Head" + } }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Layer" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "imshow(\n", + " patched_head_z_diff,\n", + " title=\"Logit Difference From Patched Head Output\",\n", + " labels={\"x\": \"Head\", \"y\": \"Layer\"},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Decomposing Heads" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Decomposing attention layers into patching in individual heads has already helped us localise the behaviour a lot. But we can understand it further by decomposing heads. An attention head consists of two semi-independent operations - calculating *where* to move information from and to (represented by the attention pattern and implemented via the QK-circuit) and calculating *what* information to move (represented by the value vectors and implemented by the OV circuit). We can disentangle which of these is important by patching in just the attention pattern *or* the value vectors. (See [A Mathematical Framework](https://transformer-circuits.pub/2021/framework/index.html) or [my walkthrough video](https://www.youtube.com/watch?v=KV5gbOmHbjU) for more on this decomposition. If you're not familiar with the details of how attention is implemented, I recommend checking out [my clean transformer implementation](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/clean-transformer-demo/Clean_Transformer_Demo.ipynb#scrollTo=3Pb0NYbZ900e) to see how the code works))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First let's patch in the value vectors, to measure when figuring out what to move is important. . This has the same shape as z ([batch, pos, head_index, d_head]) so we can reuse the same hook." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "patched_head_v_diff = torch.zeros(\n", + " model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32\n", + ")\n", + "for layer in range(model.cfg.n_layers):\n", + " for head_index in range(model.cfg.n_heads):\n", + " hook_fn = partial(patch_head_vector, head_index=head_index, clean_cache=cache)\n", + " patched_logits = model.run_with_hooks(\n", + " corrupted_tokens,\n", + " fwd_hooks=[(utils.get_act_name(\"v\", layer, \"attn\"), hook_fn)],\n", + " return_type=\"logits\",\n", + " )\n", + " patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)\n", + "\n", + " patched_head_v_diff[layer, head_index] = normalize_patched_logit_diff(\n", + " patched_logit_diff\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can plot this as a heatmap and it's initially hard to interpret." + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Lets begin by patching in the residual stream at the start of each layer and for each token position. " - ] + "coloraxis": "coloraxis", + "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + -0.00019892427371814847, + 0.005339574534446001, + 0.0006527548539452255, + 0.003504416672512889, + -0.00898387935012579, + 0.0034814265090972185, + -0.0008631910313852131, + -3.406582254683599e-05, + 0.0005166929331608117, + 0.00044255363172851503, + -0.0039068968035280704, + -0.0001880836207419634 + ], + [ + -0.0004399022145662457, + -0.00044510437874123454, + -6.73597096465528e-05, + 7.242763240355998e-05, + -3.6549441574607044e-05, + -0.0019323208834975958, + -0.0001572397886775434, + 1.6143509128596634e-05, + 0.00020593880617525429, + 0.000336798548232764, + 0.0003515324497129768, + -0.0005669358652085066 + ], + [ + 0.00021013410878367722, + -0.0007199132232926786, + 0.0004868560063187033, + -0.0005974104860797524, + -0.0005921411793678999, + -0.0005443819100037217, + -0.000227552984142676, + -0.0004809825913980603, + 0.00020570388005580753, + 0.001183376181870699, + -0.0003574058646336198, + -0.0009104468626901507 + ], + [ + 0.0010395278222858906, + -0.00012042184971505776, + -7.762980385450646e-05, + -0.0007275318494066596, + -0.001310007064603269, + -0.0023108376190066338, + 0.010987084358930588, + -5.0712766096694395e-05, + 0.00014314358122646809, + 0.00015069512301124632, + -7.957642083056271e-05, + -2.0238119759596884e-05 + ], + [ + -0.0005373673629947007, + -0.0008137872209772468, + -0.00013334336108528078, + 0.030609702691435814, + -0.007185807917267084, + 0.000148916311445646, + 0.0013340713921934366, + -0.01142292469739914, + -0.0005336419562809169, + 0.0005126654868945479, + 0.00037344868178479373, + 0.0029547319281846285 + ], + [ + 8.22278525447473e-06, + 6.477540864580078e-06, + 0.0015973682748153806, + 0.00034015480196103454, + -0.0012577504385262728, + -5.450531898532063e-05, + 0.0006331544718705118, + -0.00027081489679403603, + 7.427356467815116e-05, + -0.006704355590045452, + 0.003175975289195776, + -0.0017300404142588377 + ], + [ + 0.04863045737147331, + 0.015314852818846703, + -0.0004648726317100227, + -0.00011676354915834963, + -4.930314753437415e-05, + -0.003952810075134039, + -0.01737578585743904, + -0.00015421917487401515, + 0.0012194222072139382, + -0.00018090127559844404, + -0.00042647725786082447, + 0.00012334177154116333 + ], + [ + -2.956846401502844e-05, + -0.0013855225406587124, + -0.00012129446986364201, + 0.1332160234451294, + -0.00024490474606864154, + -0.007315828464925289, + 0.00033297244226559997, + -0.000795092957559973, + -0.007938209921121597, + 0.208413764834404, + -0.00019127204723190516, + -0.00020650937221944332 + ], + [ + -0.0020483459811657667, + -0.0003764357534237206, + -0.0033135139383375645, + -0.009666135534644127, + -0.00031723169377073646, + -0.005141589790582657, + 0.31717124581336975, + 0.0028427678626030684, + 0.0004723234742414206, + -0.0011529687326401472, + 0.2726709246635437, + -0.003175639547407627 + ], + [ + -0.00043929810635745525, + 5.7089622714556754e-05, + -0.0020629793871194124, + 0.020066648721694946, + -0.007871017791330814, + 0.011316264048218727, + 0.003056862158700824, + 0.06856372952461243, + -0.002747517777606845, + -0.009279227815568447, + 0.000506624230183661, + -0.0013159140944480896 + ], + [ + -0.012957162223756313, + -0.0030454176012426615, + -0.01792328804731369, + -0.0043589151464402676, + -0.0011521632550284266, + 0.0004999117809347808, + -0.0031131464056670666, + 0.019585633650422096, + 4.34632929682266e-05, + 0.01297028549015522, + -0.007695754989981651, + -0.0009146086522378027 + ], + [ + 0.004100752994418144, + -0.020459463819861412, + -0.035875942558050156, + 0.014656225219368935, + 0.0008441276149824262, + 0.0017804511589929461, + -0.01804223284125328, + 0.003519016318023205, + 0.008253024891018867, + -0.0017665562918409705, + 0.044167667627334595, + 0.006474285386502743 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We first create a set of corrupted tokens - where we swap each pair of prompts to have the opposite answer." + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Corrupted Average Logit Diff -3.55\n", - "Clean Average Logit Diff 3.55\n" - ] - } - ], - "source": [ - "corrupted_prompts = []\n", - "for i in range(0, len(prompts), 2):\n", - " corrupted_prompts.append(prompts[i + 1])\n", - " corrupted_prompts.append(prompts[i])\n", - "corrupted_tokens = model.to_tokens(corrupted_prompts, prepend_bos=True)\n", - "corrupted_logits, corrupted_cache = model.run_with_cache(\n", - " corrupted_tokens, return_type=\"logits\"\n", - ")\n", - "corrupted_average_logit_diff = logits_to_ave_logit_diff(corrupted_logits, answer_tokens)\n", - "print(\"Corrupted Average Logit Diff\", round(corrupted_average_logit_diff.item(), 2))\n", - "print(\"Clean Average Logit Diff\", round(original_average_logit_diff.item(), 2))" + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['<|endoftext|>When John and Mary went to the shops, Mary gave the bag to',\n", - " '<|endoftext|>When John and Mary went to the shops, John gave the bag to',\n", - " '<|endoftext|>When Tom and James went to the park, Tom gave the ball to',\n", - " '<|endoftext|>When Tom and James went to the park, James gave the ball to',\n", - " '<|endoftext|>When Dan and Sid went to the shops, Dan gave an apple to',\n", - " '<|endoftext|>When Dan and Sid went to the shops, Sid gave an apple to',\n", - " '<|endoftext|>After Martin and Amy went to the park, Martin gave a drink to',\n", - " '<|endoftext|>After Martin and Amy went to the park, Amy gave a drink to']" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model.to_string(corrupted_tokens)" + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We now intervene on the corrupted run and patch in the clean residual stream at a specific layer and position.\n", - "\n", - "We do the intervention using TransformerLens's `HookPoint` feature. We can design a hook function that takes in a specific activation and returns an edited copy, and temporarily add it in with `model.run_with_hooks`. " - ] + "title": { + "text": "Logit Difference From Patched Head Value" }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [], - "source": [ - "def patch_residual_component(\n", - " corrupted_residual_component: Float[torch.Tensor, \"batch pos d_model\"],\n", - " hook,\n", - " pos,\n", - " clean_cache,\n", - "):\n", - " corrupted_residual_component[:, pos, :] = clean_cache[hook.name][:, pos, :]\n", - " return corrupted_residual_component\n", - "\n", - "\n", - "def normalize_patched_logit_diff(patched_logit_diff):\n", - " # Subtract corrupted logit diff to measure the improvement, divide by the total improvement from clean to corrupted to normalise\n", - " # 0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively *improved* on clean performance\n", - " return (patched_logit_diff - corrupted_average_logit_diff) / (\n", - " original_average_logit_diff - corrupted_average_logit_diff\n", - " )\n", - "\n", - "\n", - "patched_residual_stream_diff = torch.zeros(\n", - " model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32\n", - ")\n", - "for layer in range(model.cfg.n_layers):\n", - " for position in range(tokens.shape[1]):\n", - " hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)\n", - " patched_logits = model.run_with_hooks(\n", - " corrupted_tokens,\n", - " fwd_hooks=[(utils.get_act_name(\"resid_pre\", layer), hook_fn)],\n", - " return_type=\"logits\",\n", - " )\n", - " patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)\n", - "\n", - " patched_residual_stream_diff[layer, position] = normalize_patched_logit_diff(\n", - " patched_logit_diff\n", - " )" - ] + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "scaleanchor": "y", + "title": { + "text": "Head" + } }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Layer" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "imshow(\n", + " patched_head_v_diff,\n", + " title=\"Logit Difference From Patched Head Value\",\n", + " labels={\"x\": \"Head\", \"y\": \"Layer\"},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "But it's very easy to interpret if we plot a scatter plot against patching head outputs. Here we see that the earlier heads (L5H5, L6H9, L3H0) and late name movers (L9H9, L10H7, L11H10) don't matter at all now, while the mid-late heads (L8H6, L8H10, L7H9) do. \n", + "\n", + "Meta lesson: Plot things early, often and in diverse ways as you explore a model's internals!" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can immediately see that, exactly as predicted, originally all relevant computation happens on the second subject token, and at layers 7 and 8, the information is moved to the final token. Moving the residual stream at the correct position near *exactly* recovers performance!\n", - "\n", - "For reference, tokens and their index from the first prompt are on the x-axis. In an abuse of notation, note that the difference here is averaged over *all* 8 prompts, while the labels only come from the *first* prompt. \n", - "\n", - "To be easier to interpret, we normalise the logit difference, by subtracting the corrupted logit difference, and dividing by the total improvement from clean to corrupted to normalise\n", - "0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively *improved* on clean performance" - ] + "hovertemplate": "%{hovertext}

Value Patch=%{x}
Output Patch=%{y}
Layer=%{marker.color}", + "hovertext": [ + "L0H0", + "L0H1", + "L0H2", + "L0H3", + "L0H4", + "L0H5", + "L0H6", + "L0H7", + "L0H8", + "L0H9", + "L0H10", + "L0H11", + "L1H0", + "L1H1", + "L1H2", + "L1H3", + "L1H4", + "L1H5", + "L1H6", + "L1H7", + "L1H8", + "L1H9", + "L1H10", + "L1H11", + "L2H0", + "L2H1", + "L2H2", + "L2H3", + "L2H4", + "L2H5", + "L2H6", + "L2H7", + "L2H8", + "L2H9", + "L2H10", + "L2H11", + "L3H0", + "L3H1", + "L3H2", + "L3H3", + "L3H4", + "L3H5", + "L3H6", + "L3H7", + "L3H8", + "L3H9", + "L3H10", + "L3H11", + "L4H0", + "L4H1", + "L4H2", + "L4H3", + "L4H4", + "L4H5", + "L4H6", + "L4H7", + "L4H8", + "L4H9", + "L4H10", + "L4H11", + "L5H0", + "L5H1", + "L5H2", + "L5H3", + "L5H4", + "L5H5", + "L5H6", + "L5H7", + "L5H8", + "L5H9", + "L5H10", + "L5H11", + "L6H0", + "L6H1", + "L6H2", + "L6H3", + "L6H4", + "L6H5", + "L6H6", + "L6H7", + "L6H8", + "L6H9", + "L6H10", + "L6H11", + "L7H0", + "L7H1", + "L7H2", + "L7H3", + "L7H4", + "L7H5", + "L7H6", + "L7H7", + "L7H8", + "L7H9", + "L7H10", + "L7H11", + "L8H0", + "L8H1", + "L8H2", + "L8H3", + "L8H4", + "L8H5", + "L8H6", + "L8H7", + "L8H8", + "L8H9", + "L8H10", + "L8H11", + "L9H0", + "L9H1", + "L9H2", + "L9H3", + "L9H4", + "L9H5", + "L9H6", + "L9H7", + "L9H8", + "L9H9", + "L9H10", + "L9H11", + "L10H0", + "L10H1", + "L10H2", + "L10H3", + "L10H4", + "L10H5", + "L10H6", + "L10H7", + "L10H8", + "L10H9", + "L10H10", + "L10H11", + "L11H0", + "L11H1", + "L11H2", + "L11H3", + "L11H4", + "L11H5", + "L11H6", + "L11H7", + "L11H8", + "L11H9", + "L11H10", + "L11H11" + ], + "legendgroup": "", + "marker": { + "color": [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 11, + 11, + 11, + 11, + 11, + 11, + 11, + 11, + 11, + 11, + 11, + 11 + ], + "coloraxis": "coloraxis", + "symbol": "circle" + }, + "mode": "markers", + "name": "", + "orientation": "v", + "showlegend": false, + "type": "scatter", + "x": [ + -0.00019892427371814847, + 0.005339574534446001, + 0.0006527548539452255, + 0.003504416672512889, + -0.00898387935012579, + 0.0034814265090972185, + -0.0008631910313852131, + -3.406582254683599e-05, + 0.0005166929331608117, + 0.00044255363172851503, + -0.0039068968035280704, + -0.0001880836207419634, + -0.0004399022145662457, + -0.00044510437874123454, + -6.73597096465528e-05, + 7.242763240355998e-05, + -3.6549441574607044e-05, + -0.0019323208834975958, + -0.0001572397886775434, + 1.6143509128596634e-05, + 0.00020593880617525429, + 0.000336798548232764, + 0.0003515324497129768, + -0.0005669358652085066, + 0.00021013410878367722, + -0.0007199132232926786, + 0.0004868560063187033, + -0.0005974104860797524, + -0.0005921411793678999, + -0.0005443819100037217, + -0.000227552984142676, + -0.0004809825913980603, + 0.00020570388005580753, + 0.001183376181870699, + -0.0003574058646336198, + -0.0009104468626901507, + 0.0010395278222858906, + -0.00012042184971505776, + -7.762980385450646e-05, + -0.0007275318494066596, + -0.001310007064603269, + -0.0023108376190066338, + 0.010987084358930588, + -5.0712766096694395e-05, + 0.00014314358122646809, + 0.00015069512301124632, + -7.957642083056271e-05, + -2.0238119759596884e-05, + -0.0005373673629947007, + -0.0008137872209772468, + -0.00013334336108528078, + 0.030609702691435814, + -0.007185807917267084, + 0.000148916311445646, + 0.0013340713921934366, + -0.01142292469739914, + -0.0005336419562809169, + 0.0005126654868945479, + 0.00037344868178479373, + 0.0029547319281846285, + 8.22278525447473e-06, + 6.477540864580078e-06, + 0.0015973682748153806, + 0.00034015480196103454, + -0.0012577504385262728, + -5.450531898532063e-05, + 0.0006331544718705118, + -0.00027081489679403603, + 7.427356467815116e-05, + -0.006704355590045452, + 0.003175975289195776, + -0.0017300404142588377, + 0.04863045737147331, + 0.015314852818846703, + -0.0004648726317100227, + -0.00011676354915834963, + -4.930314753437415e-05, + -0.003952810075134039, + -0.01737578585743904, + -0.00015421917487401515, + 0.0012194222072139382, + -0.00018090127559844404, + -0.00042647725786082447, + 0.00012334177154116333, + -2.956846401502844e-05, + -0.0013855225406587124, + -0.00012129446986364201, + 0.1332160234451294, + -0.00024490474606864154, + -0.007315828464925289, + 0.00033297244226559997, + -0.000795092957559973, + -0.007938209921121597, + 0.208413764834404, + -0.00019127204723190516, + -0.00020650937221944332, + -0.0020483459811657667, + -0.0003764357534237206, + -0.0033135139383375645, + -0.009666135534644127, + -0.00031723169377073646, + -0.005141589790582657, + 0.31717124581336975, + 0.0028427678626030684, + 0.0004723234742414206, + -0.0011529687326401472, + 0.2726709246635437, + -0.003175639547407627, + -0.00043929810635745525, + 5.7089622714556754e-05, + -0.0020629793871194124, + 0.020066648721694946, + -0.007871017791330814, + 0.011316264048218727, + 0.003056862158700824, + 0.06856372952461243, + -0.002747517777606845, + -0.009279227815568447, + 0.000506624230183661, + -0.0013159140944480896, + -0.012957162223756313, + -0.0030454176012426615, + -0.01792328804731369, + -0.0043589151464402676, + -0.0011521632550284266, + 0.0004999117809347808, + -0.0031131464056670666, + 0.019585633650422096, + 4.34632929682266e-05, + 0.01297028549015522, + -0.007695754989981651, + -0.0009146086522378027, + 0.004100752994418144, + -0.020459463819861412, + -0.035875942558050156, + 0.014656225219368935, + 0.0008441276149824262, + 0.0017804511589929461, + -0.01804223284125328, + 0.003519016318023205, + 0.008253024891018867, + -0.0017665562918409705, + 0.044167667627334595, + 0.006474285386502743 + ], + "xaxis": "x", + "y": [ + 0.0009487751522101462, + 0.016124747693538666, + 0.0018548924708738923, + 0.0034389030188322067, + -0.00982347596436739, + 0.011058605276048183, + -0.004063969012349844, + -0.0015792781487107277, + -0.0012082795146852732, + 0.003828897839412093, + -0.004256919026374817, + -0.0011422622483223677, + -0.0010771177476271987, + -0.00037898647133260965, + 2.5171791548928013e-06, + -0.00026067905128002167, + -0.00014146546891424805, + 0.0038321535103023052, + -0.0004293300735298544, + -0.00142992555629462, + -0.0009228314156644046, + 0.0006944393389858305, + 0.00043302192352712154, + -0.0035714071709662676, + -0.0004967569257132709, + 0.0008057993836700916, + 0.0005424688570201397, + -0.0005309234256856143, + -0.0007159864180721343, + -0.0010389237431809306, + -0.0009490771917626262, + -8.649027586216107e-05, + 0.0002766547549981624, + 0.0021084228064864874, + -0.0001975146442418918, + -0.0016405630158260465, + 0.1162627637386322, + 0.0002507446042727679, + -0.0014675153652206063, + -0.00039680811460129917, + 0.018962211906909943, + -0.00018764731066767126, + 0.011170871555805206, + -0.0013301445869728923, + -0.0007356539717875421, + -0.00030253134900704026, + -0.00014683544577565044, + -0.00022228369198273867, + -0.001650598249398172, + 0.0002927311579696834, + -0.00143563118763268, + 0.03084198758006096, + -0.007432155776768923, + -0.00028236035723239183, + 0.006017433945089579, + -0.011007187888026237, + -0.001266107545234263, + 0.0014901700196787715, + -0.0001800622121663764, + 0.002944394713267684, + -0.004211106337606907, + 0.0029597999528050423, + 0.002045023487880826, + 0.0013397098518908024, + -0.0012190865818411112, + 0.34349915385246277, + 0.0005632104002870619, + -0.0001262281439267099, + -0.00515326950699091, + 0.016240738332271576, + 0.01709030382335186, + -0.004175194539129734, + 0.039775289595127106, + 0.015226684510707855, + -0.0010229480685666203, + 0.0008072761120274663, + -0.004935584031045437, + -0.002123525831848383, + -0.014274083077907562, + 0.0013746818294748664, + 0.0014838266652077436, + 0.1302703619003296, + -0.00033616088330745697, + 0.0012919505825266242, + 0.00037177055492065847, + 0.019514480605721474, + 0.00022255218937061727, + 0.124249167740345, + -0.00040352059295400977, + -0.007652895525097847, + 0.0013010123511776328, + -0.0011253133416175842, + -0.007449474185705185, + 0.19224143028259277, + -0.003275118535384536, + -0.0005017912480980158, + -0.001007912098430097, + 3.091096004936844e-05, + -0.0008595998515374959, + 0.012359987013041973, + -0.0004041247011628002, + -0.004328910261392593, + 0.3185553252696991, + 0.002330605871975422, + 0.0021182901691645384, + 0.0001405928487656638, + 0.2779357433319092, + 0.005738262087106705, + 0.0058898297138512135, + -0.0009689796715974808, + 0.00912561360746622, + 0.020675739273428917, + -0.03700518235564232, + 0.014263041317462921, + -0.04828466475009918, + 0.05834139883518219, + 0.0006514795240946114, + 0.26360899209976196, + 0.0004918567719869316, + -0.00261044898070395, + 0.08374208211898804, + 0.020676210522651672, + -0.003743582172319293, + 0.01085072010755539, + -0.001096583902835846, + 0.00047430366976186633, + 0.04818058758974075, + -0.4799128472805023, + 0.00018429107149131596, + 0.011861988343298435, + 0.06088569387793541, + 0.0008461413672193885, + 0.005328264087438583, + -0.011493473313748837, + -0.11350836604833603, + 0.006329597905278206, + 0.00031669469899497926, + -0.0011600167490541935, + -0.022669579833745956, + 0.004070379305630922, + 0.0073160636238753796, + -0.00834545586258173, + -0.27817651629447937, + 0.0036344374530017376 + ], + "yaxis": "y" + } + ], + "layout": { + "coloraxis": { + "colorbar": { + "title": { + "text": "Layer" + } + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ - { - "coloraxis": "coloraxis", - "hovertemplate": "Position: %{x}
Layer: %{y}
color: %{z}", - "name": "0", - "type": "heatmap", - "x": [ - "<|endoftext|>_0", - "When_1", - " John_2", - " and_3", - " Mary_4", - " went_5", - " to_6", - " the_7", - " shops_8", - ",_9", - " John_10", - " gave_11", - " the_12", - " bag_13", - " to_14" - ], - "xaxis": "x", - "yaxis": "y", - "z": [ - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 1.000650405883789, - -0.0002469856117386371, - 0.00000976665523921838, - -0.00036458822432905436, - -0.000048967522161547095 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 1.001051902770996, - -0.000027621845219982788, - -0.000019768245692830533, - -0.0004596704675350338, - -0.0005947590689174831 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 1.0002663135528564, - 0.0008680911851115525, - 0.0005157867562957108, - -0.0009929431835189462, - -0.0008658089209347963 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.994907796382904, - 0.005429857410490513, - 0.0016050540143623948, - -0.0006193603039719164, - -0.0016324409516528249 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.9675672054290771, - 0.03134213387966156, - 0.0028418952133506536, - -0.0012302964460104704, - -0.000985861523076892 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.967520534992218, - 0.03100077249109745, - 0.0017823305679485202, - -0.00048668819363228977, - -0.0006467136554419994 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.9228319525718689, - 0.05134531855583191, - 0.004728672094643116, - 0.0009345446596853435, - 0.017046840861439705 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.6565483808517456, - 0.02385685034096241, - 0.002357019344344735, - -0.000017183941963594407, - 0.3186916410923004 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.027302566915750504, - 0.03142499923706055, - 0.0018202561186626554, - 0.0007990868762135506, - 0.9383866190910339 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.026841485872864723, - 0.02098155952990055, - 0.0012512058019638062, - 0.00032317222212441266, - 1.0048279762268066 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.005687985569238663, - 0.014263377524912357, - 0.00048709093243815005, - -0.00008977938705356792, - 0.9914212226867676 - ] - ] - } - ], - "layout": { - "coloraxis": { - "cmid": 0, - "colorscale": [ - [ - 0, - "rgb(103,0,31)" - ], - [ - 0.1, - "rgb(178,24,43)" - ], - [ - 0.2, - "rgb(214,96,77)" - ], - [ - 0.3, - "rgb(244,165,130)" - ], - [ - 0.4, - "rgb(253,219,199)" - ], - [ - 0.5, - "rgb(247,247,247)" - ], - [ - 0.6, - "rgb(209,229,240)" - ], - [ - 0.7, - "rgb(146,197,222)" - ], - [ - 0.8, - "rgb(67,147,195)" - ], - [ - 0.9, - "rgb(33,102,172)" - ], - [ - 1, - "rgb(5,48,97)" - ] - ] - }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" - ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } - }, - "title": { - "text": "Logit Difference From Patched Residual Stream" - }, - "xaxis": { - "anchor": "y", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "scaleanchor": "y", - "title": { - "text": "Position" - } - }, - "yaxis": { - "anchor": "x", - "autorange": "reversed", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "title": { - "text": "Layer" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "prompt_position_labels = [\n", - " f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(tokens[0]))\n", - "]\n", - "imshow(\n", - " patched_residual_stream_diff,\n", - " x=prompt_position_labels,\n", - " title=\"Logit Difference From Patched Residual Stream\",\n", - " labels={\"x\": \"Position\", \"y\": \"Layer\"},\n", - ")" - ] + "legend": { + "tracegroupgap": 0 }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Layers" + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can apply exactly the same idea, but this time patching in attention or MLP layers. These are also residual components with identical shapes to the residual stream terms, so we can reuse the same hooks." + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [], - "source": [ - "patched_attn_diff = torch.zeros(\n", - " model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32\n", - ")\n", - "patched_mlp_diff = torch.zeros(\n", - " model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32\n", - ")\n", - "for layer in range(model.cfg.n_layers):\n", - " for position in range(tokens.shape[1]):\n", - " hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)\n", - " patched_attn_logits = model.run_with_hooks(\n", - " corrupted_tokens,\n", - " fwd_hooks=[(utils.get_act_name(\"attn_out\", layer), hook_fn)],\n", - " return_type=\"logits\",\n", - " )\n", - " patched_attn_logit_diff = logits_to_ave_logit_diff(\n", - " patched_attn_logits, answer_tokens\n", - " )\n", - " patched_mlp_logits = model.run_with_hooks(\n", - " corrupted_tokens,\n", - " fwd_hooks=[(utils.get_act_name(\"mlp_out\", layer), hook_fn)],\n", - " return_type=\"logits\",\n", - " )\n", - " patched_mlp_logit_diff = logits_to_ave_logit_diff(\n", - " patched_mlp_logits, answer_tokens\n", - " )\n", - "\n", - " patched_attn_diff[layer, position] = normalize_patched_logit_diff(\n", - " patched_attn_logit_diff\n", - " )\n", - " patched_mlp_diff[layer, position] = normalize_patched_logit_diff(\n", - " patched_mlp_logit_diff\n", - " )" + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We see that several attention layers are significant but that, matching the residual stream results, early layers matter on the second subject token, and later layers matter on the final token, and layers essentially don't matter on any other token. Extremely localised! As with direct logit attribution, layer 9 is positive and layers 10 and 11 are not, suggesting that the late layers only matter for direct logit effects, but we also see that layers 7 and 8 matter significantly. Presumably these are the heads that move information about which name is duplicated from the second subject token to the final token." - ] + "title": { + "text": "Scatter plot of output patching vs value patching" }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ - { - "coloraxis": "coloraxis", - "hovertemplate": "Position: %{x}
Layer: %{y}
color: %{z}", - "name": "0", - "type": "heatmap", - "x": [ - "<|endoftext|>_0", - "When_1", - " John_2", - " and_3", - " Mary_4", - " went_5", - " to_6", - " the_7", - " shops_8", - ",_9", - " John_10", - " gave_11", - " the_12", - " bag_13", - " to_14" - ], - "xaxis": "x", - "yaxis": "y", - "z": [ - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.035456884652376175, - -0.0002469856117386371, - 0.00000976665523921838, - -0.00036458822432905436, - -0.000048967522161547095 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.0029848709236830473, - 0.00007950929284561425, - 0.000020842242520302534, - 0.00008088535105343908, - -0.0005967392353340983 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.0019131568260490894, - 0.0006668510613963008, - 0.00039482791908085346, - -0.0007051457650959492, - -0.00027282864903099835 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.1546323299407959, - 0.0038019807543605566, - 0.0005171628436073661, - -0.00011964991426793858, - -0.0005599213181994855 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.005406397394835949, - 0.019581740722060204, - 0.001007509301416576, - -0.0002424211270408705, - 0.0007936497568152845 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.3520970046520233, - 0.0010525835677981377, - 0.00022436455765273422, - 0.00013367898645810783, - 0.00008172441448550671 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.11986024677753448, - 0.021243548020720482, - 0.002727783052250743, - 0.0013409851817414165, - 0.01797366514801979 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.013310473412275314, - 0.011509180068969727, - 0.00037542887730523944, - -0.00004094611358596012, - 0.29760244488716125 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.0015009435592219234, - 0.017351653426885605, - 0.0005848917062394321, - 0.0010122752282768488, - 0.5697318911552429 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.00012901381705887616, - 0.00630143890157342, - 0.00014156615361571312, - 0.00031229801243171096, - 0.27152299880981445 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.0009373303619213402, - 0.00008669164526509121, - 0.00033243544748984277, - 9.73309283835988e-7, - -0.1929796040058136 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.40617984533309937 - ] - ] - } - ], - "layout": { - "coloraxis": { - "cmid": 0, - "colorscale": [ - [ - 0, - "rgb(103,0,31)" - ], - [ - 0.1, - "rgb(178,24,43)" - ], - [ - 0.2, - "rgb(214,96,77)" - ], - [ - 0.3, - "rgb(244,165,130)" - ], - [ - 0.4, - "rgb(253,219,199)" - ], - [ - 0.5, - "rgb(247,247,247)" - ], - [ - 0.6, - "rgb(209,229,240)" - ], - [ - 0.7, - "rgb(146,197,222)" - ], - [ - 0.8, - "rgb(67,147,195)" - ], - [ - 0.9, - "rgb(33,102,172)" - ], - [ - 1, - "rgb(5,48,97)" - ] - ] - }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" - ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } - }, - "title": { - "text": "Logit Difference From Patched Attention Layer" - }, - "xaxis": { - "anchor": "y", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "scaleanchor": "y", - "title": { - "text": "Position" - } - }, - "yaxis": { - "anchor": "x", - "autorange": "reversed", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "title": { - "text": "Layer" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "imshow(\n", - " patched_attn_diff,\n", - " x=prompt_position_labels,\n", - " title=\"Logit Difference From Patched Attention Layer\",\n", - " labels={\"x\": \"Position\", \"y\": \"Layer\"},\n", - ")" - ] + "xaxis": { + "anchor": "y", + "domain": [ + 0, + 1 + ], + "range": [ + -0.5, + 0.5 + ], + "title": { + "text": "Value Patch" + } }, + "yaxis": { + "anchor": "x", + "domain": [ + 0, + 1 + ], + "range": [ + -0.5, + 0.5 + ], + "title": { + "text": "Output Patch" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "head_labels = [\n", + " f\"L{l}H{h}\" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)\n", + "]\n", + "scatter(\n", + " x=utils.to_numpy(patched_head_v_diff.flatten()),\n", + " y=utils.to_numpy(patched_head_z_diff.flatten()),\n", + " xaxis=\"Value Patch\",\n", + " yaxis=\"Output Patch\",\n", + " caxis=\"Layer\",\n", + " hover_name=head_labels,\n", + " color=einops.repeat(\n", + " np.arange(model.cfg.n_layers), \"layer -> (layer head)\", head=model.cfg.n_heads\n", + " ),\n", + " range_x=(-0.5, 0.5),\n", + " range_y=(-0.5, 0.5),\n", + " title=\"Scatter plot of output patching vs value patching\",\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When we patch in attention patterns, we see the opposite effect - early and late heads matter a lot, middle heads don't. (In fact, the sum of value patching and pattern patching is approx the same as output patching)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "def patch_head_pattern(\n", + " corrupted_head_pattern: Float[torch.Tensor, \"batch head_index query_pos d_head\"],\n", + " hook,\n", + " head_index,\n", + " clean_cache,\n", + "):\n", + " corrupted_head_pattern[:, head_index, :, :] = clean_cache[hook.name][\n", + " :, head_index, :, :\n", + " ]\n", + " return corrupted_head_pattern\n", + "\n", + "\n", + "patched_head_attn_diff = torch.zeros(\n", + " model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32\n", + ")\n", + "for layer in range(model.cfg.n_layers):\n", + " for head_index in range(model.cfg.n_heads):\n", + " hook_fn = partial(patch_head_pattern, head_index=head_index, clean_cache=cache)\n", + " patched_logits = model.run_with_hooks(\n", + " corrupted_tokens,\n", + " fwd_hooks=[(utils.get_act_name(\"attn\", layer, \"attn\"), hook_fn)],\n", + " return_type=\"logits\",\n", + " )\n", + " patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)\n", + "\n", + " patched_head_attn_diff[layer, head_index] = normalize_patched_logit_diff(\n", + " patched_logit_diff\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In contrast, the MLP layers do not matter much. This makes sense, since this is more a task about moving information than about processing it, and the MLP layers specialise in processing information.\n", - "\n", - "The one exception is MLP 0, which matters a lot, but I think this is misleading and just a generally true statement about MLP 0 rather than being about the circuit on this task.\n", - "\n", - "
My takes on MLP0 \n", - "It's often observed on GPT-2 Small that MLP0 matters a lot, and that ablating it utterly destroys performance. My current best guess is that the first MLP layer is essentially acting as an extension of the embedding (for whatever reason) and that when later layers want to access the input tokens they mostly read in the output of the first MLP layer, rather than the token embeddings. Within this frame, the first attention layer doesn't do much. \n", - "\n", - "In this framing, it makes sense that MLP0 matters on the second subject token, because that's the one position with a different input token!\n", - "\n", - "I'm not entirely sure why this happens, but I would guess that it's because the embedding and unembedding matrices in GPT-2 Small are the same. This is pretty unprincipled, as the tasks of embedding and unembedding tokens are not inverses, but this is common practice, and plausibly models want to dedicate some parameters to overcoming this. \n", - "\n", - "I only have suggestive evidence of this, and would love to see someone look into this properly!\n", - "
" - ] + "coloraxis": "coloraxis", + "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + 0.0006401354330591857, + 0.005318799521774054, + 0.0011584057938307524, + -5.920405237702653e-05, + -0.00106671336106956, + 0.005079298280179501, + -0.0030818663071841, + -0.0020521720871329308, + -0.0014405983965843916, + 0.003492669900879264, + -0.002568227471783757, + -0.0009168237447738647 + ], + [ + -0.0007600873941555619, + 0.0001683824957581237, + 0.00012246915139257908, + -0.00034914951538667083, + 1.4901700524205808e-05, + 0.0050090523436665535, + -0.0002975976967718452, + -0.0014448943547904491, + -0.001099134678952396, + 0.00047447148244827986, + 5.195457561057992e-05, + -0.0034954219590872526 + ], + [ + -0.0007243098807521164, + 0.0017458146903663874, + -0.00015556166181340814, + 5.7626621128292754e-05, + -9.7398049547337e-05, + -0.0004238593974150717, + -0.0007917031762190163, + 0.00027222454082220793, + 0.00010179472155869007, + 0.0004223826399538666, + 0.00015193692524917424, + -0.0007437760941684246 + ], + [ + 0.11458104848861694, + 0.00021140948229003698, + -0.0009424989693798125, + 0.000429833511589095, + 0.02004295401275158, + 0.002104730810970068, + 7.628730963915586e-05, + -0.001543701975606382, + -0.0008484235731884837, + -0.0005819046637043357, + 0.00011921360419364646, + -1.899631206470076e-05 + ], + [ + -0.001127125695347786, + 0.001237143180333078, + -0.0012324444251134992, + -0.0005952289211563766, + -0.0007541133090853691, + -0.0005842540413141251, + 0.004813014063984156, + 0.00018187458044849336, + -0.0005361591465771198, + 0.0008579217828810215, + -0.0002985374303534627, + -1.144477391790133e-05 + ], + [ + -0.004241178277879953, + 0.0029509058222174644, + 0.0005218615406192839, + 0.0009535074350424111, + 0.0001622070267330855, + 0.34350839257240295, + -0.0003052163519896567, + 0.00010293584637111053, + -0.005300541408360004, + 0.024864863604307175, + 0.014383262023329735, + -0.0023285921197384596 + ], + [ + -0.0023893399629741907, + -0.002172795357182622, + -0.00047614958020858467, + 0.00043188079143874347, + -0.004675475414842367, + 0.0018583494238555431, + -0.0026542814448475838, + 0.0014367386465892196, + 0.00030326974228955805, + 0.13043038547039032, + 8.813483145786449e-05, + 0.0011766973184421659 + ], + [ + 0.00031847349600866437, + 0.02057075686752796, + 0.00031840638257563114, + -0.002512782346457243, + -0.0002628941729199141, + -0.00024718698114156723, + 0.0005524033331312239, + -0.00043131023994646966, + 0.00025715501396916807, + 0.008090951479971409, + -0.0030689111445099115, + -0.0004238593974150717 + ], + [ + 0.000976699055172503, + 0.00039251212729141116, + 0.0017534669023007154, + 0.022595642134547234, + -4.4805787183577195e-05, + 0.00014220383309293538, + 0.009584981948137283, + -0.0003157213795930147, + 0.0015271222218871117, + 0.0011813960736617446, + -0.010774029418826103, + 0.00936581939458847 + ], + [ + 0.006314125377684832, + -0.0010949057759717107, + 0.011662023141980171, + 0.0013481340138241649, + -0.02918696030974388, + 0.0038333951961249113, + -0.04409456625580788, + -0.005032042507082224, + 0.00482167350128293, + 0.2766477167606354, + -3.164933150401339e-05, + -0.0006618167390115559 + ], + [ + 0.0953889712691307, + 0.02506939135491848, + 0.014239178970456123, + 0.014754998497664928, + 9.890835644910112e-05, + -8.977938705356792e-05, + 0.05082912743091583, + -0.5051022171974182, + 0.00014696970174554735, + -0.0016026375815272331, + 0.06883199512958527, + 0.002327115274965763 + ], + [ + 0.0013425961369648576, + 0.009630928747355938, + -0.07776415348052979, + -0.007728713098913431, + -0.0005726079107262194, + -0.002957182005047798, + -0.0049475994892418385, + 0.00045916702947579324, + -0.0006328188464976847, + -0.006520198658108711, + -0.3204910457134247, + -0.002473111730068922 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ - { - "coloraxis": "coloraxis", - "hovertemplate": "Position: %{x}
Layer: %{y}
color: %{z}", - "name": "0", - "type": "heatmap", - "x": [ - "<|endoftext|>_0", - "When_1", - " John_2", - " and_3", - " Mary_4", - " went_5", - " to_6", - " the_7", - " shops_8", - ",_9", - " John_10", - " gave_11", - " the_12", - " bag_13", - " to_14" - ], - "xaxis": "x", - "yaxis": "y", - "z": [ - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.8507890701293945, - -0.00027843358111567795, - -0.00007293107046280056, - -0.00047373308916576207, - 0.000040039929444901645 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.008863994851708412, - 0.000222149450564757, - 0.00014938619278836995, - -0.00004853121208725497, - 0.000304041663184762 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.013550343923270702, - 0.0000586334899708163, - -0.0003296833310741931, - -0.0006382559076882899, - 0.0007730424986220896 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.0019468198297545314, - 0.0004995090421289206, - 0.00017318192112725228, - 0.00016871812113095075, - 0.00040764876757748425 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.019787074998021126, - 0.004128609783947468, - -0.0000486990247736685, - -0.00017019486404024065, - 0.0007914346642792225 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.09652391821146011, - -0.0018826150335371494, - -0.0004844730719923973, - 0.0007094081956893206, - -0.00018335132335778326 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.015900013968348503, - -0.0008501688134856522, - 0.00012337534280959517, - 0.000027521158699528314, - -0.007238299585878849 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.010360540822148323, - 0.0031509376130998135, - 0.0005309234256856143, - 0.0002361114020459354, - 0.008496351540088654 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.012533102184534073, - 0.00002201692586822901, - -0.00035374757135286927, - 0.00008615465048933402, - -0.021631328389048576 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.00033465056912973523, - 0.0008094912045635283, - 0.000016244195649051107, - 0.00012924875773023814, - 0.03162466362118721 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.0013599144294857979, - -0.00019499746849760413, - -0.00009934466652339324, - -0.00014217027637641877, - 0.028764141723513603 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.02044912613928318 - ] - ] - } - ], - "layout": { - "coloraxis": { - "cmid": 0, - "colorscale": [ - [ - 0, - "rgb(103,0,31)" - ], - [ - 0.1, - "rgb(178,24,43)" - ], - [ - 0.2, - "rgb(214,96,77)" - ], - [ - 0.3, - "rgb(244,165,130)" - ], - [ - 0.4, - "rgb(253,219,199)" - ], - [ - 0.5, - "rgb(247,247,247)" - ], - [ - 0.6, - "rgb(209,229,240)" - ], - [ - 0.7, - "rgb(146,197,222)" - ], - [ - 0.8, - "rgb(67,147,195)" - ], - [ - 0.9, - "rgb(33,102,172)" - ], - [ - 1, - "rgb(5,48,97)" - ] - ] - }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" - ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } - }, - "title": { - "text": "Logit Difference From Patched MLP Layer" - }, - "xaxis": { - "anchor": "y", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "scaleanchor": "y", - "title": { - "text": "Position" - } - }, - "yaxis": { - "anchor": "x", - "autorange": "reversed", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "title": { - "text": "Layer" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "imshow(\n", - " patched_mlp_diff,\n", - " x=prompt_position_labels,\n", - " title=\"Logit Difference From Patched MLP Layer\",\n", - " labels={\"x\": \"Position\", \"y\": \"Layer\"},\n", - ")" + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Heads\n", - "\n", - "We can refine the above analysis by patching in individual heads! This is somewhat more annoying, because there are now three dimensions (head_index, position and layer), so for now lets patch in a head's output across all positions.\n", - "\n", - "The easiest way to do this is to patch in the activation `z`, the \"mixed value\" of the attention head. That is, the average of all previous values weighted by the attention pattern, ie the activation that is then multiplied by `W_O`, the output weights. " + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [], - "source": [ - "def patch_head_vector(\n", - " corrupted_head_vector: Float[torch.Tensor, \"batch pos head_index d_head\"],\n", - " hook,\n", - " head_index,\n", - " clean_cache,\n", - "):\n", - " corrupted_head_vector[:, :, head_index, :] = clean_cache[hook.name][\n", - " :, :, head_index, :\n", - " ]\n", - " return corrupted_head_vector\n", - "\n", - "\n", - "patched_head_z_diff = torch.zeros(\n", - " model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32\n", - ")\n", - "for layer in range(model.cfg.n_layers):\n", - " for head_index in range(model.cfg.n_heads):\n", - " hook_fn = partial(patch_head_vector, head_index=head_index, clean_cache=cache)\n", - " patched_logits = model.run_with_hooks(\n", - " corrupted_tokens,\n", - " fwd_hooks=[(utils.get_act_name(\"z\", layer, \"attn\"), hook_fn)],\n", - " return_type=\"logits\",\n", - " )\n", - " patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)\n", - "\n", - " patched_head_z_diff[layer, head_index] = normalize_patched_logit_diff(\n", - " patched_logit_diff\n", - " )" + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can now see that, in addition to the name mover heads identified before, in mid-late layers the heads L8H6, L8H10, L7H9 matter and are presumably responsible for moving information from the second subject to the final token. And heads L5H5, L6H9, L3H0 also matter a lot, and are presumably involved in detecting duplicated tokens." - ] + "title": { + "text": "Logit Difference From Patched Head Pattern" }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ - { - "coloraxis": "coloraxis", - "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", - "name": "0", - "type": "heatmap", - "xaxis": "x", - "yaxis": "y", - "z": [ - [ - 0.0009487751522101462, - 0.016124747693538666, - 0.0018548924708738923, - 0.0034389030188322067, - -0.00982347596436739, - 0.011058605276048183, - -0.004063969012349844, - -0.0015792781487107277, - -0.0012082795146852732, - 0.003828897839412093, - -0.004256919026374817, - -0.0011422622483223677 - ], - [ - -0.0010771177476271987, - -0.00037898647133260965, - 0.0000025171791548928013, - -0.00026067905128002167, - -0.00014146546891424805, - 0.0038321535103023052, - -0.0004293300735298544, - -0.00142992555629462, - -0.0009228314156644046, - 0.0006944393389858305, - 0.00043302192352712154, - -0.0035714071709662676 - ], - [ - -0.0004967569257132709, - 0.0008057993836700916, - 0.0005424688570201397, - -0.0005309234256856143, - -0.0007159864180721343, - -0.0010389237431809306, - -0.0009490771917626262, - -0.00008649027586216107, - 0.0002766547549981624, - 0.0021084228064864874, - -0.0001975146442418918, - -0.0016405630158260465 - ], - [ - 0.1162627637386322, - 0.0002507446042727679, - -0.0014675153652206063, - -0.00039680811460129917, - 0.018962211906909943, - -0.00018764731066767126, - 0.011170871555805206, - -0.0013301445869728923, - -0.0007356539717875421, - -0.00030253134900704026, - -0.00014683544577565044, - -0.00022228369198273867 - ], - [ - -0.001650598249398172, - 0.0002927311579696834, - -0.00143563118763268, - 0.03084198758006096, - -0.007432155776768923, - -0.00028236035723239183, - 0.006017433945089579, - -0.011007187888026237, - -0.001266107545234263, - 0.0014901700196787715, - -0.0001800622121663764, - 0.002944394713267684 - ], - [ - -0.004211106337606907, - 0.0029597999528050423, - 0.002045023487880826, - 0.0013397098518908024, - -0.0012190865818411112, - 0.34349915385246277, - 0.0005632104002870619, - -0.0001262281439267099, - -0.00515326950699091, - 0.016240738332271576, - 0.01709030382335186, - -0.004175194539129734 - ], - [ - 0.039775289595127106, - 0.015226684510707855, - -0.0010229480685666203, - 0.0008072761120274663, - -0.004935584031045437, - -0.002123525831848383, - -0.014274083077907562, - 0.0013746818294748664, - 0.0014838266652077436, - 0.1302703619003296, - -0.00033616088330745697, - 0.0012919505825266242 - ], - [ - 0.00037177055492065847, - 0.019514480605721474, - 0.00022255218937061727, - 0.124249167740345, - -0.00040352059295400977, - -0.007652895525097847, - 0.0013010123511776328, - -0.0011253133416175842, - -0.007449474185705185, - 0.19224143028259277, - -0.003275118535384536, - -0.0005017912480980158 - ], - [ - -0.001007912098430097, - 0.00003091096004936844, - -0.0008595998515374959, - 0.012359987013041973, - -0.0004041247011628002, - -0.004328910261392593, - 0.3185553252696991, - 0.002330605871975422, - 0.0021182901691645384, - 0.0001405928487656638, - 0.2779357433319092, - 0.005738262087106705 - ], - [ - 0.0058898297138512135, - -0.0009689796715974808, - 0.00912561360746622, - 0.020675739273428917, - -0.03700518235564232, - 0.014263041317462921, - -0.04828466475009918, - 0.05834139883518219, - 0.0006514795240946114, - 0.26360899209976196, - 0.0004918567719869316, - -0.00261044898070395 - ], - [ - 0.08374208211898804, - 0.020676210522651672, - -0.003743582172319293, - 0.01085072010755539, - -0.001096583902835846, - 0.00047430366976186633, - 0.04818058758974075, - -0.4799128472805023, - 0.00018429107149131596, - 0.011861988343298435, - 0.06088569387793541, - 0.0008461413672193885 - ], - [ - 0.005328264087438583, - -0.011493473313748837, - -0.11350836604833603, - 0.006329597905278206, - 0.00031669469899497926, - -0.0011600167490541935, - -0.022669579833745956, - 0.004070379305630922, - 0.0073160636238753796, - -0.00834545586258173, - -0.27817651629447937, - 0.0036344374530017376 - ] - ] - } - ], - "layout": { - "coloraxis": { - "cmid": 0, - "colorscale": [ - [ - 0, - "rgb(103,0,31)" - ], - [ - 0.1, - "rgb(178,24,43)" - ], - [ - 0.2, - "rgb(214,96,77)" - ], - [ - 0.3, - "rgb(244,165,130)" - ], - [ - 0.4, - "rgb(253,219,199)" - ], - [ - 0.5, - "rgb(247,247,247)" - ], - [ - 0.6, - "rgb(209,229,240)" - ], - [ - 0.7, - "rgb(146,197,222)" - ], - [ - 0.8, - "rgb(67,147,195)" - ], - [ - 0.9, - "rgb(33,102,172)" - ], - [ - 1, - "rgb(5,48,97)" - ] - ] - }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" - ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } - }, - "title": { - "text": "Logit Difference From Patched Head Output" - }, - "xaxis": { - "anchor": "y", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "scaleanchor": "y", - "title": { - "text": "Head" - } - }, - "yaxis": { - "anchor": "x", - "autorange": "reversed", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "title": { - "text": "Layer" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "imshow(\n", - " patched_head_z_diff,\n", - " title=\"Logit Difference From Patched Head Output\",\n", - " labels={\"x\": \"Head\", \"y\": \"Layer\"},\n", - ")" - ] + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "scaleanchor": "y", + "title": { + "text": "Head" + } }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Layer" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Decomposing Heads" - ] + "hovertemplate": "%{hovertext}

Attention Patch=%{x}
Output Patch=%{y}", + "hovertext": [ + "L0H0", + "L0H1", + "L0H2", + "L0H3", + "L0H4", + "L0H5", + "L0H6", + "L0H7", + "L0H8", + "L0H9", + "L0H10", + "L0H11", + "L1H0", + "L1H1", + "L1H2", + "L1H3", + "L1H4", + "L1H5", + "L1H6", + "L1H7", + "L1H8", + "L1H9", + "L1H10", + "L1H11", + "L2H0", + "L2H1", + "L2H2", + "L2H3", + "L2H4", + "L2H5", + "L2H6", + "L2H7", + "L2H8", + "L2H9", + "L2H10", + "L2H11", + "L3H0", + "L3H1", + "L3H2", + "L3H3", + "L3H4", + "L3H5", + "L3H6", + "L3H7", + "L3H8", + "L3H9", + "L3H10", + "L3H11", + "L4H0", + "L4H1", + "L4H2", + "L4H3", + "L4H4", + "L4H5", + "L4H6", + "L4H7", + "L4H8", + "L4H9", + "L4H10", + "L4H11", + "L5H0", + "L5H1", + "L5H2", + "L5H3", + "L5H4", + "L5H5", + "L5H6", + "L5H7", + "L5H8", + "L5H9", + "L5H10", + "L5H11", + "L6H0", + "L6H1", + "L6H2", + "L6H3", + "L6H4", + "L6H5", + "L6H6", + "L6H7", + "L6H8", + "L6H9", + "L6H10", + "L6H11", + "L7H0", + "L7H1", + "L7H2", + "L7H3", + "L7H4", + "L7H5", + "L7H6", + "L7H7", + "L7H8", + "L7H9", + "L7H10", + "L7H11", + "L8H0", + "L8H1", + "L8H2", + "L8H3", + "L8H4", + "L8H5", + "L8H6", + "L8H7", + "L8H8", + "L8H9", + "L8H10", + "L8H11", + "L9H0", + "L9H1", + "L9H2", + "L9H3", + "L9H4", + "L9H5", + "L9H6", + "L9H7", + "L9H8", + "L9H9", + "L9H10", + "L9H11", + "L10H0", + "L10H1", + "L10H2", + "L10H3", + "L10H4", + "L10H5", + "L10H6", + "L10H7", + "L10H8", + "L10H9", + "L10H10", + "L10H11", + "L11H0", + "L11H1", + "L11H2", + "L11H3", + "L11H4", + "L11H5", + "L11H6", + "L11H7", + "L11H8", + "L11H9", + "L11H10", + "L11H11" + ], + "legendgroup": "", + "marker": { + "color": "#636efa", + "symbol": "circle" + }, + "mode": "markers", + "name": "", + "orientation": "v", + "showlegend": false, + "type": "scatter", + "x": [ + 0.0006401354330591857, + 0.005318799521774054, + 0.0011584057938307524, + -5.920405237702653e-05, + -0.00106671336106956, + 0.005079298280179501, + -0.0030818663071841, + -0.0020521720871329308, + -0.0014405983965843916, + 0.003492669900879264, + -0.002568227471783757, + -0.0009168237447738647, + -0.0007600873941555619, + 0.0001683824957581237, + 0.00012246915139257908, + -0.00034914951538667083, + 1.4901700524205808e-05, + 0.0050090523436665535, + -0.0002975976967718452, + -0.0014448943547904491, + -0.001099134678952396, + 0.00047447148244827986, + 5.195457561057992e-05, + -0.0034954219590872526, + -0.0007243098807521164, + 0.0017458146903663874, + -0.00015556166181340814, + 5.7626621128292754e-05, + -9.7398049547337e-05, + -0.0004238593974150717, + -0.0007917031762190163, + 0.00027222454082220793, + 0.00010179472155869007, + 0.0004223826399538666, + 0.00015193692524917424, + -0.0007437760941684246, + 0.11458104848861694, + 0.00021140948229003698, + -0.0009424989693798125, + 0.000429833511589095, + 0.02004295401275158, + 0.002104730810970068, + 7.628730963915586e-05, + -0.001543701975606382, + -0.0008484235731884837, + -0.0005819046637043357, + 0.00011921360419364646, + -1.899631206470076e-05, + -0.001127125695347786, + 0.001237143180333078, + -0.0012324444251134992, + -0.0005952289211563766, + -0.0007541133090853691, + -0.0005842540413141251, + 0.004813014063984156, + 0.00018187458044849336, + -0.0005361591465771198, + 0.0008579217828810215, + -0.0002985374303534627, + -1.144477391790133e-05, + -0.004241178277879953, + 0.0029509058222174644, + 0.0005218615406192839, + 0.0009535074350424111, + 0.0001622070267330855, + 0.34350839257240295, + -0.0003052163519896567, + 0.00010293584637111053, + -0.005300541408360004, + 0.024864863604307175, + 0.014383262023329735, + -0.0023285921197384596, + -0.0023893399629741907, + -0.002172795357182622, + -0.00047614958020858467, + 0.00043188079143874347, + -0.004675475414842367, + 0.0018583494238555431, + -0.0026542814448475838, + 0.0014367386465892196, + 0.00030326974228955805, + 0.13043038547039032, + 8.813483145786449e-05, + 0.0011766973184421659, + 0.00031847349600866437, + 0.02057075686752796, + 0.00031840638257563114, + -0.002512782346457243, + -0.0002628941729199141, + -0.00024718698114156723, + 0.0005524033331312239, + -0.00043131023994646966, + 0.00025715501396916807, + 0.008090951479971409, + -0.0030689111445099115, + -0.0004238593974150717, + 0.000976699055172503, + 0.00039251212729141116, + 0.0017534669023007154, + 0.022595642134547234, + -4.4805787183577195e-05, + 0.00014220383309293538, + 0.009584981948137283, + -0.0003157213795930147, + 0.0015271222218871117, + 0.0011813960736617446, + -0.010774029418826103, + 0.00936581939458847, + 0.006314125377684832, + -0.0010949057759717107, + 0.011662023141980171, + 0.0013481340138241649, + -0.02918696030974388, + 0.0038333951961249113, + -0.04409456625580788, + -0.005032042507082224, + 0.00482167350128293, + 0.2766477167606354, + -3.164933150401339e-05, + -0.0006618167390115559, + 0.0953889712691307, + 0.02506939135491848, + 0.014239178970456123, + 0.014754998497664928, + 9.890835644910112e-05, + -8.977938705356792e-05, + 0.05082912743091583, + -0.5051022171974182, + 0.00014696970174554735, + -0.0016026375815272331, + 0.06883199512958527, + 0.002327115274965763, + 0.0013425961369648576, + 0.009630928747355938, + -0.07776415348052979, + -0.007728713098913431, + -0.0005726079107262194, + -0.002957182005047798, + -0.0049475994892418385, + 0.00045916702947579324, + -0.0006328188464976847, + -0.006520198658108711, + -0.3204910457134247, + -0.002473111730068922 + ], + "xaxis": "x", + "y": [ + 0.0009487751522101462, + 0.016124747693538666, + 0.0018548924708738923, + 0.0034389030188322067, + -0.00982347596436739, + 0.011058605276048183, + -0.004063969012349844, + -0.0015792781487107277, + -0.0012082795146852732, + 0.003828897839412093, + -0.004256919026374817, + -0.0011422622483223677, + -0.0010771177476271987, + -0.00037898647133260965, + 2.5171791548928013e-06, + -0.00026067905128002167, + -0.00014146546891424805, + 0.0038321535103023052, + -0.0004293300735298544, + -0.00142992555629462, + -0.0009228314156644046, + 0.0006944393389858305, + 0.00043302192352712154, + -0.0035714071709662676, + -0.0004967569257132709, + 0.0008057993836700916, + 0.0005424688570201397, + -0.0005309234256856143, + -0.0007159864180721343, + -0.0010389237431809306, + -0.0009490771917626262, + -8.649027586216107e-05, + 0.0002766547549981624, + 0.0021084228064864874, + -0.0001975146442418918, + -0.0016405630158260465, + 0.1162627637386322, + 0.0002507446042727679, + -0.0014675153652206063, + -0.00039680811460129917, + 0.018962211906909943, + -0.00018764731066767126, + 0.011170871555805206, + -0.0013301445869728923, + -0.0007356539717875421, + -0.00030253134900704026, + -0.00014683544577565044, + -0.00022228369198273867, + -0.001650598249398172, + 0.0002927311579696834, + -0.00143563118763268, + 0.03084198758006096, + -0.007432155776768923, + -0.00028236035723239183, + 0.006017433945089579, + -0.011007187888026237, + -0.001266107545234263, + 0.0014901700196787715, + -0.0001800622121663764, + 0.002944394713267684, + -0.004211106337606907, + 0.0029597999528050423, + 0.002045023487880826, + 0.0013397098518908024, + -0.0012190865818411112, + 0.34349915385246277, + 0.0005632104002870619, + -0.0001262281439267099, + -0.00515326950699091, + 0.016240738332271576, + 0.01709030382335186, + -0.004175194539129734, + 0.039775289595127106, + 0.015226684510707855, + -0.0010229480685666203, + 0.0008072761120274663, + -0.004935584031045437, + -0.002123525831848383, + -0.014274083077907562, + 0.0013746818294748664, + 0.0014838266652077436, + 0.1302703619003296, + -0.00033616088330745697, + 0.0012919505825266242, + 0.00037177055492065847, + 0.019514480605721474, + 0.00022255218937061727, + 0.124249167740345, + -0.00040352059295400977, + -0.007652895525097847, + 0.0013010123511776328, + -0.0011253133416175842, + -0.007449474185705185, + 0.19224143028259277, + -0.003275118535384536, + -0.0005017912480980158, + -0.001007912098430097, + 3.091096004936844e-05, + -0.0008595998515374959, + 0.012359987013041973, + -0.0004041247011628002, + -0.004328910261392593, + 0.3185553252696991, + 0.002330605871975422, + 0.0021182901691645384, + 0.0001405928487656638, + 0.2779357433319092, + 0.005738262087106705, + 0.0058898297138512135, + -0.0009689796715974808, + 0.00912561360746622, + 0.020675739273428917, + -0.03700518235564232, + 0.014263041317462921, + -0.04828466475009918, + 0.05834139883518219, + 0.0006514795240946114, + 0.26360899209976196, + 0.0004918567719869316, + -0.00261044898070395, + 0.08374208211898804, + 0.020676210522651672, + -0.003743582172319293, + 0.01085072010755539, + -0.001096583902835846, + 0.00047430366976186633, + 0.04818058758974075, + -0.4799128472805023, + 0.00018429107149131596, + 0.011861988343298435, + 0.06088569387793541, + 0.0008461413672193885, + 0.005328264087438583, + -0.011493473313748837, + -0.11350836604833603, + 0.006329597905278206, + 0.00031669469899497926, + -0.0011600167490541935, + -0.022669579833745956, + 0.004070379305630922, + 0.0073160636238753796, + -0.00834545586258173, + -0.27817651629447937, + 0.0036344374530017376 + ], + "yaxis": "y" + } + ], + "layout": { + "legend": { + "tracegroupgap": 0 }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Decomposing attention layers into patching in individual heads has already helped us localise the behaviour a lot. But we can understand it further by decomposing heads. An attention head consists of two semi-independent operations - calculating *where* to move information from and to (represented by the attention pattern and implemented via the QK-circuit) and calculating *what* information to move (represented by the value vectors and implemented by the OV circuit). We can disentangle which of these is important by patching in just the attention pattern *or* the value vectors. (See [A Mathematical Framework](https://transformer-circuits.pub/2021/framework/index.html) or [my walkthrough video](https://www.youtube.com/watch?v=KV5gbOmHbjU) for more on this decomposition. If you're not familiar with the details of how attention is implemented, I recommend checking out [my clean transformer implementation](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/clean-transformer-demo/Clean_Transformer_Demo.ipynb#scrollTo=3Pb0NYbZ900e) to see how the code works))" + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "First let's patch in the value vectors, to measure when figuring out what to move is important. . This has the same shape as z ([batch, pos, head_index, d_head]) so we can reuse the same hook." + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": {}, - "outputs": [], - "source": [ - "patched_head_v_diff = torch.zeros(\n", - " model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32\n", - ")\n", - "for layer in range(model.cfg.n_layers):\n", - " for head_index in range(model.cfg.n_heads):\n", - " hook_fn = partial(patch_head_vector, head_index=head_index, clean_cache=cache)\n", - " patched_logits = model.run_with_hooks(\n", - " corrupted_tokens,\n", - " fwd_hooks=[(utils.get_act_name(\"v\", layer, \"attn\"), hook_fn)],\n", - " return_type=\"logits\",\n", - " )\n", - " patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)\n", - "\n", - " patched_head_v_diff[layer, head_index] = normalize_patched_logit_diff(\n", - " patched_logit_diff\n", - " )" + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can plot this as a heatmap and it's initially hard to interpret." - ] + "title": { + "text": "Scatter plot of output patching vs attention patching" }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ - { - "coloraxis": "coloraxis", - "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", - "name": "0", - "type": "heatmap", - "xaxis": "x", - "yaxis": "y", - "z": [ - [ - -0.00019892427371814847, - 0.005339574534446001, - 0.0006527548539452255, - 0.003504416672512889, - -0.00898387935012579, - 0.0034814265090972185, - -0.0008631910313852131, - -0.00003406582254683599, - 0.0005166929331608117, - 0.00044255363172851503, - -0.0039068968035280704, - -0.0001880836207419634 - ], - [ - -0.0004399022145662457, - -0.00044510437874123454, - -0.0000673597096465528, - 0.00007242763240355998, - -0.000036549441574607044, - -0.0019323208834975958, - -0.0001572397886775434, - 0.000016143509128596634, - 0.00020593880617525429, - 0.000336798548232764, - 0.0003515324497129768, - -0.0005669358652085066 - ], - [ - 0.00021013410878367722, - -0.0007199132232926786, - 0.0004868560063187033, - -0.0005974104860797524, - -0.0005921411793678999, - -0.0005443819100037217, - -0.000227552984142676, - -0.0004809825913980603, - 0.00020570388005580753, - 0.001183376181870699, - -0.0003574058646336198, - -0.0009104468626901507 - ], - [ - 0.0010395278222858906, - -0.00012042184971505776, - -0.00007762980385450646, - -0.0007275318494066596, - -0.001310007064603269, - -0.0023108376190066338, - 0.010987084358930588, - -0.000050712766096694395, - 0.00014314358122646809, - 0.00015069512301124632, - -0.00007957642083056271, - -0.000020238119759596884 - ], - [ - -0.0005373673629947007, - -0.0008137872209772468, - -0.00013334336108528078, - 0.030609702691435814, - -0.007185807917267084, - 0.000148916311445646, - 0.0013340713921934366, - -0.01142292469739914, - -0.0005336419562809169, - 0.0005126654868945479, - 0.00037344868178479373, - 0.0029547319281846285 - ], - [ - 0.00000822278525447473, - 0.000006477540864580078, - 0.0015973682748153806, - 0.00034015480196103454, - -0.0012577504385262728, - -0.00005450531898532063, - 0.0006331544718705118, - -0.00027081489679403603, - 0.00007427356467815116, - -0.006704355590045452, - 0.003175975289195776, - -0.0017300404142588377 - ], - [ - 0.04863045737147331, - 0.015314852818846703, - -0.0004648726317100227, - -0.00011676354915834963, - -0.00004930314753437415, - -0.003952810075134039, - -0.01737578585743904, - -0.00015421917487401515, - 0.0012194222072139382, - -0.00018090127559844404, - -0.00042647725786082447, - 0.00012334177154116333 - ], - [ - -0.00002956846401502844, - -0.0013855225406587124, - -0.00012129446986364201, - 0.1332160234451294, - -0.00024490474606864154, - -0.007315828464925289, - 0.00033297244226559997, - -0.000795092957559973, - -0.007938209921121597, - 0.208413764834404, - -0.00019127204723190516, - -0.00020650937221944332 - ], - [ - -0.0020483459811657667, - -0.0003764357534237206, - -0.0033135139383375645, - -0.009666135534644127, - -0.00031723169377073646, - -0.005141589790582657, - 0.31717124581336975, - 0.0028427678626030684, - 0.0004723234742414206, - -0.0011529687326401472, - 0.2726709246635437, - -0.003175639547407627 - ], - [ - -0.00043929810635745525, - 0.000057089622714556754, - -0.0020629793871194124, - 0.020066648721694946, - -0.007871017791330814, - 0.011316264048218727, - 0.003056862158700824, - 0.06856372952461243, - -0.002747517777606845, - -0.009279227815568447, - 0.000506624230183661, - -0.0013159140944480896 - ], - [ - -0.012957162223756313, - -0.0030454176012426615, - -0.01792328804731369, - -0.0043589151464402676, - -0.0011521632550284266, - 0.0004999117809347808, - -0.0031131464056670666, - 0.019585633650422096, - 0.0000434632929682266, - 0.01297028549015522, - -0.007695754989981651, - -0.0009146086522378027 - ], - [ - 0.004100752994418144, - -0.020459463819861412, - -0.035875942558050156, - 0.014656225219368935, - 0.0008441276149824262, - 0.0017804511589929461, - -0.01804223284125328, - 0.003519016318023205, - 0.008253024891018867, - -0.0017665562918409705, - 0.044167667627334595, - 0.006474285386502743 - ] - ] - } - ], - "layout": { - "coloraxis": { - "cmid": 0, - "colorscale": [ - [ - 0, - "rgb(103,0,31)" - ], - [ - 0.1, - "rgb(178,24,43)" - ], - [ - 0.2, - "rgb(214,96,77)" - ], - [ - 0.3, - "rgb(244,165,130)" - ], - [ - 0.4, - "rgb(253,219,199)" - ], - [ - 0.5, - "rgb(247,247,247)" - ], - [ - 0.6, - "rgb(209,229,240)" - ], - [ - 0.7, - "rgb(146,197,222)" - ], - [ - 0.8, - "rgb(67,147,195)" - ], - [ - 0.9, - "rgb(33,102,172)" - ], - [ - 1, - "rgb(5,48,97)" - ] - ] - }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" - ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } - }, - "title": { - "text": "Logit Difference From Patched Head Value" - }, - "xaxis": { - "anchor": "y", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "scaleanchor": "y", - "title": { - "text": "Head" - } - }, - "yaxis": { - "anchor": "x", - "autorange": "reversed", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "title": { - "text": "Layer" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "imshow(\n", - " patched_head_v_diff,\n", - " title=\"Logit Difference From Patched Head Value\",\n", - " labels={\"x\": \"Head\", \"y\": \"Layer\"},\n", - ")" - ] + "xaxis": { + "anchor": "y", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Attention Patch" + } }, + "yaxis": { + "anchor": "x", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Output Patch" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "imshow(\n", + " patched_head_attn_diff,\n", + " title=\"Logit Difference From Patched Head Pattern\",\n", + " labels={\"x\": \"Head\", \"y\": \"Layer\"},\n", + ")\n", + "head_labels = [\n", + " f\"L{l}H{h}\" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)\n", + "]\n", + "scatter(\n", + " x=utils.to_numpy(patched_head_attn_diff.flatten()),\n", + " y=utils.to_numpy(patched_head_z_diff.flatten()),\n", + " hover_name=head_labels,\n", + " xaxis=\"Attention Patch\",\n", + " yaxis=\"Output Patch\",\n", + " title=\"Scatter plot of output patching vs attention patching\",\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Consolidating Understanding\n", + "\n", + "OK, let's zoom out and reconsolidate. At a high-level, we find that all the action is on the second subject token until layer 7 and then transitions to the final token. And that attention layers matter a lot, MLP layers not so much (apart from MLP0, likely as an extended embedding).\n", + "\n", + "We've further localised important behaviour to several categories of heads. We've found 3 categories of heads that matter a lot - early heads (L5H5, L6H9, L3H0) whose output matters on the second subject and whose behaviour is determined by their attention patterns, mid-late heads (L8H6, L8H10, L7H9, L7H3) whose output matters on the final token and whose behaviour is determined by their value vectors, and late heads (L9H9, L10H7, L11H10) whose output matters on the final token and whose behaviour is determined by their attention patterns.\n", + "\n", + "A natural speculation is that early heads detect both that the second subject is a repeated token and *which* is repeated (ie the \" John\" token is repeated), middle heads compose with this and move this duplicated token information from the second subject token to the final token, and the late heads compose with this to *inhibit* their attention to the duplicated token, and then attend to the correct indirect object name and copy that directly to the logits." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualizing Attention Patterns\n", + "\n", + "We can validate this by looking at the attention patterns of these heads! Let's take the top 10 heads by output patching (in absolute value) and split it into early, middle and late.\n", + "\n", + "We see that middle heads attend from the final token to the second subject, and late heads attend from the final token to the indirect object, which is completely consistent with the above speculation! But weirdly, while *one* early head attends from the second subject to its first copy, the other two mysteriously attend to the word *after* the first copy." + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "

Top Early Heads


\n", + "

Top Middle Heads


\n", + "

Top Late Heads


\n", + "
" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "top_k = 10\n", + "top_heads_by_output_patch = torch.topk(\n", + " patched_head_z_diff.abs().flatten(), k=top_k\n", + ").indices\n", + "first_mid_layer = 7\n", + "first_late_layer = 9\n", + "early_heads = top_heads_by_output_patch[\n", + " top_heads_by_output_patch < model.cfg.n_heads * first_mid_layer\n", + "]\n", + "mid_heads = top_heads_by_output_patch[\n", + " torch.logical_and(\n", + " model.cfg.n_heads * first_mid_layer <= top_heads_by_output_patch,\n", + " top_heads_by_output_patch < model.cfg.n_heads * first_late_layer,\n", + " )\n", + "]\n", + "late_heads = top_heads_by_output_patch[\n", + " model.cfg.n_heads * first_late_layer <= top_heads_by_output_patch\n", + "]\n", + "\n", + "early = visualize_attention_patterns(\n", + " early_heads, cache, tokens[0], title=f\"Top Early Heads\"\n", + ")\n", + "mid = visualize_attention_patterns(\n", + " mid_heads, cache, tokens[0], title=f\"Top Middle Heads\"\n", + ")\n", + "late = visualize_attention_patterns(\n", + " late_heads, cache, tokens[0], title=f\"Top Late Heads\"\n", + ")\n", + "\n", + "HTML(early + mid + late)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Comparing to the Paper\n", + "\n", + "We can now refer to the (far, far more rigorous and detailed) analysis in the paper to compare our results! Here's the diagram they give of their results. \n", + "\n", + "![IOI1](https://pbs.twimg.com/media/FghGkTAWAAAmkhm.jpg)\n", + "\n", + "(Head 1.2 in their notation is L1H2 in my notation etc. And note - in the [latest version of the paper](https://arxiv.org/pdf/2211.00593.pdf) they add 9.0 as a backup name mover, and remove 11.3)\n", + "\n", + "The heads form three categories corresponding to the early, middle and late categories we found and we did fairly well! Definitely not perfect, but with some fairly generic techniques and some a priori reasoning, we found the broad strokes of the circuit and what it looks like. We focused on the most important heads, so we didn't find all relevant heads in each category (especially not the heads in brackets, which are more minor), but this serves as a good base for doing more rigorous and involved analysis, especially for finding the *complete* circuit (ie all of the parts of the model which participate in this behaviour) rather than just a partial and suggestive circuit. Go check out [their paper](https://arxiv.org/abs/2211.00593) or [our interview](https://www.youtube.com/watch?v=gzwj0jWbvbo) to learn more about what they did and what they found!\n", + "\n", + "Breaking down their categories:\n", + "\n", + "* Early: The duplicate token heads, previous token heads and induction heads. These serve the purpose of detecting that the second subject is duplicated and which earlier name is the duplicate.\n", + " * We found a direct duplicate token head which behaves exactly as expected, L3H0. Heads L5H0 and L6H9 are induction heads, which explains why they don't attend directly to the earlier copy of John!\n", + " * Note that the duplicate token heads and induction heads do not compose with each other - both directly add to the S-Inhibition heads. The diagram is somewhat misleading.\n", + "* Middle: They call these S-Inhibition heads - they copy the information about the duplicate token from the second subject to the to token, and their output is used to *inhibit* the attention paid from the name movers to the first subject copy. We found all these heads, and had a decent guess for what they did.\n", + " * In either case they attend to the second subject, so the patch that mattered was their value vectors!\n", + "* Late: They call these name movers, and we found some of them. They attend from the final token to the indirect object name and copy that to the logits, using the S-Inhibition heads to inhibit attention to the first copy of the subject token.\n", + " * We did find their surprising result of *negative* name movers - name movers that inhibit the correct answer!\n", + " * They have an entire category of heads we missed called backup name movers - we'll get to these later.\n", + "\n", + "So, now, let's dig into the two anomalies we missed - induction heads and backup name mover heads" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Bonus: Exploring Anomalies" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Early Heads are Induction Heads(?!)\n", + "\n", + "A really weird observation is that some of the early heads detecting duplicated tokens are induction heads, not just direct duplicate token heads. This is very weird! What's up with that? \n", + "\n", + "First off, what's an induction head? An induction head is an important type of attention head that can detect and continue repeated sequences. It is the second head in a two head induction circuit, which looks for previous copies of the current token and attends to the token *after* it, and then copies that to the current position and predicts that it will come next. They're enough of a big deal that [we wrote a whole paper on them](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html).\n", + "\n", + "![Move image demo](https://pbs.twimg.com/media/FNWAzXjVEAEOGRe.jpg)\n", + "\n", + "Second, why is it surprising that they come up here? It's surprising because it feels like overkill. The model doesn't care about *what* token comes after the first copy of the subject, just that it's duplicated. And it already has simpler duplicate token heads. My best guess is that it just already had induction heads around and that, in addition to their main function, they *also* only activate on duplicated tokens. So it was useful to repurpose this existing machinery. \n", + "\n", + "This suggests that as we look for circuits in larger models life may get more and more complicated, as components in simpler circuits get repurposed and built upon. " + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can verify that these are induction heads by running the model on repeated text and plotting the heads." + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "example_text = \"Research in mechanistic interpretability seeks to explain behaviors of machine learning models in terms of their internal components.\"\n", + "example_repeated_text = example_text + example_text\n", + "example_repeated_tokens = model.to_tokens(example_repeated_text, prepend_bos=True)\n", + "example_repeated_logits, example_repeated_cache = model.run_with_cache(\n", + " example_repeated_tokens\n", + ")\n", + "induction_head_labels = [81, 65]" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "

Induction Heads


\n", + "
" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "code = visualize_attention_patterns(\n", + " induction_head_labels,\n", + " example_repeated_cache,\n", + " example_repeated_tokens,\n", + " title=\"Induction Heads\",\n", + " max_width=800,\n", + ")\n", + "HTML(code)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Implications\n", + "\n", + "One implication of this is that it's useful to categories heads according to whether they occur in\n", + "simpler circuits, so that as we look for more complex circuits we can easily look for them. This is\n", + "easy to do here! An interesting fact about induction heads is that they work on a sequence of\n", + "repeated random tokens - notable for being wildly off distribution from the natural language GPT-2\n", + "was trained on. Being able to predict a model's behaviour off distribution is a good mark of success\n", + "for mechanistic interpretability! This is a good sanity check for whether a head is an induction\n", + "head or not. \n", + "\n", + "We can characterise an induction head by just giving a sequence of random tokens repeated once, and\n", + "measuring the average attention paid from the second copy of a token to the token after the first\n", + "copy. At the same time, we can also measure the average attention paid from the second copy of a\n", + "token to the first copy of the token, which is the attention that the induction head would pay if it\n", + "were a duplicate token head, and the average attention paid to the previous token to find previous\n", + "token heads.\n", + "\n", + "Note that this is a superficial study of whether something is an induction head - we totally ignore\n", + "the question of whether it actually does boost the correct token or whether it composes with a\n", + "single previous head and how. In particular, we sometimes get anti-induction heads which suppress\n", + "the induction-y token (no clue why!), and this technique will find those too . But given the\n", + "previous rigorous analysis, we can be pretty confident that this picks up on some true signal about\n", + "induction heads." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
Technical Implementation Details \n", + "We can do this again by using hooks, this time just to access the attention patterns rather than to intervene on them. \n", + "\n", + "Our hook function acts on the attention pattern activation. This has the name\n", + "\"blocks.{layer}.{layer_type}.hook_{activation_name}\" in general, here it's\n", + "\"blocks.{layer}.attn.hook_attn\". And it has shape [batch, head_index, query_pos, token_pos]. Our\n", + "hook function takes in the attention pattern activation, calculates the score for the relevant type\n", + "of head, and write it to an external cache.\n", + "\n", + "We add in hooks using `model.run_with_hooks(tokens, fwd_hooks=[(names_filter, hook_fn)])` to\n", + "temporarily add in the hooks and run the model, getting the resulting output. Previously\n", + "names_filter was the name of the activation, but here it's a boolean function mapping activation\n", + "names to whether we want to hook them or not. Here it's just whether the name ends with hook_attn.\n", + "hook_fn must take in the two inputs activation (the activation tensor) and hook (the HookPoint\n", + "object, which contains the name of the activation and some metadata such as the current layer).\n", + "\n", + "Internally our hooks use the function `tensor.diagonal`, this takes the diagonal between two\n", + "dimensions, and allows an arbitrary offset - offset by 1 to get previous tokens, seq_len to get\n", + "duplicate tokens (the distance to earlier copies) and seq_len-1 to get induction heads (the distance\n", + "to the token *after* earlier copies). Different offsets give a different length of output tensor,\n", + "and we can now just average to get a score in [0, 1] for each head\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[0.0390, 0.0000, 0.0310],\n", + " [0.1890, 0.1720, 0.0680],\n", + " [0.1570, 0.0210, 0.4820]])\n", + "tensor([[0.0030, 0.1320, 0.0050],\n", + " [0.0000, 0.0000, 0.0020],\n", + " [0.0020, 0.0090, 0.0000]])\n", + "tensor([[0.0040, 0.0000, 0.0040],\n", + " [0.0010, 0.0000, 0.0020],\n", + " [0.0020, 0.0090, 0.0020]])\n" + ] + } + ], + "source": [ + "seq_len = 100\n", + "batch_size = 2\n", + "\n", + "prev_token_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=device)\n", + "\n", + "\n", + "def prev_token_hook(pattern, hook):\n", + " layer = hook.layer()\n", + " diagonal = pattern.diagonal(offset=1, dim1=-1, dim2=-2)\n", + " # print(diagonal)\n", + " # print(pattern)\n", + " prev_token_scores[layer] = einops.reduce(\n", + " diagonal, \"batch head_index diagonal -> head_index\", \"mean\"\n", + " )\n", + "\n", + "\n", + "duplicate_token_scores = torch.zeros(\n", + " (model.cfg.n_layers, model.cfg.n_heads), device=device\n", + ")\n", + "\n", + "\n", + "def duplicate_token_hook(pattern, hook):\n", + " layer = hook.layer()\n", + " diagonal = pattern.diagonal(offset=seq_len, dim1=-1, dim2=-2)\n", + " duplicate_token_scores[layer] = einops.reduce(\n", + " diagonal, \"batch head_index diagonal -> head_index\", \"mean\"\n", + " )\n", + "\n", + "\n", + "induction_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=device)\n", + "\n", + "\n", + "def induction_hook(pattern, hook):\n", + " layer = hook.layer()\n", + " diagonal = pattern.diagonal(offset=seq_len - 1, dim1=-1, dim2=-2)\n", + " induction_scores[layer] = einops.reduce(\n", + " diagonal, \"batch head_index diagonal -> head_index\", \"mean\"\n", + " )\n", + "\n", + "\n", + "torch.manual_seed(0)\n", + "original_tokens = torch.randint(\n", + " 100, 20000, size=(batch_size, seq_len), device=\"cpu\"\n", + ").to(device)\n", + "repeated_tokens = einops.repeat(\n", + " original_tokens, \"batch seq_len -> batch (2 seq_len)\"\n", + ").to(device)\n", + "\n", + "pattern_filter = lambda act_name: act_name.endswith(\"hook_pattern\")\n", + "\n", + "loss = model.run_with_hooks(\n", + " repeated_tokens,\n", + " return_type=\"loss\",\n", + " fwd_hooks=[\n", + " (pattern_filter, prev_token_hook),\n", + " (pattern_filter, duplicate_token_hook),\n", + " (pattern_filter, induction_hook),\n", + " ],\n", + ")\n", + "print(torch.round(utils.get_corner(prev_token_scores).detach().cpu(), decimals=3))\n", + "print(torch.round(utils.get_corner(duplicate_token_scores).detach().cpu(), decimals=3))\n", + "print(torch.round(utils.get_corner(induction_scores).detach().cpu(), decimals=3))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now plot the head scores, and instantly see that the relevant early heads are induction heads or duplicate token heads (though also that there's a lot of induction heads that are *not* use - I have no idea why!). " + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "But it's very easy to interpret if we plot a scatter plot against patching head outputs. Here we see that the earlier heads (L5H5, L6H9, L3H0) and late name movers (L9H9, L10H7, L11H10) don't matter at all now, while the mid-late heads (L8H6, L8H10, L7H9) do. \n", - "\n", - "Meta lesson: Plot things early, often and in diverse ways as you explore a model's internals!" - ] + "coloraxis": "coloraxis", + "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + 0.039069853723049164, + 0.0004489101702347398, + 0.03133601322770119, + 0.007519590202718973, + 0.034592196345329285, + 0.00036230171099305153, + 0.034512776881456375, + 0.19740213453769684, + 0.038447845727205276, + 0.04053792357444763, + 0.027628764510154724, + 0.02496313862502575 + ], + [ + 0.1890650987625122, + 0.17219914495944977, + 0.06807752698659897, + 0.04494515433907509, + 0.07908554375171661, + 0.03096739575266838, + 0.028282109647989273, + 0.03644327446818352, + 0.026936717331409454, + 0.018826229497790337, + 0.045100897550582886, + 0.0065726665779948235 + ], + [ + 0.15745528042316437, + 0.020724520087242126, + 0.4817989468574524, + 0.2991352379322052, + 0.10764895379543304, + 0.33004048466682434, + 0.0997551754117012, + 0.04926132410764694, + 0.25493940711021423, + 0.3606453835964203, + 0.1257179230451584, + 0.07931824028491974 + ], + [ + 0.005844001192599535, + 0.15787364542484283, + 0.4189082086086273, + 0.30129021406173706, + 0.014345049858093262, + 0.032344333827495575, + 0.3312888443470001, + 0.5285974144935608, + 0.34242063760757446, + 0.101837158203125, + 0.10516070574522018, + 0.2233113795518875 + ], + [ + 0.10626544803380966, + 0.11930850893259048, + 0.022880680859088898, + 0.22826944291591644, + 0.020003994926810265, + 0.10010036826133728, + 0.1739213615655899, + 0.17407020926475525, + 0.02587701380252838, + 0.10249985754489899, + 0.009514841251075268, + 0.9921423196792603 + ], + [ + 0.019766658544540405, + 0.00528325280174613, + 0.16648508608341217, + 0.12087740004062653, + 0.16500000655651093, + 0.00803269725292921, + 0.41770195960998535, + 0.025827765464782715, + 0.04802601411938667, + 0.016231779009103775, + 0.03110172413289547, + 0.024261215701699257 + ], + [ + 0.2172909826040268, + 0.039100028574466705, + 0.01804858259856701, + 0.059900715947151184, + 0.032934583723545074, + 0.0873451679944992, + 0.026895340532064438, + 0.0943947583436966, + 0.49925994873046875, + 0.006240115500986576, + 0.027026718482375145, + 0.1278565675020218 + ], + [ + 0.2511657178401947, + 0.01330868061631918, + 0.006663354113698006, + 0.037430502474308014, + 0.02331537753343582, + 0.01740722358226776, + 0.022067422047257423, + 0.022141192108392715, + 0.04502448812127113, + 0.0208425372838974, + 0.008310739882290363, + 0.017167754471302032 + ], + [ + 0.020890623331069946, + 0.016537941992282867, + 0.02158307284116745, + 0.0150058064609766, + 0.02421221323311329, + 0.10198988765478134, + 0.029100384563207626, + 0.22793792188167572, + 0.02781485579907894, + 0.0179410632699728, + 0.024828944355249405, + 0.03806235268712044 + ], + [ + 0.02607586607336998, + 0.015407431870698929, + 0.02044427953660488, + 0.14558182656764984, + 0.01247025839984417, + 0.017151640728116035, + 0.013311829417943954, + 0.024451706558465958, + 0.018111787736415863, + 0.01319331955164671, + 0.0357399508357048, + 0.01879822090268135 + ], + [ + 0.02147812582552433, + 0.018419174477458, + 0.018183622509241104, + 0.02172141708433628, + 0.0315677747130394, + 0.034705750644207, + 0.017550116404891014, + 0.011417553760111332, + 0.01579565554857254, + 0.04592214897274971, + 0.01621554046869278, + 0.03039470687508583 + ], + [ + 0.03320508822798729, + 0.0175714660435915, + 0.015131079591810703, + 0.04148406535387039, + 0.015181189402937889, + 0.01758997142314911, + 0.015148494392633438, + 0.01767607219517231, + 0.06622709333896637, + 0.018451133742928505, + 0.01700744964182377, + 0.029749270528554916 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ - { - "hovertemplate": "%{hovertext}

Value Patch=%{x}
Output Patch=%{y}
Layer=%{marker.color}", - "hovertext": [ - "L0H0", - "L0H1", - "L0H2", - "L0H3", - "L0H4", - "L0H5", - "L0H6", - "L0H7", - "L0H8", - "L0H9", - "L0H10", - "L0H11", - "L1H0", - "L1H1", - "L1H2", - "L1H3", - "L1H4", - "L1H5", - "L1H6", - "L1H7", - "L1H8", - "L1H9", - "L1H10", - "L1H11", - "L2H0", - "L2H1", - "L2H2", - "L2H3", - "L2H4", - "L2H5", - "L2H6", - "L2H7", - "L2H8", - "L2H9", - "L2H10", - "L2H11", - "L3H0", - "L3H1", - "L3H2", - "L3H3", - "L3H4", - "L3H5", - "L3H6", - "L3H7", - "L3H8", - "L3H9", - "L3H10", - "L3H11", - "L4H0", - "L4H1", - "L4H2", - "L4H3", - "L4H4", - "L4H5", - "L4H6", - "L4H7", - "L4H8", - "L4H9", - "L4H10", - "L4H11", - "L5H0", - "L5H1", - "L5H2", - "L5H3", - "L5H4", - "L5H5", - "L5H6", - "L5H7", - "L5H8", - "L5H9", - "L5H10", - "L5H11", - "L6H0", - "L6H1", - "L6H2", - "L6H3", - "L6H4", - "L6H5", - "L6H6", - "L6H7", - "L6H8", - "L6H9", - "L6H10", - "L6H11", - "L7H0", - "L7H1", - "L7H2", - "L7H3", - "L7H4", - "L7H5", - "L7H6", - "L7H7", - "L7H8", - "L7H9", - "L7H10", - "L7H11", - "L8H0", - "L8H1", - "L8H2", - "L8H3", - "L8H4", - "L8H5", - "L8H6", - "L8H7", - "L8H8", - "L8H9", - "L8H10", - "L8H11", - "L9H0", - "L9H1", - "L9H2", - "L9H3", - "L9H4", - "L9H5", - "L9H6", - "L9H7", - "L9H8", - "L9H9", - "L9H10", - "L9H11", - "L10H0", - "L10H1", - "L10H2", - "L10H3", - "L10H4", - "L10H5", - "L10H6", - "L10H7", - "L10H8", - "L10H9", - "L10H10", - "L10H11", - "L11H0", - "L11H1", - "L11H2", - "L11H3", - "L11H4", - "L11H5", - "L11H6", - "L11H7", - "L11H8", - "L11H9", - "L11H10", - "L11H11" - ], - "legendgroup": "", - "marker": { - "color": [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 3, - 3, - 3, - 3, - 3, - 3, - 3, - 3, - 3, - 3, - 3, - 3, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 6, - 6, - 6, - 6, - 6, - 6, - 6, - 6, - 6, - 6, - 6, - 6, - 7, - 7, - 7, - 7, - 7, - 7, - 7, - 7, - 7, - 7, - 7, - 7, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - 9, - 9, - 9, - 9, - 9, - 9, - 9, - 9, - 9, - 9, - 9, - 9, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 11, - 11, - 11, - 11, - 11, - 11, - 11, - 11, - 11, - 11, - 11, - 11 - ], - "coloraxis": "coloraxis", - "symbol": "circle" - }, - "mode": "markers", - "name": "", - "orientation": "v", - "showlegend": false, - "type": "scatter", - "x": [ - -0.00019892427371814847, - 0.005339574534446001, - 0.0006527548539452255, - 0.003504416672512889, - -0.00898387935012579, - 0.0034814265090972185, - -0.0008631910313852131, - -0.00003406582254683599, - 0.0005166929331608117, - 0.00044255363172851503, - -0.0039068968035280704, - -0.0001880836207419634, - -0.0004399022145662457, - -0.00044510437874123454, - -0.0000673597096465528, - 0.00007242763240355998, - -0.000036549441574607044, - -0.0019323208834975958, - -0.0001572397886775434, - 0.000016143509128596634, - 0.00020593880617525429, - 0.000336798548232764, - 0.0003515324497129768, - -0.0005669358652085066, - 0.00021013410878367722, - -0.0007199132232926786, - 0.0004868560063187033, - -0.0005974104860797524, - -0.0005921411793678999, - -0.0005443819100037217, - -0.000227552984142676, - -0.0004809825913980603, - 0.00020570388005580753, - 0.001183376181870699, - -0.0003574058646336198, - -0.0009104468626901507, - 0.0010395278222858906, - -0.00012042184971505776, - -0.00007762980385450646, - -0.0007275318494066596, - -0.001310007064603269, - -0.0023108376190066338, - 0.010987084358930588, - -0.000050712766096694395, - 0.00014314358122646809, - 0.00015069512301124632, - -0.00007957642083056271, - -0.000020238119759596884, - -0.0005373673629947007, - -0.0008137872209772468, - -0.00013334336108528078, - 0.030609702691435814, - -0.007185807917267084, - 0.000148916311445646, - 0.0013340713921934366, - -0.01142292469739914, - -0.0005336419562809169, - 0.0005126654868945479, - 0.00037344868178479373, - 0.0029547319281846285, - 0.00000822278525447473, - 0.000006477540864580078, - 0.0015973682748153806, - 0.00034015480196103454, - -0.0012577504385262728, - -0.00005450531898532063, - 0.0006331544718705118, - -0.00027081489679403603, - 0.00007427356467815116, - -0.006704355590045452, - 0.003175975289195776, - -0.0017300404142588377, - 0.04863045737147331, - 0.015314852818846703, - -0.0004648726317100227, - -0.00011676354915834963, - -0.00004930314753437415, - -0.003952810075134039, - -0.01737578585743904, - -0.00015421917487401515, - 0.0012194222072139382, - -0.00018090127559844404, - -0.00042647725786082447, - 0.00012334177154116333, - -0.00002956846401502844, - -0.0013855225406587124, - -0.00012129446986364201, - 0.1332160234451294, - -0.00024490474606864154, - -0.007315828464925289, - 0.00033297244226559997, - -0.000795092957559973, - -0.007938209921121597, - 0.208413764834404, - -0.00019127204723190516, - -0.00020650937221944332, - -0.0020483459811657667, - -0.0003764357534237206, - -0.0033135139383375645, - -0.009666135534644127, - -0.00031723169377073646, - -0.005141589790582657, - 0.31717124581336975, - 0.0028427678626030684, - 0.0004723234742414206, - -0.0011529687326401472, - 0.2726709246635437, - -0.003175639547407627, - -0.00043929810635745525, - 0.000057089622714556754, - -0.0020629793871194124, - 0.020066648721694946, - -0.007871017791330814, - 0.011316264048218727, - 0.003056862158700824, - 0.06856372952461243, - -0.002747517777606845, - -0.009279227815568447, - 0.000506624230183661, - -0.0013159140944480896, - -0.012957162223756313, - -0.0030454176012426615, - -0.01792328804731369, - -0.0043589151464402676, - -0.0011521632550284266, - 0.0004999117809347808, - -0.0031131464056670666, - 0.019585633650422096, - 0.0000434632929682266, - 0.01297028549015522, - -0.007695754989981651, - -0.0009146086522378027, - 0.004100752994418144, - -0.020459463819861412, - -0.035875942558050156, - 0.014656225219368935, - 0.0008441276149824262, - 0.0017804511589929461, - -0.01804223284125328, - 0.003519016318023205, - 0.008253024891018867, - -0.0017665562918409705, - 0.044167667627334595, - 0.006474285386502743 - ], - "xaxis": "x", - "y": [ - 0.0009487751522101462, - 0.016124747693538666, - 0.0018548924708738923, - 0.0034389030188322067, - -0.00982347596436739, - 0.011058605276048183, - -0.004063969012349844, - -0.0015792781487107277, - -0.0012082795146852732, - 0.003828897839412093, - -0.004256919026374817, - -0.0011422622483223677, - -0.0010771177476271987, - -0.00037898647133260965, - 0.0000025171791548928013, - -0.00026067905128002167, - -0.00014146546891424805, - 0.0038321535103023052, - -0.0004293300735298544, - -0.00142992555629462, - -0.0009228314156644046, - 0.0006944393389858305, - 0.00043302192352712154, - -0.0035714071709662676, - -0.0004967569257132709, - 0.0008057993836700916, - 0.0005424688570201397, - -0.0005309234256856143, - -0.0007159864180721343, - -0.0010389237431809306, - -0.0009490771917626262, - -0.00008649027586216107, - 0.0002766547549981624, - 0.0021084228064864874, - -0.0001975146442418918, - -0.0016405630158260465, - 0.1162627637386322, - 0.0002507446042727679, - -0.0014675153652206063, - -0.00039680811460129917, - 0.018962211906909943, - -0.00018764731066767126, - 0.011170871555805206, - -0.0013301445869728923, - -0.0007356539717875421, - -0.00030253134900704026, - -0.00014683544577565044, - -0.00022228369198273867, - -0.001650598249398172, - 0.0002927311579696834, - -0.00143563118763268, - 0.03084198758006096, - -0.007432155776768923, - -0.00028236035723239183, - 0.006017433945089579, - -0.011007187888026237, - -0.001266107545234263, - 0.0014901700196787715, - -0.0001800622121663764, - 0.002944394713267684, - -0.004211106337606907, - 0.0029597999528050423, - 0.002045023487880826, - 0.0013397098518908024, - -0.0012190865818411112, - 0.34349915385246277, - 0.0005632104002870619, - -0.0001262281439267099, - -0.00515326950699091, - 0.016240738332271576, - 0.01709030382335186, - -0.004175194539129734, - 0.039775289595127106, - 0.015226684510707855, - -0.0010229480685666203, - 0.0008072761120274663, - -0.004935584031045437, - -0.002123525831848383, - -0.014274083077907562, - 0.0013746818294748664, - 0.0014838266652077436, - 0.1302703619003296, - -0.00033616088330745697, - 0.0012919505825266242, - 0.00037177055492065847, - 0.019514480605721474, - 0.00022255218937061727, - 0.124249167740345, - -0.00040352059295400977, - -0.007652895525097847, - 0.0013010123511776328, - -0.0011253133416175842, - -0.007449474185705185, - 0.19224143028259277, - -0.003275118535384536, - -0.0005017912480980158, - -0.001007912098430097, - 0.00003091096004936844, - -0.0008595998515374959, - 0.012359987013041973, - -0.0004041247011628002, - -0.004328910261392593, - 0.3185553252696991, - 0.002330605871975422, - 0.0021182901691645384, - 0.0001405928487656638, - 0.2779357433319092, - 0.005738262087106705, - 0.0058898297138512135, - -0.0009689796715974808, - 0.00912561360746622, - 0.020675739273428917, - -0.03700518235564232, - 0.014263041317462921, - -0.04828466475009918, - 0.05834139883518219, - 0.0006514795240946114, - 0.26360899209976196, - 0.0004918567719869316, - -0.00261044898070395, - 0.08374208211898804, - 0.020676210522651672, - -0.003743582172319293, - 0.01085072010755539, - -0.001096583902835846, - 0.00047430366976186633, - 0.04818058758974075, - -0.4799128472805023, - 0.00018429107149131596, - 0.011861988343298435, - 0.06088569387793541, - 0.0008461413672193885, - 0.005328264087438583, - -0.011493473313748837, - -0.11350836604833603, - 0.006329597905278206, - 0.00031669469899497926, - -0.0011600167490541935, - -0.022669579833745956, - 0.004070379305630922, - 0.0073160636238753796, - -0.00834545586258173, - -0.27817651629447937, - 0.0036344374530017376 - ], - "yaxis": "y" - } - ], - "layout": { - "coloraxis": { - "colorbar": { - "title": { - "text": "Layer" - } - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ] - }, - "legend": { - "tracegroupgap": 0 - }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" - ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } - }, - "title": { - "text": "Scatter plot of output patching vs value patching" - }, - "xaxis": { - "anchor": "y", - "domain": [ - 0, - 1 - ], - "range": [ - -0.5, - 0.5 - ], - "title": { - "text": "Value Patch" - } - }, - "yaxis": { - "anchor": "x", - "domain": [ - 0, - 1 - ], - "range": [ - -0.5, - 0.5 - ], - "title": { - "text": "Output Patch" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "head_labels = [\n", - " f\"L{l}H{h}\" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)\n", - "]\n", - "scatter(\n", - " x=utils.to_numpy(patched_head_v_diff.flatten()),\n", - " y=utils.to_numpy(patched_head_z_diff.flatten()),\n", - " xaxis=\"Value Patch\",\n", - " yaxis=\"Output Patch\",\n", - " caxis=\"Layer\",\n", - " hover_name=head_labels,\n", - " color=einops.repeat(\n", - " np.arange(model.cfg.n_layers), \"layer -> (layer head)\", head=model.cfg.n_heads\n", - " ),\n", - " range_x=(-0.5, 0.5),\n", - " range_y=(-0.5, 0.5),\n", - " title=\"Scatter plot of output patching vs value patching\",\n", - ")" + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "When we patch in attention patterns, we see the opposite effect - early and late heads matter a lot, middle heads don't. (In fact, the sum of value patching and pattern patching is approx the same as output patching)" + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] - }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": {}, - "outputs": [], - "source": [ - "def patch_head_pattern(\n", - " corrupted_head_pattern: Float[torch.Tensor, \"batch head_index query_pos d_head\"],\n", - " hook,\n", - " head_index,\n", - " clean_cache,\n", - "):\n", - " corrupted_head_pattern[:, head_index, :, :] = clean_cache[hook.name][\n", - " :, head_index, :, :\n", - " ]\n", - " return corrupted_head_pattern\n", - "\n", - "\n", - "patched_head_attn_diff = torch.zeros(\n", - " model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32\n", - ")\n", - "for layer in range(model.cfg.n_layers):\n", - " for head_index in range(model.cfg.n_heads):\n", - " hook_fn = partial(patch_head_pattern, head_index=head_index, clean_cache=cache)\n", - " patched_logits = model.run_with_hooks(\n", - " corrupted_tokens,\n", - " fwd_hooks=[(utils.get_act_name(\"attn\", layer, \"attn\"), hook_fn)],\n", - " return_type=\"logits\",\n", - " )\n", - " patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)\n", - "\n", - " patched_head_attn_diff[layer, head_index] = normalize_patched_logit_diff(\n", - " patched_logit_diff\n", - " )" + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } }, - { - "cell_type": "code", - "execution_count": 32, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ - { - "coloraxis": "coloraxis", - "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", - "name": "0", - "type": "heatmap", - "xaxis": "x", - "yaxis": "y", - "z": [ - [ - 0.0006401354330591857, - 0.005318799521774054, - 0.0011584057938307524, - -0.00005920405237702653, - -0.00106671336106956, - 0.005079298280179501, - -0.0030818663071841, - -0.0020521720871329308, - -0.0014405983965843916, - 0.003492669900879264, - -0.002568227471783757, - -0.0009168237447738647 - ], - [ - -0.0007600873941555619, - 0.0001683824957581237, - 0.00012246915139257908, - -0.00034914951538667083, - 0.000014901700524205808, - 0.0050090523436665535, - -0.0002975976967718452, - -0.0014448943547904491, - -0.001099134678952396, - 0.00047447148244827986, - 0.00005195457561057992, - -0.0034954219590872526 - ], - [ - -0.0007243098807521164, - 0.0017458146903663874, - -0.00015556166181340814, - 0.000057626621128292754, - -0.000097398049547337, - -0.0004238593974150717, - -0.0007917031762190163, - 0.00027222454082220793, - 0.00010179472155869007, - 0.0004223826399538666, - 0.00015193692524917424, - -0.0007437760941684246 - ], - [ - 0.11458104848861694, - 0.00021140948229003698, - -0.0009424989693798125, - 0.000429833511589095, - 0.02004295401275158, - 0.002104730810970068, - 0.00007628730963915586, - -0.001543701975606382, - -0.0008484235731884837, - -0.0005819046637043357, - 0.00011921360419364646, - -0.00001899631206470076 - ], - [ - -0.001127125695347786, - 0.001237143180333078, - -0.0012324444251134992, - -0.0005952289211563766, - -0.0007541133090853691, - -0.0005842540413141251, - 0.004813014063984156, - 0.00018187458044849336, - -0.0005361591465771198, - 0.0008579217828810215, - -0.0002985374303534627, - -0.00001144477391790133 - ], - [ - -0.004241178277879953, - 0.0029509058222174644, - 0.0005218615406192839, - 0.0009535074350424111, - 0.0001622070267330855, - 0.34350839257240295, - -0.0003052163519896567, - 0.00010293584637111053, - -0.005300541408360004, - 0.024864863604307175, - 0.014383262023329735, - -0.0023285921197384596 - ], - [ - -0.0023893399629741907, - -0.002172795357182622, - -0.00047614958020858467, - 0.00043188079143874347, - -0.004675475414842367, - 0.0018583494238555431, - -0.0026542814448475838, - 0.0014367386465892196, - 0.00030326974228955805, - 0.13043038547039032, - 0.00008813483145786449, - 0.0011766973184421659 - ], - [ - 0.00031847349600866437, - 0.02057075686752796, - 0.00031840638257563114, - -0.002512782346457243, - -0.0002628941729199141, - -0.00024718698114156723, - 0.0005524033331312239, - -0.00043131023994646966, - 0.00025715501396916807, - 0.008090951479971409, - -0.0030689111445099115, - -0.0004238593974150717 - ], - [ - 0.000976699055172503, - 0.00039251212729141116, - 0.0017534669023007154, - 0.022595642134547234, - -0.000044805787183577195, - 0.00014220383309293538, - 0.009584981948137283, - -0.0003157213795930147, - 0.0015271222218871117, - 0.0011813960736617446, - -0.010774029418826103, - 0.00936581939458847 - ], - [ - 0.006314125377684832, - -0.0010949057759717107, - 0.011662023141980171, - 0.0013481340138241649, - -0.02918696030974388, - 0.0038333951961249113, - -0.04409456625580788, - -0.005032042507082224, - 0.00482167350128293, - 0.2766477167606354, - -0.00003164933150401339, - -0.0006618167390115559 - ], - [ - 0.0953889712691307, - 0.02506939135491848, - 0.014239178970456123, - 0.014754998497664928, - 0.00009890835644910112, - -0.00008977938705356792, - 0.05082912743091583, - -0.5051022171974182, - 0.00014696970174554735, - -0.0016026375815272331, - 0.06883199512958527, - 0.002327115274965763 - ], - [ - 0.0013425961369648576, - 0.009630928747355938, - -0.07776415348052979, - -0.007728713098913431, - -0.0005726079107262194, - -0.002957182005047798, - -0.0049475994892418385, - 0.00045916702947579324, - -0.0006328188464976847, - -0.006520198658108711, - -0.3204910457134247, - -0.002473111730068922 - ] - ] - } - ], - "layout": { - "coloraxis": { - "cmid": 0, - "colorscale": [ - [ - 0, - "rgb(103,0,31)" - ], - [ - 0.1, - "rgb(178,24,43)" - ], - [ - 0.2, - "rgb(214,96,77)" - ], - [ - 0.3, - "rgb(244,165,130)" - ], - [ - 0.4, - "rgb(253,219,199)" - ], - [ - 0.5, - "rgb(247,247,247)" - ], - [ - 0.6, - "rgb(209,229,240)" - ], - [ - 0.7, - "rgb(146,197,222)" - ], - [ - 0.8, - "rgb(67,147,195)" - ], - [ - 0.9, - "rgb(33,102,172)" - ], - [ - 1, - "rgb(5,48,97)" - ] - ] - }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" - ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } - }, - "title": { - "text": "Logit Difference From Patched Head Pattern" - }, - "xaxis": { - "anchor": "y", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "scaleanchor": "y", - "title": { - "text": "Head" - } - }, - "yaxis": { - "anchor": "x", - "autorange": "reversed", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "title": { - "text": "Layer" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ - { - "hovertemplate": "%{hovertext}

Attention Patch=%{x}
Output Patch=%{y}", - "hovertext": [ - "L0H0", - "L0H1", - "L0H2", - "L0H3", - "L0H4", - "L0H5", - "L0H6", - "L0H7", - "L0H8", - "L0H9", - "L0H10", - "L0H11", - "L1H0", - "L1H1", - "L1H2", - "L1H3", - "L1H4", - "L1H5", - "L1H6", - "L1H7", - "L1H8", - "L1H9", - "L1H10", - "L1H11", - "L2H0", - "L2H1", - "L2H2", - "L2H3", - "L2H4", - "L2H5", - "L2H6", - "L2H7", - "L2H8", - "L2H9", - "L2H10", - "L2H11", - "L3H0", - "L3H1", - "L3H2", - "L3H3", - "L3H4", - "L3H5", - "L3H6", - "L3H7", - "L3H8", - "L3H9", - "L3H10", - "L3H11", - "L4H0", - "L4H1", - "L4H2", - "L4H3", - "L4H4", - "L4H5", - "L4H6", - "L4H7", - "L4H8", - "L4H9", - "L4H10", - "L4H11", - "L5H0", - "L5H1", - "L5H2", - "L5H3", - "L5H4", - "L5H5", - "L5H6", - "L5H7", - "L5H8", - "L5H9", - "L5H10", - "L5H11", - "L6H0", - "L6H1", - "L6H2", - "L6H3", - "L6H4", - "L6H5", - "L6H6", - "L6H7", - "L6H8", - "L6H9", - "L6H10", - "L6H11", - "L7H0", - "L7H1", - "L7H2", - "L7H3", - "L7H4", - "L7H5", - "L7H6", - "L7H7", - "L7H8", - "L7H9", - "L7H10", - "L7H11", - "L8H0", - "L8H1", - "L8H2", - "L8H3", - "L8H4", - "L8H5", - "L8H6", - "L8H7", - "L8H8", - "L8H9", - "L8H10", - "L8H11", - "L9H0", - "L9H1", - "L9H2", - "L9H3", - "L9H4", - "L9H5", - "L9H6", - "L9H7", - "L9H8", - "L9H9", - "L9H10", - "L9H11", - "L10H0", - "L10H1", - "L10H2", - "L10H3", - "L10H4", - "L10H5", - "L10H6", - "L10H7", - "L10H8", - "L10H9", - "L10H10", - "L10H11", - "L11H0", - "L11H1", - "L11H2", - "L11H3", - "L11H4", - "L11H5", - "L11H6", - "L11H7", - "L11H8", - "L11H9", - "L11H10", - "L11H11" - ], - "legendgroup": "", - "marker": { - "color": "#636efa", - "symbol": "circle" - }, - "mode": "markers", - "name": "", - "orientation": "v", - "showlegend": false, - "type": "scatter", - "x": [ - 0.0006401354330591857, - 0.005318799521774054, - 0.0011584057938307524, - -0.00005920405237702653, - -0.00106671336106956, - 0.005079298280179501, - -0.0030818663071841, - -0.0020521720871329308, - -0.0014405983965843916, - 0.003492669900879264, - -0.002568227471783757, - -0.0009168237447738647, - -0.0007600873941555619, - 0.0001683824957581237, - 0.00012246915139257908, - -0.00034914951538667083, - 0.000014901700524205808, - 0.0050090523436665535, - -0.0002975976967718452, - -0.0014448943547904491, - -0.001099134678952396, - 0.00047447148244827986, - 0.00005195457561057992, - -0.0034954219590872526, - -0.0007243098807521164, - 0.0017458146903663874, - -0.00015556166181340814, - 0.000057626621128292754, - -0.000097398049547337, - -0.0004238593974150717, - -0.0007917031762190163, - 0.00027222454082220793, - 0.00010179472155869007, - 0.0004223826399538666, - 0.00015193692524917424, - -0.0007437760941684246, - 0.11458104848861694, - 0.00021140948229003698, - -0.0009424989693798125, - 0.000429833511589095, - 0.02004295401275158, - 0.002104730810970068, - 0.00007628730963915586, - -0.001543701975606382, - -0.0008484235731884837, - -0.0005819046637043357, - 0.00011921360419364646, - -0.00001899631206470076, - -0.001127125695347786, - 0.001237143180333078, - -0.0012324444251134992, - -0.0005952289211563766, - -0.0007541133090853691, - -0.0005842540413141251, - 0.004813014063984156, - 0.00018187458044849336, - -0.0005361591465771198, - 0.0008579217828810215, - -0.0002985374303534627, - -0.00001144477391790133, - -0.004241178277879953, - 0.0029509058222174644, - 0.0005218615406192839, - 0.0009535074350424111, - 0.0001622070267330855, - 0.34350839257240295, - -0.0003052163519896567, - 0.00010293584637111053, - -0.005300541408360004, - 0.024864863604307175, - 0.014383262023329735, - -0.0023285921197384596, - -0.0023893399629741907, - -0.002172795357182622, - -0.00047614958020858467, - 0.00043188079143874347, - -0.004675475414842367, - 0.0018583494238555431, - -0.0026542814448475838, - 0.0014367386465892196, - 0.00030326974228955805, - 0.13043038547039032, - 0.00008813483145786449, - 0.0011766973184421659, - 0.00031847349600866437, - 0.02057075686752796, - 0.00031840638257563114, - -0.002512782346457243, - -0.0002628941729199141, - -0.00024718698114156723, - 0.0005524033331312239, - -0.00043131023994646966, - 0.00025715501396916807, - 0.008090951479971409, - -0.0030689111445099115, - -0.0004238593974150717, - 0.000976699055172503, - 0.00039251212729141116, - 0.0017534669023007154, - 0.022595642134547234, - -0.000044805787183577195, - 0.00014220383309293538, - 0.009584981948137283, - -0.0003157213795930147, - 0.0015271222218871117, - 0.0011813960736617446, - -0.010774029418826103, - 0.00936581939458847, - 0.006314125377684832, - -0.0010949057759717107, - 0.011662023141980171, - 0.0013481340138241649, - -0.02918696030974388, - 0.0038333951961249113, - -0.04409456625580788, - -0.005032042507082224, - 0.00482167350128293, - 0.2766477167606354, - -0.00003164933150401339, - -0.0006618167390115559, - 0.0953889712691307, - 0.02506939135491848, - 0.014239178970456123, - 0.014754998497664928, - 0.00009890835644910112, - -0.00008977938705356792, - 0.05082912743091583, - -0.5051022171974182, - 0.00014696970174554735, - -0.0016026375815272331, - 0.06883199512958527, - 0.002327115274965763, - 0.0013425961369648576, - 0.009630928747355938, - -0.07776415348052979, - -0.007728713098913431, - -0.0005726079107262194, - -0.002957182005047798, - -0.0049475994892418385, - 0.00045916702947579324, - -0.0006328188464976847, - -0.006520198658108711, - -0.3204910457134247, - -0.002473111730068922 - ], - "xaxis": "x", - "y": [ - 0.0009487751522101462, - 0.016124747693538666, - 0.0018548924708738923, - 0.0034389030188322067, - -0.00982347596436739, - 0.011058605276048183, - -0.004063969012349844, - -0.0015792781487107277, - -0.0012082795146852732, - 0.003828897839412093, - -0.004256919026374817, - -0.0011422622483223677, - -0.0010771177476271987, - -0.00037898647133260965, - 0.0000025171791548928013, - -0.00026067905128002167, - -0.00014146546891424805, - 0.0038321535103023052, - -0.0004293300735298544, - -0.00142992555629462, - -0.0009228314156644046, - 0.0006944393389858305, - 0.00043302192352712154, - -0.0035714071709662676, - -0.0004967569257132709, - 0.0008057993836700916, - 0.0005424688570201397, - -0.0005309234256856143, - -0.0007159864180721343, - -0.0010389237431809306, - -0.0009490771917626262, - -0.00008649027586216107, - 0.0002766547549981624, - 0.0021084228064864874, - -0.0001975146442418918, - -0.0016405630158260465, - 0.1162627637386322, - 0.0002507446042727679, - -0.0014675153652206063, - -0.00039680811460129917, - 0.018962211906909943, - -0.00018764731066767126, - 0.011170871555805206, - -0.0013301445869728923, - -0.0007356539717875421, - -0.00030253134900704026, - -0.00014683544577565044, - -0.00022228369198273867, - -0.001650598249398172, - 0.0002927311579696834, - -0.00143563118763268, - 0.03084198758006096, - -0.007432155776768923, - -0.00028236035723239183, - 0.006017433945089579, - -0.011007187888026237, - -0.001266107545234263, - 0.0014901700196787715, - -0.0001800622121663764, - 0.002944394713267684, - -0.004211106337606907, - 0.0029597999528050423, - 0.002045023487880826, - 0.0013397098518908024, - -0.0012190865818411112, - 0.34349915385246277, - 0.0005632104002870619, - -0.0001262281439267099, - -0.00515326950699091, - 0.016240738332271576, - 0.01709030382335186, - -0.004175194539129734, - 0.039775289595127106, - 0.015226684510707855, - -0.0010229480685666203, - 0.0008072761120274663, - -0.004935584031045437, - -0.002123525831848383, - -0.014274083077907562, - 0.0013746818294748664, - 0.0014838266652077436, - 0.1302703619003296, - -0.00033616088330745697, - 0.0012919505825266242, - 0.00037177055492065847, - 0.019514480605721474, - 0.00022255218937061727, - 0.124249167740345, - -0.00040352059295400977, - -0.007652895525097847, - 0.0013010123511776328, - -0.0011253133416175842, - -0.007449474185705185, - 0.19224143028259277, - -0.003275118535384536, - -0.0005017912480980158, - -0.001007912098430097, - 0.00003091096004936844, - -0.0008595998515374959, - 0.012359987013041973, - -0.0004041247011628002, - -0.004328910261392593, - 0.3185553252696991, - 0.002330605871975422, - 0.0021182901691645384, - 0.0001405928487656638, - 0.2779357433319092, - 0.005738262087106705, - 0.0058898297138512135, - -0.0009689796715974808, - 0.00912561360746622, - 0.020675739273428917, - -0.03700518235564232, - 0.014263041317462921, - -0.04828466475009918, - 0.05834139883518219, - 0.0006514795240946114, - 0.26360899209976196, - 0.0004918567719869316, - -0.00261044898070395, - 0.08374208211898804, - 0.020676210522651672, - -0.003743582172319293, - 0.01085072010755539, - -0.001096583902835846, - 0.00047430366976186633, - 0.04818058758974075, - -0.4799128472805023, - 0.00018429107149131596, - 0.011861988343298435, - 0.06088569387793541, - 0.0008461413672193885, - 0.005328264087438583, - -0.011493473313748837, - -0.11350836604833603, - 0.006329597905278206, - 0.00031669469899497926, - -0.0011600167490541935, - -0.022669579833745956, - 0.004070379305630922, - 0.0073160636238753796, - -0.00834545586258173, - -0.27817651629447937, - 0.0036344374530017376 - ], - "yaxis": "y" - } - ], - "layout": { - "legend": { - "tracegroupgap": 0 - }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" - ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } - }, - "title": { - "text": "Scatter plot of output patching vs attention patching" - }, - "xaxis": { - "anchor": "y", - "domain": [ - 0, - 1 - ], - "title": { - "text": "Attention Patch" - } - }, - "yaxis": { - "anchor": "x", - "domain": [ - 0, - 1 - ], - "title": { - "text": "Output Patch" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "imshow(\n", - " patched_head_attn_diff,\n", - " title=\"Logit Difference From Patched Head Pattern\",\n", - " labels={\"x\": \"Head\", \"y\": \"Layer\"},\n", - ")\n", - "head_labels = [\n", - " f\"L{l}H{h}\" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)\n", - "]\n", - "scatter(\n", - " x=utils.to_numpy(patched_head_attn_diff.flatten()),\n", - " y=utils.to_numpy(patched_head_z_diff.flatten()),\n", - " hover_name=head_labels,\n", - " xaxis=\"Attention Patch\",\n", - " yaxis=\"Output Patch\",\n", - " title=\"Scatter plot of output patching vs attention patching\",\n", - ")" - ] + "title": { + "text": "Previous Token Scores" }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Consolidating Understanding\n", - "\n", - "OK, let's zoom out and reconsolidate. At a high-level, we find that all the action is on the second subject token until layer 7 and then transitions to the final token. And that attention layers matter a lot, MLP layers not so much (apart from MLP0, likely as an extended embedding).\n", - "\n", - "We've further localised important behaviour to several categories of heads. We've found 3 categories of heads that matter a lot - early heads (L5H5, L6H9, L3H0) whose output matters on the second subject and whose behaviour is determined by their attention patterns, mid-late heads (L8H6, L8H10, L7H9, L7H3) whose output matters on the final token and whose behaviour is determined by their value vectors, and late heads (L9H9, L10H7, L11H10) whose output matters on the final token and whose behaviour is determined by their attention patterns.\n", - "\n", - "A natural speculation is that early heads detect both that the second subject is a repeated token and *which* is repeated (ie the \" John\" token is repeated), middle heads compose with this and move this duplicated token information from the second subject token to the final token, and the late heads compose with this to *inhibit* their attention to the duplicated token, and then attend to the correct indirect object name and copy that directly to the logits." - ] + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "scaleanchor": "y", + "title": { + "text": "Head" + } }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Layer" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Visualizing Attention Patterns\n", - "\n", - "We can validate this by looking at the attention patterns of these heads! Let's take the top 10 heads by output patching (in absolute value) and split it into early, middle and late.\n", - "\n", - "We see that middle heads attend from the final token to the second subject, and late heads attend from the final token to the indirect object, which is completely consistent with the above speculation! But weirdly, while *one* early head attends from the second subject to its first copy, the other two mysteriously attend to the word *after* the first copy." - ] + "coloraxis": "coloraxis", + "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + 0.0031923248898237944, + 0.13236315548419952, + 0.005006915424019098, + 1.0427449524286203e-05, + 0.0013110184809193015, + 0.7034568786621094, + 0.00426204688847065, + 0.00016496369789820164, + 0.002474633976817131, + 0.0008572910446673632, + 0.01889149099588394, + 0.008690938353538513 + ], + [ + 0.0002916341181844473, + 0.00013782267342321575, + 0.0015036173863336444, + 0.005392482969909906, + 0.0018583914497867227, + 0.009062949568033218, + 0.012414448894560337, + 0.0022405502386391163, + 0.005135662388056517, + 0.005220627877861261, + 0.005546474829316139, + 0.02975049614906311 + ], + [ + 0.0024816279765218496, + 0.009442180395126343, + 0.0003456332196947187, + 0.0002591445227153599, + 0.0052116685546934605, + 0.000570951378904283, + 0.0015209749108180404, + 0.006313100922852755, + 0.001560864970088005, + 0.0004215767839923501, + 0.00015359291865024716, + 0.005160381551831961 + ], + [ + 0.6775657534599304, + 0.002840448170900345, + 0.0007841526530683041, + 0.00471264636144042, + 0.006322895642369986, + 0.006206681486219168, + 0.0005474805948324502, + 0.00037829449865967035, + 0.0020155368838459253, + 0.007952751591801643, + 0.003576782764866948, + 0.002608788898214698 + ], + [ + 0.00860405620187521, + 0.0070286463014781475, + 0.007598803844302893, + 0.003442801535129547, + 0.016561277210712433, + 0.0059797209687530994, + 0.004869826138019562, + 0.0007624455611221492, + 0.006062133703380823, + 0.007536627352237701, + 0.012022900395095348, + 1.055422134237094e-12 + ], + [ + 0.00950299296528101, + 0.00856209360063076, + 0.004162600729614496, + 0.003008665982633829, + 0.006847422569990158, + 0.004358117934316397, + 0.007669268175959587, + 0.009584215469658375, + 0.0076188258826732635, + 0.0043280418030917645, + 0.041402824223041534, + 0.00976183544844389 + ], + [ + 0.004456141032278538, + 0.008873268961906433, + 0.007405205629765987, + 0.0062249391339719296, + 0.00731915095821023, + 0.005623893812298775, + 0.017349667847156525, + 0.005529467947781086, + 0.002920132130384445, + 0.008636755868792534, + 0.006222263444215059, + 0.00835894700139761 + ], + [ + 0.003699858672916889, + 0.04107949137687683, + 0.04148268699645996, + 0.009313640184700489, + 0.009097025729715824, + 0.008774377405643463, + 0.007298537530004978, + 0.023312218487262726, + 0.008843323215842247, + 0.00987986009567976, + 0.017598601058125496, + 0.006039854139089584 + ], + [ + 0.008986304514110088, + 0.028667239472270012, + 0.008891218341886997, + 0.010114557109773159, + 0.009737391024827957, + 0.007611637003719807, + 0.009763265959918499, + 0.005155472084879875, + 0.009276345372200012, + 0.011895839124917984, + 0.010411946102976799, + 0.007498950231820345 + ], + [ + 0.024409977719187737, + 0.011438451707363129, + 0.02003096230328083, + 0.0051185814663767815, + 0.015081286430358887, + 0.012334450148046017, + 0.015452565625309944, + 0.008602450601756573, + 0.014702522195875645, + 0.020766200497746468, + 0.009192758239805698, + 0.005703347735106945 + ], + [ + 0.017897022888064384, + 0.013280633836984634, + 0.006755237001925707, + 0.012744844891130924, + 0.008020960725843906, + 0.007722244597971439, + 0.017341373488307, + 0.0074546560645103455, + 0.007832515984773636, + 0.00825214572250843, + 0.013642766512930393, + 0.012807483784854412 + ], + [ + 0.004923742264509201, + 0.007951060310006142, + 0.007947920821607113, + 0.004564082249999046, + 0.010363400913774967, + 0.009582078084349632, + 0.0102877551689744, + 0.00832072552293539, + 0.0025700009427964687, + 0.012810997664928436, + 0.008063871413469315, + 0.006558285094797611 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] }, - { - "cell_type": "code", - "execution_count": 33, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "

Top Early Heads


\n", - "

Top Middle Heads


\n", - "

Top Late Heads


\n", - "
" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 33, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "top_k = 10\n", - "top_heads_by_output_patch = torch.topk(\n", - " patched_head_z_diff.abs().flatten(), k=top_k\n", - ").indices\n", - "first_mid_layer = 7\n", - "first_late_layer = 9\n", - "early_heads = top_heads_by_output_patch[\n", - " top_heads_by_output_patch < model.cfg.n_heads * first_mid_layer\n", - "]\n", - "mid_heads = top_heads_by_output_patch[\n", - " torch.logical_and(\n", - " model.cfg.n_heads * first_mid_layer <= top_heads_by_output_patch,\n", - " top_heads_by_output_patch < model.cfg.n_heads * first_late_layer,\n", - " )\n", - "]\n", - "late_heads = top_heads_by_output_patch[\n", - " model.cfg.n_heads * first_late_layer <= top_heads_by_output_patch\n", - "]\n", - "\n", - "early = visualize_attention_patterns(\n", - " early_heads, cache, tokens[0], title=f\"Top Early Heads\"\n", - ")\n", - "mid = visualize_attention_patterns(\n", - " mid_heads, cache, tokens[0], title=f\"Top Middle Heads\"\n", - ")\n", - "late = visualize_attention_patterns(\n", - " late_heads, cache, tokens[0], title=f\"Top Late Heads\"\n", - ")\n", - "\n", - "HTML(early + mid + late)" + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Comparing to the Paper\n", - "\n", - "We can now refer to the (far, far more rigorous and detailed) analysis in the paper to compare our results! Here's the diagram they give of their results. \n", - "\n", - "![IOI1](https://pbs.twimg.com/media/FghGkTAWAAAmkhm.jpg)\n", - "\n", - "(Head 1.2 in their notation is L1H2 in my notation etc. And note - in the [latest version of the paper](https://arxiv.org/pdf/2211.00593.pdf) they add 9.0 as a backup name mover, and remove 11.3)\n", - "\n", - "The heads form three categories corresponding to the early, middle and late categories we found and we did fairly well! Definitely not perfect, but with some fairly generic techniques and some a priori reasoning, we found the broad strokes of the circuit and what it looks like. We focused on the most important heads, so we didn't find all relevant heads in each category (especially not the heads in brackets, which are more minor), but this serves as a good base for doing more rigorous and involved analysis, especially for finding the *complete* circuit (ie all of the parts of the model which participate in this behaviour) rather than just a partial and suggestive circuit. Go check out [their paper](https://arxiv.org/abs/2211.00593) or [our interview](https://www.youtube.com/watch?v=gzwj0jWbvbo) to learn more about what they did and what they found!\n", - "\n", - "Breaking down their categories:\n", - "\n", - "* Early: The duplicate token heads, previous token heads and induction heads. These serve the purpose of detecting that the second subject is duplicated and which earlier name is the duplicate.\n", - " * We found a direct duplicate token head which behaves exactly as expected, L3H0. Heads L5H0 and L6H9 are induction heads, which explains why they don't attend directly to the earlier copy of John!\n", - " * Note that the duplicate token heads and induction heads do not compose with each other - both directly add to the S-Inhibition heads. The diagram is somewhat misleading.\n", - "* Middle: They call these S-Inhibition heads - they copy the information about the duplicate token from the second subject to the to token, and their output is used to *inhibit* the attention paid from the name movers to the first subject copy. We found all these heads, and had a decent guess for what they did.\n", - " * In either case they attend to the second subject, so the patch that mattered was their value vectors!\n", - "* Late: They call these name movers, and we found some of them. They attend from the final token to the indirect object name and copy that to the logits, using the S-Inhibition heads to inhibit attention to the first copy of the subject token.\n", - " * We did find their surprising result of *negative* name movers - name movers that inhibit the correct answer!\n", - " * They have an entire category of heads we missed called backup name movers - we'll get to these later.\n", - "\n", - "So, now, let's dig into the two anomalies we missed - induction heads and backup name mover heads" + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Bonus: Exploring Anomalies" + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Early Heads are Induction Heads(?!)\n", - "\n", - "A really weird observation is that some of the early heads detecting duplicated tokens are induction heads, not just direct duplicate token heads. This is very weird! What's up with that? \n", - "\n", - "First off, what's an induction head? An induction head is an important type of attention head that can detect and continue repeated sequences. It is the second head in a two head induction circuit, which looks for previous copies of the current token and attends to the token *after* it, and then copies that to the current position and predicts that it will come next. They're enough of a big deal that [we wrote a whole paper on them](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html).\n", - "\n", - "![Move image demo](https://pbs.twimg.com/media/FNWAzXjVEAEOGRe.jpg)\n", - "\n", - "Second, why is it surprising that they come up here? It's surprising because it feels like overkill. The model doesn't care about *what* token comes after the first copy of the subject, just that it's duplicated. And it already has simpler duplicate token heads. My best guess is that it just already had induction heads around and that, in addition to their main function, they *also* only activate on duplicated tokens. So it was useful to repurpose this existing machinery. \n", - "\n", - "This suggests that as we look for circuits in larger models life may get more and more complicated, as components in simpler circuits get repurposed and built upon. " - ] + "title": { + "text": "Duplicate Token Scores" }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can verify that these are induction heads by running the model on repeated text and plotting the heads." - ] + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "scaleanchor": "y", + "title": { + "text": "Head" + } }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Layer" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ { - "cell_type": "code", - "execution_count": 34, - "metadata": {}, - "outputs": [], - "source": [ - "example_text = \"Research in mechanistic interpretability seeks to explain behaviors of machine learning models in terms of their internal components.\"\n", - "example_repeated_text = example_text + example_text\n", - "example_repeated_tokens = model.to_tokens(example_repeated_text, prepend_bos=True)\n", - "example_repeated_logits, example_repeated_cache = model.run_with_cache(\n", - " example_repeated_tokens\n", - ")\n", - "induction_head_labels = [81, 65]" - ] + "coloraxis": "coloraxis", + "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + 0.004035575315356255, + 3.85937346436549e-05, + 0.003946058917790651, + 1.7428524756724073e-07, + 5.9896130551351234e-05, + 4.0836803236743435e-05, + 0.0035017586778849363, + 0.00024610417312942445, + 0.0031679815147072077, + 0.0030104012694209814, + 0.002093541668727994, + 0.008525434881448746 + ], + [ + 0.000526473973877728, + 0.00015670718858018517, + 0.001507942914031446, + 0.005595325026661158, + 0.0018401180859655142, + 0.0038875630125403404, + 0.005349153187125921, + 0.004649169277399778, + 0.005880181211978197, + 0.007283917628228664, + 0.005552186165004969, + 0.00012677280756179243 + ], + [ + 0.0022015420254319906, + 0.008784863166511059, + 0.002159146359190345, + 0.0010447809472680092, + 0.005142326466739178, + 0.002251626690849662, + 0.0008376616751775146, + 0.006352409720420837, + 0.002618127502501011, + 0.0010309136705473065, + 0.00015219187480397522, + 0.005351166240870953 + ], + [ + 0.007752244360744953, + 0.0030915802344679832, + 0.001362923881970346, + 0.004341960418969393, + 0.011233060620725155, + 0.006535551976412535, + 0.000906877510715276, + 0.0006078600417822599, + 0.002819513902068138, + 0.005254077725112438, + 0.004195652436465025, + 0.00255418848246336 + ], + [ + 0.007342735771089792, + 0.004788339603692293, + 0.007458819076418877, + 0.0033073313534259796, + 0.007871866226196289, + 0.004219769034534693, + 0.004172054585069418, + 0.0005154653917998075, + 0.008124975487589836, + 0.0068268910981714725, + 0.008085492067039013, + 3.761376626831847e-11 + ], + [ + 0.4337766170501709, + 0.9306095838546753, + 0.006382268853485584, + 0.0034730439074337482, + 0.005500996019691229, + 0.9255973696708679, + 0.00538142304867506, + 0.007857315242290497, + 0.00863779615610838, + 0.01576443389058113, + 0.012188379652798176, + 0.008265726268291473 + ], + [ + 0.002507298020645976, + 0.008432027883827686, + 0.008623305708169937, + 0.007653353735804558, + 0.01105806790292263, + 0.005525435321033001, + 0.017205175012350082, + 0.004794349893927574, + 0.0040976013988256454, + 0.9257788062095642, + 0.020375633612275124, + 0.006313954945653677 + ], + [ + 0.005555536597967148, + 0.18942977488040924, + 0.8509925007820129, + 0.008273146115243435, + 0.008239664137363434, + 0.00864996388554573, + 0.02832852303981781, + 0.08996275067329407, + 0.006617339327931404, + 0.009413909167051315, + 0.9037814736366272, + 0.03037159889936447 + ], + [ + 0.00735454261302948, + 0.3791317641735077, + 0.005602709017693996, + 0.025401461869478226, + 0.008504674769937992, + 0.00623108958825469, + 0.11892436444759369, + 0.005114651285111904, + 0.013350939378142357, + 0.01576736941933632, + 0.025843923911452293, + 0.008429747074842453 + ], + [ + 0.2398916333913803, + 0.14378757774829865, + 0.09330663084983826, + 0.005819779820740223, + 0.07744801044464111, + 0.01644793339073658, + 0.4442836344242096, + 0.011141352355480194, + 0.03619001433253288, + 0.472646564245224, + 0.00803996529430151, + 0.030953049659729004 + ], + [ + 0.3606555163860321, + 0.48201146721839905, + 0.022851115092635155, + 0.1264195442199707, + 0.04125598818063736, + 0.0072374604642391205, + 0.2877156138420105, + 0.3897320628166199, + 0.030060900375247, + 0.006112942937761545, + 0.1655488908290863, + 0.22245149314403534 + ], + [ + 0.007408542558550835, + 0.033737149089574814, + 0.02041277289390564, + 0.002755412133410573, + 0.02518630214035511, + 0.07808877527713776, + 0.033082809299230576, + 0.046440087258815765, + 0.0032543439883738756, + 0.2744256258010864, + 0.3800230026245117, + 0.009483495727181435 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] }, - { - "cell_type": "code", - "execution_count": 35, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "

Induction Heads


\n", - "
" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 35, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "code = visualize_attention_patterns(\n", - " induction_head_labels,\n", - " example_repeated_cache,\n", - " example_repeated_tokens,\n", - " title=\"Induction Heads\",\n", - " max_width=800,\n", - ")\n", - "HTML(code)" + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Implications\n", - "\n", - "One implication of this is that it's useful to categories heads according to whether they occur in\n", - "simpler circuits, so that as we look for more complex circuits we can easily look for them. This is\n", - "easy to do here! An interesting fact about induction heads is that they work on a sequence of\n", - "repeated random tokens - notable for being wildly off distribution from the natural language GPT-2\n", - "was trained on. Being able to predict a model's behaviour off distribution is a good mark of success\n", - "for mechanistic interpretability! This is a good sanity check for whether a head is an induction\n", - "head or not. \n", - "\n", - "We can characterise an induction head by just giving a sequence of random tokens repeated once, and\n", - "measuring the average attention paid from the second copy of a token to the token after the first\n", - "copy. At the same time, we can also measure the average attention paid from the second copy of a\n", - "token to the first copy of the token, which is the attention that the induction head would pay if it\n", - "were a duplicate token head, and the average attention paid to the previous token to find previous\n", - "token heads.\n", - "\n", - "Note that this is a superficial study of whether something is an induction head - we totally ignore\n", - "the question of whether it actually does boost the correct token or whether it composes with a\n", - "single previous head and how. In particular, we sometimes get anti-induction heads which suppress\n", - "the induction-y token (no clue why!), and this technique will find those too . But given the\n", - "previous rigorous analysis, we can be pretty confident that this picks up on some true signal about\n", - "induction heads." + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "
Technical Implementation Details \n", - "We can do this again by using hooks, this time just to access the attention patterns rather than to intervene on them. \n", - "\n", - "Our hook function acts on the attention pattern activation. This has the name\n", - "\"blocks.{layer}.{layer_type}.hook_{activation_name}\" in general, here it's\n", - "\"blocks.{layer}.attn.hook_attn\". And it has shape [batch, head_index, query_pos, token_pos]. Our\n", - "hook function takes in the attention pattern activation, calculates the score for the relevant type\n", - "of head, and write it to an external cache.\n", - "\n", - "We add in hooks using `model.run_with_hooks(tokens, fwd_hooks=[(names_filter, hook_fn)])` to\n", - "temporarily add in the hooks and run the model, getting the resulting output. Previously\n", - "names_filter was the name of the activation, but here it's a boolean function mapping activation\n", - "names to whether we want to hook them or not. Here it's just whether the name ends with hook_attn.\n", - "hook_fn must take in the two inputs activation (the activation tensor) and hook (the HookPoint\n", - "object, which contains the name of the activation and some metadata such as the current layer).\n", - "\n", - "Internally our hooks use the function `tensor.diagonal`, this takes the diagonal between two\n", - "dimensions, and allows an arbitrary offset - offset by 1 to get previous tokens, seq_len to get\n", - "duplicate tokens (the distance to earlier copies) and seq_len-1 to get induction heads (the distance\n", - "to the token *after* earlier copies). Different offsets give a different length of output tensor,\n", - "and we can now just average to get a score in [0, 1] for each head\n", - "
" + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } }, - { - "cell_type": "code", - "execution_count": 36, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[0.0390, 0.0000, 0.0310],\n", - " [0.1890, 0.1720, 0.0680],\n", - " [0.1570, 0.0210, 0.4820]])\n", - "tensor([[0.0030, 0.1320, 0.0050],\n", - " [0.0000, 0.0000, 0.0020],\n", - " [0.0020, 0.0090, 0.0000]])\n", - "tensor([[0.0040, 0.0000, 0.0040],\n", - " [0.0010, 0.0000, 0.0020],\n", - " [0.0020, 0.0090, 0.0020]])\n" - ] - } - ], - "source": [ - "seq_len = 100\n", - "batch_size = 2\n", - "\n", - "prev_token_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=device)\n", - "\n", - "\n", - "def prev_token_hook(pattern, hook):\n", - " layer = hook.layer()\n", - " diagonal = pattern.diagonal(offset=1, dim1=-1, dim2=-2)\n", - " # print(diagonal)\n", - " # print(pattern)\n", - " prev_token_scores[layer] = einops.reduce(\n", - " diagonal, \"batch head_index diagonal -> head_index\", \"mean\"\n", - " )\n", - "\n", - "\n", - "duplicate_token_scores = torch.zeros(\n", - " (model.cfg.n_layers, model.cfg.n_heads), device=device\n", - ")\n", - "\n", - "\n", - "def duplicate_token_hook(pattern, hook):\n", - " layer = hook.layer()\n", - " diagonal = pattern.diagonal(offset=seq_len, dim1=-1, dim2=-2)\n", - " duplicate_token_scores[layer] = einops.reduce(\n", - " diagonal, \"batch head_index diagonal -> head_index\", \"mean\"\n", - " )\n", - "\n", - "\n", - "induction_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=device)\n", - "\n", - "\n", - "def induction_hook(pattern, hook):\n", - " layer = hook.layer()\n", - " diagonal = pattern.diagonal(offset=seq_len - 1, dim1=-1, dim2=-2)\n", - " induction_scores[layer] = einops.reduce(\n", - " diagonal, \"batch head_index diagonal -> head_index\", \"mean\"\n", - " )\n", - "\n", - "\n", - "torch.manual_seed(0)\n", - "original_tokens = torch.randint(\n", - " 100, 20000, size=(batch_size, seq_len), device=\"cpu\"\n", - ").to(device)\n", - "repeated_tokens = einops.repeat(\n", - " original_tokens, \"batch seq_len -> batch (2 seq_len)\"\n", - ").to(device)\n", - "\n", - "pattern_filter = lambda act_name: act_name.endswith(\"hook_pattern\")\n", - "\n", - "loss = model.run_with_hooks(\n", - " repeated_tokens,\n", - " return_type=\"loss\",\n", - " fwd_hooks=[\n", - " (pattern_filter, prev_token_hook),\n", - " (pattern_filter, duplicate_token_hook),\n", - " (pattern_filter, induction_hook),\n", - " ],\n", - ")\n", - "print(torch.round(utils.get_corner(prev_token_scores).detach().cpu(), decimals=3))\n", - "print(torch.round(utils.get_corner(duplicate_token_scores).detach().cpu(), decimals=3))\n", - "print(torch.round(utils.get_corner(induction_scores).detach().cpu(), decimals=3))" - ] + "title": { + "text": "Induction Head Scores" }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can now plot the head scores, and instantly see that the relevant early heads are induction heads or duplicate token heads (though also that there's a lot of induction heads that are *not* use - I have no idea why!). " - ] + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "scaleanchor": "y", + "title": { + "text": "Head" + } }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Layer" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "imshow(\n", + " prev_token_scores, labels={\"x\": \"Head\", \"y\": \"Layer\"}, title=\"Previous Token Scores\"\n", + ")\n", + "imshow(\n", + " duplicate_token_scores,\n", + " labels={\"x\": \"Head\", \"y\": \"Layer\"},\n", + " title=\"Duplicate Token Scores\",\n", + ")\n", + "imshow(\n", + " induction_scores, labels={\"x\": \"Head\", \"y\": \"Layer\"}, title=\"Induction Head Scores\"\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The above suggests that it would be a useful bit of infrastructure to have a \"wiki\" for the heads of a model, giving their scores according to some metrics re head functions, like the ones we've seen here. TransformerLens makes this easy to make, as just changing the name input to `HookedTransformer.from_pretrained` gives a different model but in the same architecture, so the same code should work. If you want to make this, I'd love to see it! \n", + "\n", + "As a proof of concept, [I made a mosaic of all induction heads across the 40 models then in TransformerLens](https://www.neelnanda.io/mosaic).\n", + "\n", + "![induction scores as proof of concept](https://firebasestorage.googleapis.com/v0/b/firescript-577a2.appspot.com/o/imgs%2Fapp%2FNeelNanda%2F5vtuFmdzt_.png?alt=media&token=4d613de4-9d14-48d6-ba9d-e591c562d429)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Backup Name Mover Heads" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Another fascinating anomaly is that of the **backup name mover heads**. A standard technique to apply when interpreting model internals is ablations, or knock-out. If we run the model but intervene to set a specific head to zero, what happens? If the model is robust to this intervention, then naively we can be confident that the head is not doing anything important, and conversely if the model is much worse at the task this suggests that head was important. There are several conceptual flaws with this approach, making the evidence only suggestive, eg that the average output of the head may be far from zero and so the knockout may send it far from expected activations, breaking internals on *any* task. But it's still an easy technique to apply to give some data.\n", + "\n", + "But a wild finding in the paper is that models have **built in redundancy**. If we knock out one of the name movers, then there are some backup name movers in later layers that *change their behaviour* and do (some of) the job of the original name mover head. This means that naive knock-out will significantly underestimate the importance of the name movers.\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's test this! Let's ablate the most important name mover (head L9H9) on just the final token using a custom ablation hook and then cache all new activations and compared performance. We focus on the final position because we want to specifically ablate the direct logit effect. When we do this, we see that naively, removing the top name mover should reduce the logit diff massively, from 3.55 to 0.57. **But actually, it only goes down to 2.99!**\n", + "\n", + "
Implementation Details \n", + "Ablating heads is really easy in TransformerLens! We can just define a hook on the z activation in the relevant attention layer (recall, z is the mixed values, and comes immediately before multiplying by the output weights $W_O$). z has a head_index axis, so we can set the component for the relevant head and for position -1 to zero, and return it. (Technically we could just edit in place without returning it, but by convention we always return an edited activation). \n", + "\n", + "We now want to compare all internal activations with a hook, which is hard to do with the nice `run_with_hooks` API. So we can directly access the hook on the z activation with `model.blocks[layer].attn.hook_z` and call its `add_hook` method. This adds in the hook to the *global state* of the model. We can now use run_with_cache, and don't need to care about the global state, because run_with_cache internally adds a bunch of caching hooks, and then removes all hooks after the run, *including* the previously added ablation hook. This can be disabled with the reset_hooks_end flag, but here it's useful! \n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Top Name Mover to ablate: L9H9\n", + "Original logit diff: 3.55\n", + "Post ablation logit diff: 2.92\n", + "Direct Logit Attribution of top name mover head: 2.99\n", + "Naive prediction of post ablation logit diff: 0.57\n" + ] + } + ], + "source": [ + "top_name_mover = per_head_logit_diffs.flatten().argmax().item()\n", + "top_name_mover_layer = top_name_mover // model.cfg.n_heads\n", + "top_name_mover_head = top_name_mover % model.cfg.n_heads\n", + "print(f\"Top Name Mover to ablate: L{top_name_mover_layer}H{top_name_mover_head}\")\n", + "\n", + "\n", + "def ablate_top_head_hook(z: Float[torch.Tensor, \"batch pos head_index d_head\"], hook):\n", + " z[:, -1, top_name_mover_head, :] = 0\n", + " return z\n", + "\n", + "\n", + "# Adds a hook into global model state\n", + "model.blocks[top_name_mover_layer].attn.hook_z.add_hook(ablate_top_head_hook)\n", + "# Runs the model, temporarily adds caching hooks and then removes *all* hooks after running, including the ablation hook.\n", + "ablated_logits, ablated_cache = model.run_with_cache(tokens)\n", + "print(f\"Original logit diff: {original_average_logit_diff:.2f}\")\n", + "print(\n", + " f\"Post ablation logit diff: {logits_to_ave_logit_diff(ablated_logits, answer_tokens).item():.2f}\"\n", + ")\n", + "print(\n", + " f\"Direct Logit Attribution of top name mover head: {per_head_logit_diffs.flatten()[top_name_mover].item():.2f}\"\n", + ")\n", + "print(\n", + " f\"Naive prediction of post ablation logit diff: {original_average_logit_diff - per_head_logit_diffs.flatten()[top_name_mover].item():.2f}\"\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "So what's up with this? As before, we can look at the direct logit attribution of each head to see what's going on. It's easiest to interpret if plotted as a scatter plot against the initial per head logit difference.\n", + "\n", + "And we can see a *really* big difference in a few heads! (Hover to see labels) In particular the negative name mover L10H7 decreases its negative effect a lot, adding +1 to the logit diff, and the backup name mover L10H10 adjusts its effect to be more positive, adding +0.8 to the logit diff (with several other marginal changes). (And obviously the ablated head has gone down to zero!)" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tried to stack head results when they weren't cached. Computing head results now\n" + ] + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ { - "cell_type": "code", - "execution_count": 37, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ - { - "coloraxis": "coloraxis", - "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", - "name": "0", - "type": "heatmap", - "xaxis": "x", - "yaxis": "y", - "z": [ - [ - 0.039069853723049164, - 0.0004489101702347398, - 0.03133601322770119, - 0.007519590202718973, - 0.034592196345329285, - 0.00036230171099305153, - 0.034512776881456375, - 0.19740213453769684, - 0.038447845727205276, - 0.04053792357444763, - 0.027628764510154724, - 0.02496313862502575 - ], - [ - 0.1890650987625122, - 0.17219914495944977, - 0.06807752698659897, - 0.04494515433907509, - 0.07908554375171661, - 0.03096739575266838, - 0.028282109647989273, - 0.03644327446818352, - 0.026936717331409454, - 0.018826229497790337, - 0.045100897550582886, - 0.0065726665779948235 - ], - [ - 0.15745528042316437, - 0.020724520087242126, - 0.4817989468574524, - 0.2991352379322052, - 0.10764895379543304, - 0.33004048466682434, - 0.0997551754117012, - 0.04926132410764694, - 0.25493940711021423, - 0.3606453835964203, - 0.1257179230451584, - 0.07931824028491974 - ], - [ - 0.005844001192599535, - 0.15787364542484283, - 0.4189082086086273, - 0.30129021406173706, - 0.014345049858093262, - 0.032344333827495575, - 0.3312888443470001, - 0.5285974144935608, - 0.34242063760757446, - 0.101837158203125, - 0.10516070574522018, - 0.2233113795518875 - ], - [ - 0.10626544803380966, - 0.11930850893259048, - 0.022880680859088898, - 0.22826944291591644, - 0.020003994926810265, - 0.10010036826133728, - 0.1739213615655899, - 0.17407020926475525, - 0.02587701380252838, - 0.10249985754489899, - 0.009514841251075268, - 0.9921423196792603 - ], - [ - 0.019766658544540405, - 0.00528325280174613, - 0.16648508608341217, - 0.12087740004062653, - 0.16500000655651093, - 0.00803269725292921, - 0.41770195960998535, - 0.025827765464782715, - 0.04802601411938667, - 0.016231779009103775, - 0.03110172413289547, - 0.024261215701699257 - ], - [ - 0.2172909826040268, - 0.039100028574466705, - 0.01804858259856701, - 0.059900715947151184, - 0.032934583723545074, - 0.0873451679944992, - 0.026895340532064438, - 0.0943947583436966, - 0.49925994873046875, - 0.006240115500986576, - 0.027026718482375145, - 0.1278565675020218 - ], - [ - 0.2511657178401947, - 0.01330868061631918, - 0.006663354113698006, - 0.037430502474308014, - 0.02331537753343582, - 0.01740722358226776, - 0.022067422047257423, - 0.022141192108392715, - 0.04502448812127113, - 0.0208425372838974, - 0.008310739882290363, - 0.017167754471302032 - ], - [ - 0.020890623331069946, - 0.016537941992282867, - 0.02158307284116745, - 0.0150058064609766, - 0.02421221323311329, - 0.10198988765478134, - 0.029100384563207626, - 0.22793792188167572, - 0.02781485579907894, - 0.0179410632699728, - 0.024828944355249405, - 0.03806235268712044 - ], - [ - 0.02607586607336998, - 0.015407431870698929, - 0.02044427953660488, - 0.14558182656764984, - 0.01247025839984417, - 0.017151640728116035, - 0.013311829417943954, - 0.024451706558465958, - 0.018111787736415863, - 0.01319331955164671, - 0.0357399508357048, - 0.01879822090268135 - ], - [ - 0.02147812582552433, - 0.018419174477458, - 0.018183622509241104, - 0.02172141708433628, - 0.0315677747130394, - 0.034705750644207, - 0.017550116404891014, - 0.011417553760111332, - 0.01579565554857254, - 0.04592214897274971, - 0.01621554046869278, - 0.03039470687508583 - ], - [ - 0.03320508822798729, - 0.0175714660435915, - 0.015131079591810703, - 0.04148406535387039, - 0.015181189402937889, - 0.01758997142314911, - 0.015148494392633438, - 0.01767607219517231, - 0.06622709333896637, - 0.018451133742928505, - 0.01700744964182377, - 0.029749270528554916 - ] - ] - } - ], - "layout": { - "coloraxis": { - "cmid": 0, - "colorscale": [ - [ - 0, - "rgb(103,0,31)" - ], - [ - 0.1, - "rgb(178,24,43)" - ], - [ - 0.2, - "rgb(214,96,77)" - ], - [ - 0.3, - "rgb(244,165,130)" - ], - [ - 0.4, - "rgb(253,219,199)" - ], - [ - 0.5, - "rgb(247,247,247)" - ], - [ - 0.6, - "rgb(209,229,240)" - ], - [ - 0.7, - "rgb(146,197,222)" - ], - [ - 0.8, - "rgb(67,147,195)" - ], - [ - 0.9, - "rgb(33,102,172)" - ], - [ - 1, - "rgb(5,48,97)" - ] - ] - }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" - ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } - }, - "title": { - "text": "Previous Token Scores" - }, - "xaxis": { - "anchor": "y", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "scaleanchor": "y", - "title": { - "text": "Head" - } - }, - "yaxis": { - "anchor": "x", - "autorange": "reversed", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "title": { - "text": "Layer" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ - { - "coloraxis": "coloraxis", - "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", - "name": "0", - "type": "heatmap", - "xaxis": "x", - "yaxis": "y", - "z": [ - [ - 0.0031923248898237944, - 0.13236315548419952, - 0.005006915424019098, - 0.000010427449524286203, - 0.0013110184809193015, - 0.7034568786621094, - 0.00426204688847065, - 0.00016496369789820164, - 0.002474633976817131, - 0.0008572910446673632, - 0.01889149099588394, - 0.008690938353538513 - ], - [ - 0.0002916341181844473, - 0.00013782267342321575, - 0.0015036173863336444, - 0.005392482969909906, - 0.0018583914497867227, - 0.009062949568033218, - 0.012414448894560337, - 0.0022405502386391163, - 0.005135662388056517, - 0.005220627877861261, - 0.005546474829316139, - 0.02975049614906311 - ], - [ - 0.0024816279765218496, - 0.009442180395126343, - 0.0003456332196947187, - 0.0002591445227153599, - 0.0052116685546934605, - 0.000570951378904283, - 0.0015209749108180404, - 0.006313100922852755, - 0.001560864970088005, - 0.0004215767839923501, - 0.00015359291865024716, - 0.005160381551831961 - ], - [ - 0.6775657534599304, - 0.002840448170900345, - 0.0007841526530683041, - 0.00471264636144042, - 0.006322895642369986, - 0.006206681486219168, - 0.0005474805948324502, - 0.00037829449865967035, - 0.0020155368838459253, - 0.007952751591801643, - 0.003576782764866948, - 0.002608788898214698 - ], - [ - 0.00860405620187521, - 0.0070286463014781475, - 0.007598803844302893, - 0.003442801535129547, - 0.016561277210712433, - 0.0059797209687530994, - 0.004869826138019562, - 0.0007624455611221492, - 0.006062133703380823, - 0.007536627352237701, - 0.012022900395095348, - 1.055422134237094e-12 - ], - [ - 0.00950299296528101, - 0.00856209360063076, - 0.004162600729614496, - 0.003008665982633829, - 0.006847422569990158, - 0.004358117934316397, - 0.007669268175959587, - 0.009584215469658375, - 0.0076188258826732635, - 0.0043280418030917645, - 0.041402824223041534, - 0.00976183544844389 - ], - [ - 0.004456141032278538, - 0.008873268961906433, - 0.007405205629765987, - 0.0062249391339719296, - 0.00731915095821023, - 0.005623893812298775, - 0.017349667847156525, - 0.005529467947781086, - 0.002920132130384445, - 0.008636755868792534, - 0.006222263444215059, - 0.00835894700139761 - ], - [ - 0.003699858672916889, - 0.04107949137687683, - 0.04148268699645996, - 0.009313640184700489, - 0.009097025729715824, - 0.008774377405643463, - 0.007298537530004978, - 0.023312218487262726, - 0.008843323215842247, - 0.00987986009567976, - 0.017598601058125496, - 0.006039854139089584 - ], - [ - 0.008986304514110088, - 0.028667239472270012, - 0.008891218341886997, - 0.010114557109773159, - 0.009737391024827957, - 0.007611637003719807, - 0.009763265959918499, - 0.005155472084879875, - 0.009276345372200012, - 0.011895839124917984, - 0.010411946102976799, - 0.007498950231820345 - ], - [ - 0.024409977719187737, - 0.011438451707363129, - 0.02003096230328083, - 0.0051185814663767815, - 0.015081286430358887, - 0.012334450148046017, - 0.015452565625309944, - 0.008602450601756573, - 0.014702522195875645, - 0.020766200497746468, - 0.009192758239805698, - 0.005703347735106945 - ], - [ - 0.017897022888064384, - 0.013280633836984634, - 0.006755237001925707, - 0.012744844891130924, - 0.008020960725843906, - 0.007722244597971439, - 0.017341373488307, - 0.0074546560645103455, - 0.007832515984773636, - 0.00825214572250843, - 0.013642766512930393, - 0.012807483784854412 - ], - [ - 0.004923742264509201, - 0.007951060310006142, - 0.007947920821607113, - 0.004564082249999046, - 0.010363400913774967, - 0.009582078084349632, - 0.0102877551689744, - 0.00832072552293539, - 0.0025700009427964687, - 0.012810997664928436, - 0.008063871413469315, - 0.006558285094797611 - ] - ] - } - ], - "layout": { - "coloraxis": { - "cmid": 0, - "colorscale": [ - [ - 0, - "rgb(103,0,31)" - ], - [ - 0.1, - "rgb(178,24,43)" - ], - [ - 0.2, - "rgb(214,96,77)" - ], - [ - 0.3, - "rgb(244,165,130)" - ], - [ - 0.4, - "rgb(253,219,199)" - ], - [ - 0.5, - "rgb(247,247,247)" - ], - [ - 0.6, - "rgb(209,229,240)" - ], - [ - 0.7, - "rgb(146,197,222)" - ], - [ - 0.8, - "rgb(67,147,195)" - ], - [ - 0.9, - "rgb(33,102,172)" - ], - [ - 1, - "rgb(5,48,97)" - ] - ] - }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" - ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } - }, - "title": { - "text": "Duplicate Token Scores" - }, - "xaxis": { - "anchor": "y", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "scaleanchor": "y", - "title": { - "text": "Head" - } - }, - "yaxis": { - "anchor": "x", - "autorange": "reversed", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "title": { - "text": "Layer" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ - { - "coloraxis": "coloraxis", - "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", - "name": "0", - "type": "heatmap", - "xaxis": "x", - "yaxis": "y", - "z": [ - [ - 0.004035575315356255, - 0.0000385937346436549, - 0.003946058917790651, - 1.7428524756724073e-7, - 0.000059896130551351234, - 0.000040836803236743435, - 0.0035017586778849363, - 0.00024610417312942445, - 0.0031679815147072077, - 0.0030104012694209814, - 0.002093541668727994, - 0.008525434881448746 - ], - [ - 0.000526473973877728, - 0.00015670718858018517, - 0.001507942914031446, - 0.005595325026661158, - 0.0018401180859655142, - 0.0038875630125403404, - 0.005349153187125921, - 0.004649169277399778, - 0.005880181211978197, - 0.007283917628228664, - 0.005552186165004969, - 0.00012677280756179243 - ], - [ - 0.0022015420254319906, - 0.008784863166511059, - 0.002159146359190345, - 0.0010447809472680092, - 0.005142326466739178, - 0.002251626690849662, - 0.0008376616751775146, - 0.006352409720420837, - 0.002618127502501011, - 0.0010309136705473065, - 0.00015219187480397522, - 0.005351166240870953 - ], - [ - 0.007752244360744953, - 0.0030915802344679832, - 0.001362923881970346, - 0.004341960418969393, - 0.011233060620725155, - 0.006535551976412535, - 0.000906877510715276, - 0.0006078600417822599, - 0.002819513902068138, - 0.005254077725112438, - 0.004195652436465025, - 0.00255418848246336 - ], - [ - 0.007342735771089792, - 0.004788339603692293, - 0.007458819076418877, - 0.0033073313534259796, - 0.007871866226196289, - 0.004219769034534693, - 0.004172054585069418, - 0.0005154653917998075, - 0.008124975487589836, - 0.0068268910981714725, - 0.008085492067039013, - 3.761376626831847e-11 - ], - [ - 0.4337766170501709, - 0.9306095838546753, - 0.006382268853485584, - 0.0034730439074337482, - 0.005500996019691229, - 0.9255973696708679, - 0.00538142304867506, - 0.007857315242290497, - 0.00863779615610838, - 0.01576443389058113, - 0.012188379652798176, - 0.008265726268291473 - ], - [ - 0.002507298020645976, - 0.008432027883827686, - 0.008623305708169937, - 0.007653353735804558, - 0.01105806790292263, - 0.005525435321033001, - 0.017205175012350082, - 0.004794349893927574, - 0.0040976013988256454, - 0.9257788062095642, - 0.020375633612275124, - 0.006313954945653677 - ], - [ - 0.005555536597967148, - 0.18942977488040924, - 0.8509925007820129, - 0.008273146115243435, - 0.008239664137363434, - 0.00864996388554573, - 0.02832852303981781, - 0.08996275067329407, - 0.006617339327931404, - 0.009413909167051315, - 0.9037814736366272, - 0.03037159889936447 - ], - [ - 0.00735454261302948, - 0.3791317641735077, - 0.005602709017693996, - 0.025401461869478226, - 0.008504674769937992, - 0.00623108958825469, - 0.11892436444759369, - 0.005114651285111904, - 0.013350939378142357, - 0.01576736941933632, - 0.025843923911452293, - 0.008429747074842453 - ], - [ - 0.2398916333913803, - 0.14378757774829865, - 0.09330663084983826, - 0.005819779820740223, - 0.07744801044464111, - 0.01644793339073658, - 0.4442836344242096, - 0.011141352355480194, - 0.03619001433253288, - 0.472646564245224, - 0.00803996529430151, - 0.030953049659729004 - ], - [ - 0.3606555163860321, - 0.48201146721839905, - 0.022851115092635155, - 0.1264195442199707, - 0.04125598818063736, - 0.0072374604642391205, - 0.2877156138420105, - 0.3897320628166199, - 0.030060900375247, - 0.006112942937761545, - 0.1655488908290863, - 0.22245149314403534 - ], - [ - 0.007408542558550835, - 0.033737149089574814, - 0.02041277289390564, - 0.002755412133410573, - 0.02518630214035511, - 0.07808877527713776, - 0.033082809299230576, - 0.046440087258815765, - 0.0032543439883738756, - 0.2744256258010864, - 0.3800230026245117, - 0.009483495727181435 - ] - ] - } - ], - "layout": { - "coloraxis": { - "cmid": 0, - "colorscale": [ - [ - 0, - "rgb(103,0,31)" - ], - [ - 0.1, - "rgb(178,24,43)" - ], - [ - 0.2, - "rgb(214,96,77)" - ], - [ - 0.3, - "rgb(244,165,130)" - ], - [ - 0.4, - "rgb(253,219,199)" - ], - [ - 0.5, - "rgb(247,247,247)" - ], - [ - 0.6, - "rgb(209,229,240)" - ], - [ - 0.7, - "rgb(146,197,222)" - ], - [ - 0.8, - "rgb(67,147,195)" - ], - [ - 0.9, - "rgb(33,102,172)" - ], - [ - 1, - "rgb(5,48,97)" - ] - ] - }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" - ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } - }, - "title": { - "text": "Induction Head Scores" - }, - "xaxis": { - "anchor": "y", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "scaleanchor": "y", - "title": { - "text": "Head" - } - }, - "yaxis": { - "anchor": "x", - "autorange": "reversed", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "title": { - "text": "Layer" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "imshow(\n", - " prev_token_scores, labels={\"x\": \"Head\", \"y\": \"Layer\"}, title=\"Previous Token Scores\"\n", - ")\n", - "imshow(\n", - " duplicate_token_scores,\n", - " labels={\"x\": \"Head\", \"y\": \"Layer\"},\n", - " title=\"Duplicate Token Scores\",\n", - ")\n", - "imshow(\n", - " induction_scores, labels={\"x\": \"Head\", \"y\": \"Layer\"}, title=\"Induction Head Scores\"\n", - ")" - ] + "coloraxis": "coloraxis", + "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + -0.002156503964215517, + -0.0004650682385545224, + 0.00024167183437384665, + 0.0002806585980579257, + -0.0004162999684922397, + -0.0004892416181974113, + -0.002620948012918234, + -0.002935677068307996, + 0.00042561208829283714, + 0.0005418329383246601, + 0.00023754138965159655, + -7.48957390896976e-05 + ], + [ + -0.000658505829051137, + 0.0004060641804244369, + -0.0009330413886345923, + 0.0008937822422012687, + -0.0009785268921405077, + -0.000533820129930973, + -0.0027988189831376076, + -0.004214101936668158, + 0.002578593324869871, + 0.0024506838526576757, + 0.0005351756699383259, + 0.0012349633034318686 + ], + [ + 0.0009405204327777028, + -0.0011168691562488675, + -0.0011541967978700995, + -0.0015697095077484846, + -0.0005699327448382974, + 0.001451514894142747, + 0.002439911477267742, + 0.003158293664455414, + 0.000923738582059741, + -0.003578126197680831, + -0.0010650777257978916, + -0.0003558753523975611 + ], + [ + -0.0005624951445497572, + -1.1960582924075425e-05, + 0.0011531109921634197, + 0.0007360265008173883, + 0.0016493839211761951, + 0.0008800819050520658, + -0.0006905529880896211, + -0.003031972097232938, + 0.0008080147090367973, + 0.00010368914809077978, + -0.0005807994166389108, + -0.0011067037703469396 + ], + [ + -0.0026375530287623405, + 0.0002691895351745188, + -0.0016417437000200152, + -0.003406986128538847, + 0.0017449699807912111, + 0.00046454701805487275, + -0.0007899806369096041, + 0.0018328562146052718, + -0.00086324627045542, + -0.0003978293389081955, + 0.0007879206677898765, + -0.00012048585631418973 + ], + [ + 0.0008688560919836164, + 0.0009473530226387084, + -0.0022812988609075546, + -0.0011803123634308577, + 0.0002407809515716508, + -0.0004318578285165131, + -0.0003728170122485608, + -0.000738416681997478, + 0.0008113418589346111, + -0.00040444196201860905, + -0.007074396125972271, + 0.003946478478610516 + ], + [ + -0.014917617663741112, + -0.0022801742888987064, + 0.0022679336834698915, + -8.302251808345318e-05, + -0.004980948753654957, + 0.0027670026756823063, + 0.006266288459300995, + -0.003485947148874402, + -0.0013348984066396952, + -0.0017918883822858334, + -0.0012231896398589015, + 0.00040514359716326 + ], + [ + -0.0002460568503011018, + -0.005790225230157375, + -0.0004975841729901731, + 0.142182856798172, + -0.0014961492270231247, + -0.019006317481398582, + 0.003133433870971203, + -0.001858205534517765, + -0.011305196210741997, + 0.1922595500946045, + -0.0011892566690221429, + -0.0010282933944836259 + ], + [ + -0.0038003993686288595, + -0.0008570950012654066, + -0.013956742361187935, + 0.00828910805284977, + 0.004315475933253765, + -0.009073829278349876, + -0.08315148949623108, + 0.0034569751005619764, + -0.01805492490530014, + 0.002178061753511429, + 0.29780513048171997, + 0.02409379370510578 + ], + [ + 0.08904723823070526, + -0.0007931794971227646, + 0.07247699797153473, + 0.015016308054327965, + -0.02120928093791008, + 0.05205465108156204, + 1.4411165714263916, + 0.04743674397468567, + -0.03229031339287758, + 0, + 0.0019993737805634737, + -0.00807223655283451 + ], + [ + 0.8600788116455078, + 0.3260062038898468, + 0.16344408690929413, + 0.07133537530899048, + -0.00444837287068367, + 0.000681330740917474, + 0.36613449454307556, + -0.7105098962783813, + -0.002031375654041767, + -0.032143525779247284, + 1.2294330596923828, + 0.0018453558441251516 + ], + [ + 0.016877274960279465, + -0.001730365096591413, + -0.5010868310928345, + 0.02749764919281006, + -0.0059662917628884315, + -0.004944110754877329, + -0.08855228126049042, + 0.006622308399528265, + 0.044124361127614975, + -0.02726735547184944, + -1.134916067123413, + 0.02287953346967697 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The above suggests that it would be a useful bit of infrastructure to have a \"wiki\" for the heads of a model, giving their scores according to some metrics re head functions, like the ones we've seen here. TransformerLens makes this easy to make, as just changing the name input to `HookedTransformer.from_pretrained` gives a different model but in the same architecture, so the same code should work. If you want to make this, I'd love to see it! \n", - "\n", - "As a proof of concept, [I made a mosaic of all induction heads across the 40 models then in TransformerLens](https://www.neelnanda.io/mosaic).\n", - "\n", - "![induction scores as proof of concept](https://firebasestorage.googleapis.com/v0/b/firescript-577a2.appspot.com/o/imgs%2Fapp%2FNeelNanda%2F5vtuFmdzt_.png?alt=media&token=4d613de4-9d14-48d6-ba9d-e591c562d429)" - ] + "margin": { + "t": 60 }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Backup Name Mover Heads" + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Another fascinating anomaly is that of the **backup name mover heads**. A standard technique to apply when interpreting model internals is ablations, or knock-out. If we run the model but intervene to set a specific head to zero, what happens? If the model is robust to this intervention, then naively we can be confident that the head is not doing anything important, and conversely if the model is much worse at the task this suggests that head was important. There are several conceptual flaws with this approach, making the evidence only suggestive, eg that the average output of the head may be far from zero and so the knockout may send it far from expected activations, breaking internals on *any* task. But it's still an easy technique to apply to give some data.\n", - "\n", - "But a wild finding in the paper is that models have **built in redundancy**. If we knock out one of the name movers, then there are some backup name movers in later layers that *change their behaviour* and do (some of) the job of the original name mover head. This means that naive knock-out will significantly underestimate the importance of the name movers.\n" + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's test this! Let's ablate the most important name mover (head L9H9) on just the final token using a custom ablation hook and then cache all new activations and compared performance. We focus on the final position because we want to specifically ablate the direct logit effect. When we do this, we see that naively, removing the top name mover should reduce the logit diff massively, from 3.55 to 0.57. **But actually, it only goes down to 2.99!**\n", - "\n", - "
Implementation Details \n", - "Ablating heads is really easy in TransformerLens! We can just define a hook on the z activation in the relevant attention layer (recall, z is the mixed values, and comes immediately before multiplying by the output weights $W_O$). z has a head_index axis, so we can set the component for the relevant head and for position -1 to zero, and return it. (Technically we could just edit in place without returning it, but by convention we always return an edited activation). \n", - "\n", - "We now want to compare all internal activations with a hook, which is hard to do with the nice `run_with_hooks` API. So we can directly access the hook on the z activation with `model.blocks[layer].attn.hook_z` and call its `add_hook` method. This adds in the hook to the *global state* of the model. We can now use run_with_cache, and don't need to care about the global state, because run_with_cache internally adds a bunch of caching hooks, and then removes all hooks after the run, *including* the previously added ablation hook. This can be disabled with the reset_hooks_end flag, but here it's useful! \n", - "
" + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } }, - { - "cell_type": "code", - "execution_count": 38, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Top Name Mover to ablate: L9H9\n", - "Original logit diff: 3.55\n", - "Post ablation logit diff: 2.92\n", - "Direct Logit Attribution of top name mover head: 2.99\n", - "Naive prediction of post ablation logit diff: 0.57\n" - ] - } - ], - "source": [ - "top_name_mover = per_head_logit_diffs.flatten().argmax().item()\n", - "top_name_mover_layer = top_name_mover // model.cfg.n_heads\n", - "top_name_mover_head = top_name_mover % model.cfg.n_heads\n", - "print(f\"Top Name Mover to ablate: L{top_name_mover_layer}H{top_name_mover_head}\")\n", - "\n", - "\n", - "def ablate_top_head_hook(z: Float[torch.Tensor, \"batch pos head_index d_head\"], hook):\n", - " z[:, -1, top_name_mover_head, :] = 0\n", - " return z\n", - "\n", - "\n", - "# Adds a hook into global model state\n", - "model.blocks[top_name_mover_layer].attn.hook_z.add_hook(ablate_top_head_hook)\n", - "# Runs the model, temporarily adds caching hooks and then removes *all* hooks after running, including the ablation hook.\n", - "ablated_logits, ablated_cache = model.run_with_cache(tokens)\n", - "print(f\"Original logit diff: {original_average_logit_diff:.2f}\")\n", - "print(\n", - " f\"Post ablation logit diff: {logits_to_ave_logit_diff(ablated_logits, answer_tokens).item():.2f}\"\n", - ")\n", - "print(\n", - " f\"Direct Logit Attribution of top name mover head: {per_head_logit_diffs.flatten()[top_name_mover].item():.2f}\"\n", - ")\n", - "print(\n", - " f\"Naive prediction of post ablation logit diff: {original_average_logit_diff - per_head_logit_diffs.flatten()[top_name_mover].item():.2f}\"\n", - ")" - ] + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "scaleanchor": "y", + "title": { + "text": "Head" + } }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Layer" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "So what's up with this? As before, we can look at the direct logit attribution of each head to see what's going on. It's easiest to interpret if plotted as a scatter plot against the initial per head logit difference.\n", - "\n", - "And we can see a *really* big difference in a few heads! (Hover to see labels) In particular the negative name mover L10H7 decreases its negative effect a lot, adding +1 to the logit diff, and the backup name mover L10H10 adjusts its effect to be more positive, adding +0.8 to the logit diff (with several other marginal changes). (And obviously the ablated head has gone down to zero!)" - ] + "hovertemplate": "%{hovertext}

Ablated=%{x}
Original=%{y}", + "hovertext": [ + "L0H0", + "L0H1", + "L0H2", + "L0H3", + "L0H4", + "L0H5", + "L0H6", + "L0H7", + "L0H8", + "L0H9", + "L0H10", + "L0H11", + "L1H0", + "L1H1", + "L1H2", + "L1H3", + "L1H4", + "L1H5", + "L1H6", + "L1H7", + "L1H8", + "L1H9", + "L1H10", + "L1H11", + "L2H0", + "L2H1", + "L2H2", + "L2H3", + "L2H4", + "L2H5", + "L2H6", + "L2H7", + "L2H8", + "L2H9", + "L2H10", + "L2H11", + "L3H0", + "L3H1", + "L3H2", + "L3H3", + "L3H4", + "L3H5", + "L3H6", + "L3H7", + "L3H8", + "L3H9", + "L3H10", + "L3H11", + "L4H0", + "L4H1", + "L4H2", + "L4H3", + "L4H4", + "L4H5", + "L4H6", + "L4H7", + "L4H8", + "L4H9", + "L4H10", + "L4H11", + "L5H0", + "L5H1", + "L5H2", + "L5H3", + "L5H4", + "L5H5", + "L5H6", + "L5H7", + "L5H8", + "L5H9", + "L5H10", + "L5H11", + "L6H0", + "L6H1", + "L6H2", + "L6H3", + "L6H4", + "L6H5", + "L6H6", + "L6H7", + "L6H8", + "L6H9", + "L6H10", + "L6H11", + "L7H0", + "L7H1", + "L7H2", + "L7H3", + "L7H4", + "L7H5", + "L7H6", + "L7H7", + "L7H8", + "L7H9", + "L7H10", + "L7H11", + "L8H0", + "L8H1", + "L8H2", + "L8H3", + "L8H4", + "L8H5", + "L8H6", + "L8H7", + "L8H8", + "L8H9", + "L8H10", + "L8H11", + "L9H0", + "L9H1", + "L9H2", + "L9H3", + "L9H4", + "L9H5", + "L9H6", + "L9H7", + "L9H8", + "L9H9", + "L9H10", + "L9H11", + "L10H0", + "L10H1", + "L10H2", + "L10H3", + "L10H4", + "L10H5", + "L10H6", + "L10H7", + "L10H8", + "L10H9", + "L10H10", + "L10H11", + "L11H0", + "L11H1", + "L11H2", + "L11H3", + "L11H4", + "L11H5", + "L11H6", + "L11H7", + "L11H8", + "L11H9", + "L11H10", + "L11H11" + ], + "legendgroup": "", + "marker": { + "color": "#636efa", + "symbol": "circle" + }, + "mode": "markers", + "name": "", + "orientation": "v", + "showlegend": false, + "type": "scatter", + "x": [ + -0.002156503964215517, + -0.0004650682385545224, + 0.00024167183437384665, + 0.0002806585980579257, + -0.0004162999684922397, + -0.0004892416181974113, + -0.002620948012918234, + -0.002935677068307996, + 0.00042561208829283714, + 0.0005418329383246601, + 0.00023754138965159655, + -7.48957390896976e-05, + -0.000658505829051137, + 0.0004060641804244369, + -0.0009330413886345923, + 0.0008937822422012687, + -0.0009785268921405077, + -0.000533820129930973, + -0.0027988189831376076, + -0.004214101936668158, + 0.002578593324869871, + 0.0024506838526576757, + 0.0005351756699383259, + 0.0012349633034318686, + 0.0009405204327777028, + -0.0011168691562488675, + -0.0011541967978700995, + -0.0015697095077484846, + -0.0005699327448382974, + 0.001451514894142747, + 0.002439911477267742, + 0.003158293664455414, + 0.000923738582059741, + -0.003578126197680831, + -0.0010650777257978916, + -0.0003558753523975611, + -0.0005624951445497572, + -1.1960582924075425e-05, + 0.0011531109921634197, + 0.0007360265008173883, + 0.0016493839211761951, + 0.0008800819050520658, + -0.0006905529880896211, + -0.003031972097232938, + 0.0008080147090367973, + 0.00010368914809077978, + -0.0005807994166389108, + -0.0011067037703469396, + -0.0026375530287623405, + 0.0002691895351745188, + -0.0016417437000200152, + -0.003406986128538847, + 0.0017449699807912111, + 0.00046454701805487275, + -0.0007899806369096041, + 0.0018328562146052718, + -0.00086324627045542, + -0.0003978293389081955, + 0.0007879206677898765, + -0.00012048585631418973, + 0.0008688560919836164, + 0.0009473530226387084, + -0.0022812988609075546, + -0.0011803123634308577, + 0.0002407809515716508, + -0.0004318578285165131, + -0.0003728170122485608, + -0.000738416681997478, + 0.0008113418589346111, + -0.00040444196201860905, + -0.007074396125972271, + 0.003946478478610516, + -0.014917617663741112, + -0.0022801742888987064, + 0.0022679336834698915, + -8.302251808345318e-05, + -0.004980948753654957, + 0.0027670026756823063, + 0.006266288459300995, + -0.003485947148874402, + -0.0013348984066396952, + -0.0017918883822858334, + -0.0012231896398589015, + 0.00040514359716326, + -0.0002460568503011018, + -0.005790225230157375, + -0.0004975841729901731, + 0.142182856798172, + -0.0014961492270231247, + -0.019006317481398582, + 0.003133433870971203, + -0.001858205534517765, + -0.011305196210741997, + 0.1922595500946045, + -0.0011892566690221429, + -0.0010282933944836259, + -0.0038003993686288595, + -0.0008570950012654066, + -0.013956742361187935, + 0.00828910805284977, + 0.004315475933253765, + -0.009073829278349876, + -0.08315148949623108, + 0.0034569751005619764, + -0.01805492490530014, + 0.002178061753511429, + 0.29780513048171997, + 0.02409379370510578, + 0.08904723823070526, + -0.0007931794971227646, + 0.07247699797153473, + 0.015016308054327965, + -0.02120928093791008, + 0.05205465108156204, + 1.4411165714263916, + 0.04743674397468567, + -0.03229031339287758, + 0, + 0.0019993737805634737, + -0.00807223655283451, + 0.8600788116455078, + 0.3260062038898468, + 0.16344408690929413, + 0.07133537530899048, + -0.00444837287068367, + 0.000681330740917474, + 0.36613449454307556, + -0.7105098962783813, + -0.002031375654041767, + -0.032143525779247284, + 1.2294330596923828, + 0.0018453558441251516, + 0.016877274960279465, + -0.001730365096591413, + -0.5010868310928345, + 0.02749764919281006, + -0.0059662917628884315, + -0.004944110754877329, + -0.08855228126049042, + 0.006622308399528265, + 0.044124361127614975, + -0.02726735547184944, + -1.134916067123413, + 0.02287953346967697 + ], + "xaxis": "x", + "y": [ + -0.0020563392899930477, + -0.0005101899732835591, + 0.0004685786843765527, + 0.00012512074317783117, + -0.0006028738571330905, + -0.0002429460291750729, + -0.0023189077619463205, + -0.002758360467851162, + 0.000564602785743773, + 0.0009697531932033598, + -0.0002504526637494564, + 4.737317794933915e-06, + -0.0010070882271975279, + 0.00039470894262194633, + -0.00154874159488827, + 0.0014034928753972054, + -0.0012653048615902662, + -0.0011358022456988692, + -0.00281596090644598, + -0.0029645217582583427, + 0.0029190476052463055, + 0.0025743592996150255, + 0.00036239007022231817, + 0.0017548729665577412, + 0.0005569400964304805, + -0.001126631861552596, + -0.0017353934235870838, + -0.0014514457434415817, + -0.00028735760133713484, + 0.0017211002996191382, + 0.0026658899150788784, + 0.00311466702260077, + 0.0005667927907779813, + -0.003666515462100506, + -0.0018847601022571325, + 7.039372576400638e-06, + -0.0007264417363330722, + 0.00011364505917299539, + 0.0014301587361842394, + 0.0007490540738217533, + 0.0020184689201414585, + 0.0007436950691044331, + -0.00046178390039131045, + -0.0039057559333741665, + 0.0011406694538891315, + -4.022853681817651e-05, + -0.0013293239753693342, + -0.0017636751290410757, + -0.0028280913829803467, + 0.00033634810824878514, + -0.0014248639345169067, + -0.003777273464947939, + 0.0015998880844563246, + 0.0002989505883306265, + -0.000804675742983818, + 0.002038792008534074, + -0.0015593919670209289, + -0.0006436670082621276, + 0.0011168173514306545, + -0.00035012533771805465, + 0.0011338205076754093, + 0.0011259170714765787, + -0.002516670385375619, + -0.0014790185960009694, + 0.0003878737334161997, + -6.408110493794084e-05, + -0.0005096744280308485, + -0.0008840755908749998, + 0.0006398351397365332, + -0.0010097370250150561, + -0.006759158335626125, + 0.0033667823299765587, + -0.01514742337167263, + -0.0021350777242332697, + 0.002593174111098051, + -0.00042678468162193894, + -0.005558924749493599, + 0.0026658528950065374, + 0.006411008536815643, + -0.003826778382062912, + -0.0003843410813715309, + -0.0016430341638624668, + -0.0013344454346224666, + -9.20506427064538e-05, + -9.476230479776859e-05, + -0.0057889921590685844, + -0.0006383581785485148, + 0.13493388891220093, + -0.001768707763403654, + -0.018917907029390335, + 0.003873429261147976, + -0.0021450775675475597, + -0.010327338241040707, + 0.18325845897197723, + -0.0007747983909212053, + -0.00104526337236166, + -0.003833949100226164, + -0.0008046097937040031, + -0.012673400342464447, + 0.00804573018103838, + 0.003604492638260126, + -0.009398287162184715, + -0.08272082358598709, + 0.003555194940418005, + -0.018404025584459305, + 0.0017587244510650635, + 0.2896133363246918, + 0.022854052484035492, + 0.08595258742570877, + -0.0006932877004146576, + 0.06817055493593216, + 0.013111240230500698, + -0.021098043769598007, + 0.05112447217106819, + 1.3844914436340332, + 0.045836858451366425, + -0.03830280900001526, + 2.985445976257324, + 0.0019662054255604744, + -0.008030137047171593, + 0.5608693957328796, + 0.17083050310611725, + -0.03361757844686508, + 0.05821544677019119, + -0.0024530249647796154, + 0.0018771197646856308, + 0.28827205300331116, + -1.8986485004425049, + -0.0015286931302398443, + -0.035129792988300323, + 0.4802178740501404, + -0.0009115453576669097, + 0.016075748950242996, + -0.03986122086644173, + -0.3879126012325287, + 0.011123123578727245, + -0.005477819126099348, + -0.0025129620917141438, + -0.08056175708770752, + 0.007518616039305925, + 0.0430111438035965, + -0.040082238614559174, + -0.9702364802360535, + 0.011862239800393581 + ], + "yaxis": "y" + } + ], + "layout": { + "legend": { + "tracegroupgap": 0 }, - { - "cell_type": "code", - "execution_count": 39, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tried to stack head results when they weren't cached. Computing head results now\n" - ] - }, - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ - { - "coloraxis": "coloraxis", - "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", - "name": "0", - "type": "heatmap", - "xaxis": "x", - "yaxis": "y", - "z": [ - [ - -0.002156503964215517, - -0.0004650682385545224, - 0.00024167183437384665, - 0.0002806585980579257, - -0.0004162999684922397, - -0.0004892416181974113, - -0.002620948012918234, - -0.002935677068307996, - 0.00042561208829283714, - 0.0005418329383246601, - 0.00023754138965159655, - -0.0000748957390896976 - ], - [ - -0.000658505829051137, - 0.0004060641804244369, - -0.0009330413886345923, - 0.0008937822422012687, - -0.0009785268921405077, - -0.000533820129930973, - -0.0027988189831376076, - -0.004214101936668158, - 0.002578593324869871, - 0.0024506838526576757, - 0.0005351756699383259, - 0.0012349633034318686 - ], - [ - 0.0009405204327777028, - -0.0011168691562488675, - -0.0011541967978700995, - -0.0015697095077484846, - -0.0005699327448382974, - 0.001451514894142747, - 0.002439911477267742, - 0.003158293664455414, - 0.000923738582059741, - -0.003578126197680831, - -0.0010650777257978916, - -0.0003558753523975611 - ], - [ - -0.0005624951445497572, - -0.000011960582924075425, - 0.0011531109921634197, - 0.0007360265008173883, - 0.0016493839211761951, - 0.0008800819050520658, - -0.0006905529880896211, - -0.003031972097232938, - 0.0008080147090367973, - 0.00010368914809077978, - -0.0005807994166389108, - -0.0011067037703469396 - ], - [ - -0.0026375530287623405, - 0.0002691895351745188, - -0.0016417437000200152, - -0.003406986128538847, - 0.0017449699807912111, - 0.00046454701805487275, - -0.0007899806369096041, - 0.0018328562146052718, - -0.00086324627045542, - -0.0003978293389081955, - 0.0007879206677898765, - -0.00012048585631418973 - ], - [ - 0.0008688560919836164, - 0.0009473530226387084, - -0.0022812988609075546, - -0.0011803123634308577, - 0.0002407809515716508, - -0.0004318578285165131, - -0.0003728170122485608, - -0.000738416681997478, - 0.0008113418589346111, - -0.00040444196201860905, - -0.007074396125972271, - 0.003946478478610516 - ], - [ - -0.014917617663741112, - -0.0022801742888987064, - 0.0022679336834698915, - -0.00008302251808345318, - -0.004980948753654957, - 0.0027670026756823063, - 0.006266288459300995, - -0.003485947148874402, - -0.0013348984066396952, - -0.0017918883822858334, - -0.0012231896398589015, - 0.00040514359716326 - ], - [ - -0.0002460568503011018, - -0.005790225230157375, - -0.0004975841729901731, - 0.142182856798172, - -0.0014961492270231247, - -0.019006317481398582, - 0.003133433870971203, - -0.001858205534517765, - -0.011305196210741997, - 0.1922595500946045, - -0.0011892566690221429, - -0.0010282933944836259 - ], - [ - -0.0038003993686288595, - -0.0008570950012654066, - -0.013956742361187935, - 0.00828910805284977, - 0.004315475933253765, - -0.009073829278349876, - -0.08315148949623108, - 0.0034569751005619764, - -0.01805492490530014, - 0.002178061753511429, - 0.29780513048171997, - 0.02409379370510578 - ], - [ - 0.08904723823070526, - -0.0007931794971227646, - 0.07247699797153473, - 0.015016308054327965, - -0.02120928093791008, - 0.05205465108156204, - 1.4411165714263916, - 0.04743674397468567, - -0.03229031339287758, - 0, - 0.0019993737805634737, - -0.00807223655283451 - ], - [ - 0.8600788116455078, - 0.3260062038898468, - 0.16344408690929413, - 0.07133537530899048, - -0.00444837287068367, - 0.000681330740917474, - 0.36613449454307556, - -0.7105098962783813, - -0.002031375654041767, - -0.032143525779247284, - 1.2294330596923828, - 0.0018453558441251516 - ], - [ - 0.016877274960279465, - -0.001730365096591413, - -0.5010868310928345, - 0.02749764919281006, - -0.0059662917628884315, - -0.004944110754877329, - -0.08855228126049042, - 0.006622308399528265, - 0.044124361127614975, - -0.02726735547184944, - -1.134916067123413, - 0.02287953346967697 - ] - ] - } - ], - "layout": { - "coloraxis": { - "cmid": 0, - "colorscale": [ - [ - 0, - "rgb(103,0,31)" - ], - [ - 0.1, - "rgb(178,24,43)" - ], - [ - 0.2, - "rgb(214,96,77)" - ], - [ - 0.3, - "rgb(244,165,130)" - ], - [ - 0.4, - "rgb(253,219,199)" - ], - [ - 0.5, - "rgb(247,247,247)" - ], - [ - 0.6, - "rgb(209,229,240)" - ], - [ - 0.7, - "rgb(146,197,222)" - ], - [ - 0.8, - "rgb(67,147,195)" - ], - [ - 0.9, - "rgb(33,102,172)" - ], - [ - 1, - "rgb(5,48,97)" - ] - ] - }, - "margin": { - "t": 60 - }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" - ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } - }, - "xaxis": { - "anchor": "y", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "scaleanchor": "y", - "title": { - "text": "Head" - } - }, - "yaxis": { - "anchor": "x", - "autorange": "reversed", - "constrain": "domain", - "domain": [ - 0, - 1 - ], - "title": { - "text": "Layer" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ - { - "hovertemplate": "%{hovertext}

Ablated=%{x}
Original=%{y}", - "hovertext": [ - "L0H0", - "L0H1", - "L0H2", - "L0H3", - "L0H4", - "L0H5", - "L0H6", - "L0H7", - "L0H8", - "L0H9", - "L0H10", - "L0H11", - "L1H0", - "L1H1", - "L1H2", - "L1H3", - "L1H4", - "L1H5", - "L1H6", - "L1H7", - "L1H8", - "L1H9", - "L1H10", - "L1H11", - "L2H0", - "L2H1", - "L2H2", - "L2H3", - "L2H4", - "L2H5", - "L2H6", - "L2H7", - "L2H8", - "L2H9", - "L2H10", - "L2H11", - "L3H0", - "L3H1", - "L3H2", - "L3H3", - "L3H4", - "L3H5", - "L3H6", - "L3H7", - "L3H8", - "L3H9", - "L3H10", - "L3H11", - "L4H0", - "L4H1", - "L4H2", - "L4H3", - "L4H4", - "L4H5", - "L4H6", - "L4H7", - "L4H8", - "L4H9", - "L4H10", - "L4H11", - "L5H0", - "L5H1", - "L5H2", - "L5H3", - "L5H4", - "L5H5", - "L5H6", - "L5H7", - "L5H8", - "L5H9", - "L5H10", - "L5H11", - "L6H0", - "L6H1", - "L6H2", - "L6H3", - "L6H4", - "L6H5", - "L6H6", - "L6H7", - "L6H8", - "L6H9", - "L6H10", - "L6H11", - "L7H0", - "L7H1", - "L7H2", - "L7H3", - "L7H4", - "L7H5", - "L7H6", - "L7H7", - "L7H8", - "L7H9", - "L7H10", - "L7H11", - "L8H0", - "L8H1", - "L8H2", - "L8H3", - "L8H4", - "L8H5", - "L8H6", - "L8H7", - "L8H8", - "L8H9", - "L8H10", - "L8H11", - "L9H0", - "L9H1", - "L9H2", - "L9H3", - "L9H4", - "L9H5", - "L9H6", - "L9H7", - "L9H8", - "L9H9", - "L9H10", - "L9H11", - "L10H0", - "L10H1", - "L10H2", - "L10H3", - "L10H4", - "L10H5", - "L10H6", - "L10H7", - "L10H8", - "L10H9", - "L10H10", - "L10H11", - "L11H0", - "L11H1", - "L11H2", - "L11H3", - "L11H4", - "L11H5", - "L11H6", - "L11H7", - "L11H8", - "L11H9", - "L11H10", - "L11H11" - ], - "legendgroup": "", - "marker": { - "color": "#636efa", - "symbol": "circle" - }, - "mode": "markers", - "name": "", - "orientation": "v", - "showlegend": false, - "type": "scatter", - "x": [ - -0.002156503964215517, - -0.0004650682385545224, - 0.00024167183437384665, - 0.0002806585980579257, - -0.0004162999684922397, - -0.0004892416181974113, - -0.002620948012918234, - -0.002935677068307996, - 0.00042561208829283714, - 0.0005418329383246601, - 0.00023754138965159655, - -0.0000748957390896976, - -0.000658505829051137, - 0.0004060641804244369, - -0.0009330413886345923, - 0.0008937822422012687, - -0.0009785268921405077, - -0.000533820129930973, - -0.0027988189831376076, - -0.004214101936668158, - 0.002578593324869871, - 0.0024506838526576757, - 0.0005351756699383259, - 0.0012349633034318686, - 0.0009405204327777028, - -0.0011168691562488675, - -0.0011541967978700995, - -0.0015697095077484846, - -0.0005699327448382974, - 0.001451514894142747, - 0.002439911477267742, - 0.003158293664455414, - 0.000923738582059741, - -0.003578126197680831, - -0.0010650777257978916, - -0.0003558753523975611, - -0.0005624951445497572, - -0.000011960582924075425, - 0.0011531109921634197, - 0.0007360265008173883, - 0.0016493839211761951, - 0.0008800819050520658, - -0.0006905529880896211, - -0.003031972097232938, - 0.0008080147090367973, - 0.00010368914809077978, - -0.0005807994166389108, - -0.0011067037703469396, - -0.0026375530287623405, - 0.0002691895351745188, - -0.0016417437000200152, - -0.003406986128538847, - 0.0017449699807912111, - 0.00046454701805487275, - -0.0007899806369096041, - 0.0018328562146052718, - -0.00086324627045542, - -0.0003978293389081955, - 0.0007879206677898765, - -0.00012048585631418973, - 0.0008688560919836164, - 0.0009473530226387084, - -0.0022812988609075546, - -0.0011803123634308577, - 0.0002407809515716508, - -0.0004318578285165131, - -0.0003728170122485608, - -0.000738416681997478, - 0.0008113418589346111, - -0.00040444196201860905, - -0.007074396125972271, - 0.003946478478610516, - -0.014917617663741112, - -0.0022801742888987064, - 0.0022679336834698915, - -0.00008302251808345318, - -0.004980948753654957, - 0.0027670026756823063, - 0.006266288459300995, - -0.003485947148874402, - -0.0013348984066396952, - -0.0017918883822858334, - -0.0012231896398589015, - 0.00040514359716326, - -0.0002460568503011018, - -0.005790225230157375, - -0.0004975841729901731, - 0.142182856798172, - -0.0014961492270231247, - -0.019006317481398582, - 0.003133433870971203, - -0.001858205534517765, - -0.011305196210741997, - 0.1922595500946045, - -0.0011892566690221429, - -0.0010282933944836259, - -0.0038003993686288595, - -0.0008570950012654066, - -0.013956742361187935, - 0.00828910805284977, - 0.004315475933253765, - -0.009073829278349876, - -0.08315148949623108, - 0.0034569751005619764, - -0.01805492490530014, - 0.002178061753511429, - 0.29780513048171997, - 0.02409379370510578, - 0.08904723823070526, - -0.0007931794971227646, - 0.07247699797153473, - 0.015016308054327965, - -0.02120928093791008, - 0.05205465108156204, - 1.4411165714263916, - 0.04743674397468567, - -0.03229031339287758, - 0, - 0.0019993737805634737, - -0.00807223655283451, - 0.8600788116455078, - 0.3260062038898468, - 0.16344408690929413, - 0.07133537530899048, - -0.00444837287068367, - 0.000681330740917474, - 0.36613449454307556, - -0.7105098962783813, - -0.002031375654041767, - -0.032143525779247284, - 1.2294330596923828, - 0.0018453558441251516, - 0.016877274960279465, - -0.001730365096591413, - -0.5010868310928345, - 0.02749764919281006, - -0.0059662917628884315, - -0.004944110754877329, - -0.08855228126049042, - 0.006622308399528265, - 0.044124361127614975, - -0.02726735547184944, - -1.134916067123413, - 0.02287953346967697 - ], - "xaxis": "x", - "y": [ - -0.0020563392899930477, - -0.0005101899732835591, - 0.0004685786843765527, - 0.00012512074317783117, - -0.0006028738571330905, - -0.0002429460291750729, - -0.0023189077619463205, - -0.002758360467851162, - 0.000564602785743773, - 0.0009697531932033598, - -0.0002504526637494564, - 0.000004737317794933915, - -0.0010070882271975279, - 0.00039470894262194633, - -0.00154874159488827, - 0.0014034928753972054, - -0.0012653048615902662, - -0.0011358022456988692, - -0.00281596090644598, - -0.0029645217582583427, - 0.0029190476052463055, - 0.0025743592996150255, - 0.00036239007022231817, - 0.0017548729665577412, - 0.0005569400964304805, - -0.001126631861552596, - -0.0017353934235870838, - -0.0014514457434415817, - -0.00028735760133713484, - 0.0017211002996191382, - 0.0026658899150788784, - 0.00311466702260077, - 0.0005667927907779813, - -0.003666515462100506, - -0.0018847601022571325, - 0.000007039372576400638, - -0.0007264417363330722, - 0.00011364505917299539, - 0.0014301587361842394, - 0.0007490540738217533, - 0.0020184689201414585, - 0.0007436950691044331, - -0.00046178390039131045, - -0.0039057559333741665, - 0.0011406694538891315, - -0.00004022853681817651, - -0.0013293239753693342, - -0.0017636751290410757, - -0.0028280913829803467, - 0.00033634810824878514, - -0.0014248639345169067, - -0.003777273464947939, - 0.0015998880844563246, - 0.0002989505883306265, - -0.000804675742983818, - 0.002038792008534074, - -0.0015593919670209289, - -0.0006436670082621276, - 0.0011168173514306545, - -0.00035012533771805465, - 0.0011338205076754093, - 0.0011259170714765787, - -0.002516670385375619, - -0.0014790185960009694, - 0.0003878737334161997, - -0.00006408110493794084, - -0.0005096744280308485, - -0.0008840755908749998, - 0.0006398351397365332, - -0.0010097370250150561, - -0.006759158335626125, - 0.0033667823299765587, - -0.01514742337167263, - -0.0021350777242332697, - 0.002593174111098051, - -0.00042678468162193894, - -0.005558924749493599, - 0.0026658528950065374, - 0.006411008536815643, - -0.003826778382062912, - -0.0003843410813715309, - -0.0016430341638624668, - -0.0013344454346224666, - -0.0000920506427064538, - -0.00009476230479776859, - -0.0057889921590685844, - -0.0006383581785485148, - 0.13493388891220093, - -0.001768707763403654, - -0.018917907029390335, - 0.003873429261147976, - -0.0021450775675475597, - -0.010327338241040707, - 0.18325845897197723, - -0.0007747983909212053, - -0.00104526337236166, - -0.003833949100226164, - -0.0008046097937040031, - -0.012673400342464447, - 0.00804573018103838, - 0.003604492638260126, - -0.009398287162184715, - -0.08272082358598709, - 0.003555194940418005, - -0.018404025584459305, - 0.0017587244510650635, - 0.2896133363246918, - 0.022854052484035492, - 0.08595258742570877, - -0.0006932877004146576, - 0.06817055493593216, - 0.013111240230500698, - -0.021098043769598007, - 0.05112447217106819, - 1.3844914436340332, - 0.045836858451366425, - -0.03830280900001526, - 2.985445976257324, - 0.0019662054255604744, - -0.008030137047171593, - 0.5608693957328796, - 0.17083050310611725, - -0.03361757844686508, - 0.05821544677019119, - -0.0024530249647796154, - 0.0018771197646856308, - 0.28827205300331116, - -1.8986485004425049, - -0.0015286931302398443, - -0.035129792988300323, - 0.4802178740501404, - -0.0009115453576669097, - 0.016075748950242996, - -0.03986122086644173, - -0.3879126012325287, - 0.011123123578727245, - -0.005477819126099348, - -0.0025129620917141438, - -0.08056175708770752, - 0.007518616039305925, - 0.0430111438035965, - -0.040082238614559174, - -0.9702364802360535, - 0.011862239800393581 - ], - "yaxis": "y" - } - ], - "layout": { - "legend": { - "tracegroupgap": 0 - }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" - ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } - }, - "title": { - "text": "Original vs Post-Ablation Direct Logit Attribution of Heads" - }, - "xaxis": { - "anchor": "y", - "domain": [ - 0, - 1 - ], - "range": [ - -3, - 3 - ], - "title": { - "text": "Ablated" - } - }, - "yaxis": { - "anchor": "x", - "domain": [ - 0, - 1 - ], - "range": [ - -3, - 3 - ], - "title": { - "text": "Original" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "per_head_ablated_residual, labels = ablated_cache.stack_head_results(\n", - " layer=-1, pos_slice=-1, return_labels=True\n", - ")\n", - "per_head_ablated_logit_diffs = residual_stack_to_logit_diff(\n", - " per_head_ablated_residual, ablated_cache\n", - ")\n", - "per_head_ablated_logit_diffs = per_head_ablated_logit_diffs.reshape(\n", - " model.cfg.n_layers, model.cfg.n_heads\n", - ")\n", - "imshow(per_head_ablated_logit_diffs, labels={\"x\": \"Head\", \"y\": \"Layer\"})\n", - "scatter(\n", - " y=per_head_logit_diffs.flatten(),\n", - " x=per_head_ablated_logit_diffs.flatten(),\n", - " hover_name=head_labels,\n", - " range_x=(-3, 3),\n", - " range_y=(-3, 3),\n", - " xaxis=\"Ablated\",\n", - " yaxis=\"Original\",\n", - " title=\"Original vs Post-Ablation Direct Logit Attribution of Heads\",\n", - ")" + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "One natural hypothesis is that this is because the final LayerNorm scaling has changed, which can scale up or down the final residual stream. This is slightly true, and we can see that the typical head is a bit off from the x=y line. But the average LN scaling ratio is 1.04, and this should uniformly change *all* heads by the same factor, so this can't be sufficient" + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] - }, - { - "cell_type": "code", - "execution_count": 40, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Average LN scaling ratio: 1.042\n", - "Ablation LN scale tensor([[18.5200],\n", - " [17.4700],\n", - " [17.8200],\n", - " [17.5100],\n", - " [17.2600],\n", - " [18.2500],\n", - " [16.1800],\n", - " [17.4300]])\n", - "Original LN scale tensor([[19.5700],\n", - " [18.3500],\n", - " [18.2900],\n", - " [18.6800],\n", - " [17.4900],\n", - " [18.8700],\n", - " [16.4200],\n", - " [18.6800]])\n" - ] - } - ], - "source": [ - "print(\n", - " \"Average LN scaling ratio:\",\n", - " round(\n", - " (\n", - " cache[\"ln_final.hook_scale\"][:, -1]\n", - " / ablated_cache[\"ln_final.hook_scale\"][:, -1]\n", - " )\n", - " .mean()\n", - " .item(),\n", - " 3,\n", - " ),\n", - ")\n", - "print(\n", - " \"Ablation LN scale\",\n", - " ablated_cache[\"ln_final.hook_scale\"][:, -1].detach().cpu().round(decimals=2),\n", - ")\n", - "print(\n", - " \"Original LN scale\",\n", - " cache[\"ln_final.hook_scale\"][:, -1].detach().cpu().round(decimals=2),\n", - ")" + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Exercise to the reader:** Can you finish off this analysis? What's going on here? Why are the backup name movers changing their behaviour? Why is one negative name mover becoming significantly less important?" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" + "title": { + "text": "Original vs Post-Ablation Direct Logit Attribution of Heads" }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" + "xaxis": { + "anchor": "y", + "domain": [ + 0, + 1 + ], + "range": [ + -3, + 3 + ], + "title": { + "text": "Ablated" + } }, - "vscode": { - "interpreter": { - "hash": "eb812820b5094695c8a581672e17220e30dd2c15d704c018326e3cc2e1a566f1" - } + "yaxis": { + "anchor": "x", + "domain": [ + 0, + 1 + ], + "range": [ + -3, + 3 + ], + "title": { + "text": "Original" + } } - }, - "nbformat": 4, - "nbformat_minor": 2 -} + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "per_head_ablated_residual, labels = ablated_cache.stack_head_results(\n", + " layer=-1, pos_slice=-1, return_labels=True\n", + ")\n", + "per_head_ablated_logit_diffs = residual_stack_to_logit_diff(\n", + " per_head_ablated_residual, ablated_cache\n", + ")\n", + "per_head_ablated_logit_diffs = per_head_ablated_logit_diffs.reshape(\n", + " model.cfg.n_layers, model.cfg.n_heads\n", + ")\n", + "imshow(per_head_ablated_logit_diffs, labels={\"x\": \"Head\", \"y\": \"Layer\"})\n", + "scatter(\n", + " y=per_head_logit_diffs.flatten(),\n", + " x=per_head_ablated_logit_diffs.flatten(),\n", + " hover_name=head_labels,\n", + " range_x=(-3, 3),\n", + " range_y=(-3, 3),\n", + " xaxis=\"Ablated\",\n", + " yaxis=\"Original\",\n", + " title=\"Original vs Post-Ablation Direct Logit Attribution of Heads\",\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "One natural hypothesis is that this is because the final LayerNorm scaling has changed, which can scale up or down the final residual stream. This is slightly true, and we can see that the typical head is a bit off from the x=y line. But the average LN scaling ratio is 1.04, and this should uniformly change *all* heads by the same factor, so this can't be sufficient" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average LN scaling ratio: 1.042\n", + "Ablation LN scale tensor([[18.5200],\n", + " [17.4700],\n", + " [17.8200],\n", + " [17.5100],\n", + " [17.2600],\n", + " [18.2500],\n", + " [16.1800],\n", + " [17.4300]])\n", + "Original LN scale tensor([[19.5700],\n", + " [18.3500],\n", + " [18.2900],\n", + " [18.6800],\n", + " [17.4900],\n", + " [18.8700],\n", + " [16.4200],\n", + " [18.6800]])\n" + ] + } + ], + "source": [ + "print(\n", + " \"Average LN scaling ratio:\",\n", + " round(\n", + " (\n", + " cache[\"ln_final.hook_scale\"][:, -1]\n", + " / ablated_cache[\"ln_final.hook_scale\"][:, -1]\n", + " )\n", + " .mean()\n", + " .item(),\n", + " 3,\n", + " ),\n", + ")\n", + "print(\n", + " \"Ablation LN scale\",\n", + " ablated_cache[\"ln_final.hook_scale\"][:, -1].detach().cpu().round(decimals=2),\n", + ")\n", + "print(\n", + " \"Original LN scale\",\n", + " cache[\"ln_final.hook_scale\"][:, -1].detach().cpu().round(decimals=2),\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Exercise to the reader:** Can you finish off this analysis? What's going on here? Why are the backup name movers changing their behaviour? Why is one negative name mover becoming significantly less important?" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + }, + "vscode": { + "interpreter": { + "hash": "eb812820b5094695c8a581672e17220e30dd2c15d704c018326e3cc2e1a566f1" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file From 96c11979f9f6403f6376bb1861c7cd0cee7554a5 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Mon, 2 Mar 2026 14:33:51 -0600 Subject: [PATCH 4/7] Work on exploratory analysis demo updates --- demos/Exploratory_Analysis_Demo.ipynb | 5858 +++++++------------------ 1 file changed, 1471 insertions(+), 4387 deletions(-) diff --git a/demos/Exploratory_Analysis_Demo.ipynb b/demos/Exploratory_Analysis_Demo.ipynb index 0ec844270..af53b6073 100644 --- a/demos/Exploratory_Analysis_Demo.ipynb +++ b/demos/Exploratory_Analysis_Demo.ipynb @@ -63,7 +63,14 @@ { "cell_type": "code", "execution_count": 1, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:00:57.561568Z", + "iopub.status.busy": "2026-03-02T20:00:57.561326Z", + "iopub.status.idle": "2026-03-02T20:00:57.569271Z", + "shell.execute_reply": "2026-03-02T20:00:57.568798Z" + } + }, "outputs": [], "source": [ "\n", @@ -100,8 +107,15 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 2, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:00:57.571199Z", + "iopub.status.busy": "2026-03-02T20:00:57.571053Z", + "iopub.status.idle": "2026-03-02T20:00:59.542791Z", + "shell.execute_reply": "2026-03-02T20:00:59.542488Z" + } + }, "outputs": [], "source": [ "from functools import partial\n", @@ -139,7 +153,14 @@ { "cell_type": "code", "execution_count": 3, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:00:59.544042Z", + "iopub.status.busy": "2026-03-02T20:00:59.543912Z", + "iopub.status.idle": "2026-03-02T20:00:59.545310Z", + "shell.execute_reply": "2026-03-02T20:00:59.545099Z" + } + }, "outputs": [ { "name": "stdout", @@ -171,7 +192,14 @@ { "cell_type": "code", "execution_count": 4, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:00:59.558227Z", + "iopub.status.busy": "2026-03-02T20:00:59.558155Z", + "iopub.status.idle": "2026-03-02T20:00:59.559870Z", + "shell.execute_reply": "2026-03-02T20:00:59.559678Z" + } + }, "outputs": [], "source": [ "def imshow(tensor, **kwargs):\n", @@ -251,9 +279,31 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 5, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:00:59.560745Z", + "iopub.status.busy": "2026-03-02T20:00:59.560687Z", + "iopub.status.idle": "2026-03-02T20:01:00.877227Z", + "shell.execute_reply": "2026-03-02T20:01:00.876923Z" + } + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "09c32ea4455548fdbb35d27edb7d40c3", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading weights: 0%| | 0/148 [00:00 1:\n", + " cache.cache_dict[\"hook_pos_embed\"] = pe.expand(tokens.shape[0], -1, -1)" ] }, { @@ -488,7 +578,14 @@ { "cell_type": "code", "execution_count": 10, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:01:01.572612Z", + "iopub.status.busy": "2026-03-02T20:01:01.572539Z", + "iopub.status.idle": "2026-03-02T20:01:01.681618Z", + "shell.execute_reply": "2026-03-02T20:01:01.681403Z" + } + }, "outputs": [ { "name": "stdout", @@ -625,7 +722,14 @@ { "cell_type": "code", "execution_count": 11, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:01:01.682582Z", + "iopub.status.busy": "2026-03-02T20:01:01.682527Z", + "iopub.status.idle": "2026-03-02T20:01:01.691522Z", + "shell.execute_reply": "2026-03-02T20:01:01.691328Z" + } + }, "outputs": [ { "name": "stdout", @@ -637,12 +741,18 @@ } ], "source": [ - "answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)\n", + "# TransformerBridge doesn't have tokens_to_residual_directions yet,\n", + "# so we implement it inline using model.unembed.W_U\n", + "W_U = model.unembed.W_U # [d_model, d_vocab]\n", + "answer_residual_directions = W_U[:, answer_tokens]\n", + "answer_residual_directions = einops.rearrange(\n", + " answer_residual_directions, \"d_model batch correct_incorrect -> batch correct_incorrect d_model\"\n", + ")\n", "print(\"Answer residual directions shape:\", answer_residual_directions.shape)\n", "logit_diff_directions = (\n", " answer_residual_directions[:, 0] - answer_residual_directions[:, 1]\n", ")\n", - "print(\"Logit difference directions shape:\", logit_diff_directions.shape)" + "print(\"Logit difference directions shape:\", logit_diff_directions.shape)\n" ] }, { @@ -669,7 +779,14 @@ { "cell_type": "code", "execution_count": 12, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:01:01.692474Z", + "iopub.status.busy": "2026-03-02T20:01:01.692402Z", + "iopub.status.idle": "2026-03-02T20:01:01.802781Z", + "shell.execute_reply": "2026-03-02T20:01:01.802561Z" + } + }, "outputs": [ { "name": "stdout", @@ -718,7 +835,14 @@ { "cell_type": "code", "execution_count": 13, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:01:01.803733Z", + "iopub.status.busy": "2026-03-02T20:01:01.803672Z", + "iopub.status.idle": "2026-03-02T20:01:01.805237Z", + "shell.execute_reply": "2026-03-02T20:01:01.805052Z" + } + }, "outputs": [], "source": [ "def residual_stack_to_logit_diff(\n", @@ -756,7 +880,14 @@ { "cell_type": "code", "execution_count": 14, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:01:01.806104Z", + "iopub.status.busy": "2026-03-02T20:01:01.806045Z", + "iopub.status.idle": "2026-03-02T20:01:02.394986Z", + "shell.execute_reply": "2026-03-02T20:01:02.394746Z" + } + }, "outputs": [ { "data": { @@ -807,61 +938,15 @@ "orientation": "v", "showlegend": false, "type": "scatter", - "x": [ - 0, - 0.5, - 1, - 1.5, - 2, - 2.5, - 3, - 3.5, - 4, - 4.5, - 5, - 5.5, - 6, - 6.5, - 7, - 7.5, - 8, - 8.5, - 9, - 9.5, - 10, - 10.5, - 11, - 11.5, - 12 - ], + "x": { + "bdata": "AAAAAAAAAAAAAAAAAADgPwAAAAAAAPA/AAAAAAAA+D8AAAAAAAAAQAAAAAAAAARAAAAAAAAACEAAAAAAAAAMQAAAAAAAABBAAAAAAAAAEkAAAAAAAAAUQAAAAAAAABZAAAAAAAAAGEAAAAAAAAAaQAAAAAAAABxAAAAAAAAAHkAAAAAAAAAgQAAAAAAAACFAAAAAAAAAIkAAAAAAAAAjQAAAAAAAACRAAAAAAAAAJUAAAAAAAAAmQAAAAAAAACdAAAAAAAAAKEA=", + "dtype": "f8" + }, "xaxis": "x", - "y": [ - 1.2937933206558228e-05, - -0.006643360480666161, - -0.007525032386183739, - -0.009075596928596497, - -0.008736769668757915, - -0.008685456588864326, - -0.006480347365140915, - -0.007939882576465607, - -0.009661720134317875, - -0.015095856040716171, - -0.01419061329215765, - -0.019930001348257065, - -0.00912435818463564, - -0.027298055589199066, - -0.02985510788857937, - 0.2497255504131317, - 0.250558078289032, - 0.45005205273628235, - 0.45996904373168945, - 5.02545166015625, - 5.142900466918945, - 4.730565071105957, - 4.887058258056641, - 3.445383071899414, - 3.5518720149993896 - ], + "y": { + "bdata": "QBlZNxyx2bumlPa7BrIUvKIkD7zfTQ68elnUu+gVAry0TR68V1Z3vNN/aLyOQ6O8oX0VvEef37xPkfS8mLh/PnJJgD6VbeY+ZoHrPprQoEDBkqRA8GCXQPJinEBsgVxAHVJjQA==", + "dtype": "f4" + }, "yaxis": "y" } ], @@ -945,7 +1030,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -981,7 +1066,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -1005,7 +1090,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -1041,64 +1126,13 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], "type": "heatmap" } ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], "histogram": [ { "marker": { @@ -1119,7 +1153,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -1155,7 +1189,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -1170,7 +1204,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -1206,7 +1240,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -1299,6 +1333,17 @@ "type": "scattergl" } ], + "scattermap": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermap" + } + ], "scattermapbox": [ { "marker": { @@ -1351,7 +1396,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -1387,7 +1432,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -1478,7 +1523,7 @@ ], "sequential": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -1514,13 +1559,13 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], "sequentialminus": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -1556,7 +1601,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ] @@ -1691,8 +1736,8 @@ "xaxis": { "anchor": "y", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "title": { "text": "x" @@ -1701,8 +1746,8 @@ "yaxis": { "anchor": "x", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "title": { "text": "y" @@ -1749,7 +1794,14 @@ { "cell_type": "code", "execution_count": 15, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:01:02.395953Z", + "iopub.status.busy": "2026-03-02T20:01:02.395892Z", + "iopub.status.idle": "2026-03-02T20:01:02.416037Z", + "shell.execute_reply": "2026-03-02T20:01:02.415827Z" + } + }, "outputs": [ { "data": { @@ -1801,63 +1853,15 @@ "orientation": "v", "showlegend": false, "type": "scatter", - "x": [ - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - 10, - 11, - 12, - 13, - 14, - 15, - 16, - 17, - 18, - 19, - 20, - 21, - 22, - 23, - 24, - 25 - ], + "x": { + "bdata": "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBk=", + "dtype": "i1" + }, "xaxis": "x", - "y": [ - -0.00028366726473905146, - 0.00029660604195669293, - -0.0066563040018081665, - -0.0008816685294732451, - -0.0015505650080740452, - 0.00033882574643939734, - 5.131529178470373e-05, - 0.0022051138803362846, - -0.0014595506945624948, - -0.0017218313878402114, - -0.005434143822640181, - 0.0009052485693246126, - -0.0057394010946154594, - 0.010805649682879448, - -0.018173698335886, - -0.002557049971073866, - 0.27958065271377563, - 0.0008325176313519478, - 0.19949400424957275, - 0.00991708692163229, - 4.565483093261719, - 0.11744903028011322, - -0.4123360514640808, - 0.15649384260177612, - -1.4416757822036743, - 0.10648896545171738 - ], + "y": { + "bdata": "M7mUuRWCmzmcHdq7ahxnulg9y7qIq7E54MNWOJqEEDt3Sb+6Xb7hukwRsrsgaG06kw68u3oJMTx84JS8WJAnu2Iljz78TVo6RkhMPvV5IjyEGJJAbInwPQUd075YQCA+9Ii4vzAW2j0=", + "dtype": "f4" + }, "yaxis": "y" } ], @@ -1941,7 +1945,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -1977,7 +1981,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -2001,7 +2005,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -2037,64 +2041,13 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], "type": "heatmap" } ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], "histogram": [ { "marker": { @@ -2115,7 +2068,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -2151,7 +2104,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -2166,7 +2119,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -2202,7 +2155,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -2295,6 +2248,17 @@ "type": "scattergl" } ], + "scattermap": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermap" + } + ], "scattermapbox": [ { "marker": { @@ -2347,7 +2311,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -2383,7 +2347,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -2474,7 +2438,7 @@ ], "sequential": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -2510,13 +2474,13 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], "sequentialminus": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -2552,7 +2516,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ] @@ -2687,8 +2651,8 @@ "xaxis": { "anchor": "y", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "title": { "text": "x" @@ -2697,8 +2661,8 @@ "yaxis": { "anchor": "x", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "title": { "text": "y" @@ -2744,7 +2708,14 @@ { "cell_type": "code", "execution_count": 16, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:01:02.416994Z", + "iopub.status.busy": "2026-03-02T20:01:02.416933Z", + "iopub.status.idle": "2026-03-02T20:01:02.588061Z", + "shell.execute_reply": "2026-03-02T20:01:02.587851Z" + } + }, "outputs": [ { "name": "stdout", @@ -2767,184 +2738,19 @@ "type": "heatmap", "xaxis": "x", "yaxis": "y", - "z": [ - [ - -0.0020563392899930477, - -0.0005101899732835591, - 0.0004685786843765527, - 0.00012512074317783117, - -0.0006028738571330905, - -0.0002429460291750729, - -0.0023189077619463205, - -0.002758360467851162, - 0.000564602785743773, - 0.0009697531932033598, - -0.0002504526637494564, - 4.737317794933915e-06 - ], - [ - -0.0010070882271975279, - 0.00039470894262194633, - -0.00154874159488827, - 0.0014034928753972054, - -0.0012653048615902662, - -0.0011358022456988692, - -0.00281596090644598, - -0.0029645217582583427, - 0.0029190476052463055, - 0.0025743592996150255, - 0.00036239007022231817, - 0.0017548729665577412 - ], - [ - 0.0005569400964304805, - -0.001126631861552596, - -0.0017353934235870838, - -0.0014514457434415817, - -0.00028735760133713484, - 0.0017211002996191382, - 0.0026658899150788784, - 0.00311466702260077, - 0.0005667927907779813, - -0.003666515462100506, - -0.0018847601022571325, - 7.039372576400638e-06 - ], - [ - -0.0007264417363330722, - 0.00011364505917299539, - 0.0014301587361842394, - 0.0007490540738217533, - 0.0020184689201414585, - 0.0007436950691044331, - -0.00046178390039131045, - -0.0039057559333741665, - 0.0011406694538891315, - -4.022853681817651e-05, - -0.0013293239753693342, - -0.0017636751290410757 - ], - [ - -0.0028280913829803467, - 0.00033634810824878514, - -0.0014248639345169067, - -0.003777273464947939, - 0.0015998880844563246, - 0.0002989505883306265, - -0.000804675742983818, - 0.002038792008534074, - -0.0015593919670209289, - -0.0006436670082621276, - 0.0011168173514306545, - -0.00035012533771805465 - ], - [ - 0.0011338205076754093, - 0.0011259170714765787, - -0.002516670385375619, - -0.0014790185960009694, - 0.0003878737334161997, - -6.408110493794084e-05, - -0.0005096744280308485, - -0.0008840755908749998, - 0.0006398351397365332, - -0.0010097370250150561, - -0.006759158335626125, - 0.0033667823299765587 - ], - [ - -0.01514742337167263, - -0.0021350777242332697, - 0.002593174111098051, - -0.00042678468162193894, - -0.005558924749493599, - 0.0026658528950065374, - 0.006411008536815643, - -0.003826778382062912, - -0.0003843410813715309, - -0.0016430341638624668, - -0.0013344454346224666, - -9.20506427064538e-05 - ], - [ - -9.476230479776859e-05, - -0.0057889921590685844, - -0.0006383581785485148, - 0.13493388891220093, - -0.001768707763403654, - -0.018917907029390335, - 0.003873429261147976, - -0.0021450775675475597, - -0.010327338241040707, - 0.18325845897197723, - -0.0007747983909212053, - -0.00104526337236166 - ], - [ - -0.003833949100226164, - -0.0008046097937040031, - -0.012673400342464447, - 0.00804573018103838, - 0.003604492638260126, - -0.009398287162184715, - -0.08272082358598709, - 0.003555194940418005, - -0.018404025584459305, - 0.0017587244510650635, - 0.2896133363246918, - 0.022854052484035492 - ], - [ - 0.08595258742570877, - -0.0006932877004146576, - 0.06817055493593216, - 0.013111240230500698, - -0.021098043769598007, - 0.05112447217106819, - 1.3844914436340332, - 0.045836858451366425, - -0.03830280900001526, - 2.985445976257324, - 0.0019662054255604744, - -0.008030137047171593 - ], - [ - 0.5608693957328796, - 0.17083050310611725, - -0.03361757844686508, - 0.05821544677019119, - -0.0024530249647796154, - 0.0018771197646856308, - 0.28827205300331116, - -1.8986485004425049, - -0.0015286931302398443, - -0.035129792988300323, - 0.4802178740501404, - -0.0009115453576669097 - ], - [ - 0.016075748950242996, - -0.03986122086644173, - -0.3879126012325287, - 0.011123123578727245, - -0.005477819126099348, - -0.0025129620917141438, - -0.08056175708770752, - 0.007518616039305925, - 0.0430111438035965, - -0.040082238614559174, - -0.9702364802360535, - 0.011862239800393581 - ] - ] + "z": { + "bdata": "isMGu+K+BbrkqvU5ijQDOaIJHrr0wn65qfgXu8jFNLttARQ66DZ+OqRNg7kAGZ82SwGEupDyzjm8/sq6avW3OtfXpbrN3pS61Is4u8BIQrs0TT87eLYoO6D9vTn2AuY6J/8ROpaqk7o/d+O6zD++upivlrmylOE6L7YuO3wfTDvFkhQ6aElwu6wH97oAC+w2YG4+usBt7ji2d7s6+ltEOgVIBDsY9EI6Ih7yuW71f7togpU6oAcpuPQ6rrosLOe6qlk5u4tmsDluw7q6Z4x3uxa00To+uJw5nepSupSbBTtHYcy6T74ouvdhkjq6jLe5EJyUOnWTkzqO6iS7VNzBushoyzmwXIa46JwFuhTEZ7r6vSc69liEutN63bvRpFw7Si14vIvrC7uW8Sk7fbrfuX4otrsati47bBTSOxDNerucfsm50VnXuvrnrrrwZcC4xfHGuHuxvbsLVye6USwKPh7X57rl+Zq8NNh9O3aUDLtmNCm8WKg7PiEcS7oa/4i6zkJ7u0vuUrr6o0+8/NEDPLY1bDvG+hm8lmmpvZ78aDtCxJa8xYDmOkhIlD4SObs81AewPQzANbrOnIs9MdFWPBDWrLy4Z1E9GzexP6C/Oz2+4xy9rhE/QAnbADuMkAO8TpUPPzLuLj5msgm9oXNuPerBILvoCfY6f5iTPvIG878XXsi6EOQPvWPf9T5w+m66cLGDPH9GI71vnMa+XD82PDZ/s7s9sSS7av2kvUBf9jt4LDA9Fy0kvbVheL/UWkI8", + "dtype": "f4", + "shape": "12, 12" + } } ], "layout": { "coloraxis": { - "cmid": 0, + "cmid": 0.0, "colorscale": [ [ - 0, + 0.0, "rgb(103,0,31)" ], [ @@ -2984,7 +2790,7 @@ "rgb(33,102,172)" ], [ - 1, + 1.0, "rgb(5,48,97)" ] ] @@ -3065,7 +2871,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -3101,7 +2907,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -3125,7 +2931,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -3161,14 +2967,26 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], "type": "heatmap" } ], - "heatmapgl": [ + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ { "colorbar": { "outlinewidth": 0, @@ -3176,7 +2994,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -3212,26 +3030,14 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" + "type": "histogram2d" } ], - "histogram2d": [ + "histogram2dcontour": [ { "colorbar": { "outlinewidth": 0, @@ -3239,7 +3045,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -3275,65 +3081,14 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], - "type": "histogram2d" + "type": "histogram2dcontour" } ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ + "mesh3d": [ { "colorbar": { "outlinewidth": 0, @@ -3419,6 +3174,17 @@ "type": "scattergl" } ], + "scattermap": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermap" + } + ], "scattermapbox": [ { "marker": { @@ -3471,7 +3237,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -3507,7 +3273,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -3598,7 +3364,7 @@ ], "sequential": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -3634,13 +3400,13 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], "sequentialminus": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -3676,7 +3442,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ] @@ -3812,8 +3578,8 @@ "anchor": "y", "constrain": "domain", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "scaleanchor": "y", "title": { @@ -3825,8 +3591,8 @@ "autorange": "reversed", "constrain": "domain", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "title": { "text": "Layer" @@ -3840,6 +3606,13 @@ } ], "source": [ + "# hook_result shape: Bridge captures [batch,pos,d_model] not [batch,pos,n_heads,d_head]\n", + "# Remove so compute_head_results can recompute from z + W_O\n", + "for layer in range(model.cfg.n_layers):\n", + " key = f\"blocks.{layer}.attn.hook_result\"\n", + " if key in cache.cache_dict and cache.cache_dict[key].ndim == 3:\n", + " del cache.cache_dict[key]\n", + "\n", "per_head_residual, labels = cache.stack_head_results(\n", " layer=-1, pos_slice=-1, return_labels=True\n", ")\n", @@ -3880,7 +3653,14 @@ { "cell_type": "code", "execution_count": 17, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:01:02.589066Z", + "iopub.status.busy": "2026-03-02T20:01:02.589003Z", + "iopub.status.idle": "2026-03-02T20:01:02.591425Z", + "shell.execute_reply": "2026-03-02T20:01:02.591254Z" + } + }, "outputs": [], "source": [ "def visualize_attention_patterns(\n", @@ -3894,6 +3674,11 @@ " if isinstance(heads, int):\n", " heads = [heads]\n", "\n", + " # Handle empty head list\n", + " if len(heads) == 0:\n", + " title_html = f\"

{title}


\"\n", + " return f\"
{title_html}

No heads in this group

\"\n", + "\n", " # Create the plotting data\n", " labels: List[str] = []\n", " patterns: List[Float[torch.Tensor, \"dest_pos src_pos\"]] = []\n", @@ -3941,26 +3726,33 @@ { "cell_type": "code", "execution_count": 18, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:01:02.592269Z", + "iopub.status.busy": "2026-03-02T20:01:02.592214Z", + "iopub.status.idle": "2026-03-02T20:01:02.739740Z", + "shell.execute_reply": "2026-03-02T20:01:02.739487Z" + } + }, "outputs": [ { "data": { "text/html": [ - "

Top 3 Positive Logit Attribution Heads


\n", + "

Top 3 Positive Logit Attribution Heads


\n", "

Top 3 Negative Logit Attribution Heads


\n", + "

Top 3 Negative Logit Attribution Heads


\n", "
" ], @@ -4085,7 +3877,14 @@ { "cell_type": "code", "execution_count": 19, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:01:02.740771Z", + "iopub.status.busy": "2026-03-02T20:01:02.740702Z", + "iopub.status.idle": "2026-03-02T20:01:02.868880Z", + "shell.execute_reply": "2026-03-02T20:01:02.868665Z" + } + }, "outputs": [ { "name": "stdout", @@ -4113,7 +3912,14 @@ { "cell_type": "code", "execution_count": 20, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:01:02.869830Z", + "iopub.status.busy": "2026-03-02T20:01:02.869770Z", + "iopub.status.idle": "2026-03-02T20:01:02.872051Z", + "shell.execute_reply": "2026-03-02T20:01:02.871854Z" + } + }, "outputs": [ { "data": { @@ -4150,16 +3956,47 @@ { "cell_type": "code", "execution_count": 21, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:01:02.872883Z", + "iopub.status.busy": "2026-03-02T20:01:02.872823Z", + "iopub.status.idle": "2026-03-02T20:01:05.815789Z", + "shell.execute_reply": "2026-03-02T20:01:05.815468Z" + } + }, "outputs": [], "source": [ + "def _cache_lookup(cache, hook_name, expected_ndim=None):\n", + " \"\"\"Look up cache value, handling bridge hook.name aliasing.\n", + "\n", + " TransformerBridge hook.name may differ from the cache key (e.g. hook.name is\n", + " 'blocks.0.attn.hook_result' but cache stores 'blocks.0.hook_attn_out').\n", + " Additionally, compute_head_results may overwrite hook_result with a 4D tensor,\n", + " so we fall back to the block-level alias when the shape doesn't match.\n", + " \"\"\"\n", + " try:\n", + " val = cache[hook_name]\n", + " if expected_ndim is not None and val.ndim != expected_ndim:\n", + " raise KeyError(f\"Shape mismatch: expected {expected_ndim}D, got {val.ndim}D\")\n", + " return val\n", + " except KeyError:\n", + " # Try the block-level alias: blocks.X.attn.hook_result -> blocks.X.hook_attn_out\n", + " parts = hook_name.split(\".\")\n", + " if len(parts) >= 4 and parts[2] == \"attn\":\n", + " alt_key = f\"{parts[0]}.{parts[1]}.hook_attn_out\"\n", + " return cache[alt_key]\n", + " raise\n", + "\n", + "\n", "def patch_residual_component(\n", " corrupted_residual_component: Float[torch.Tensor, \"batch pos d_model\"],\n", " hook,\n", " pos,\n", " clean_cache,\n", "):\n", - " corrupted_residual_component[:, pos, :] = clean_cache[hook.name][:, pos, :]\n", + " clean_value = _cache_lookup(clean_cache, hook.name, expected_ndim=3)[:, pos, :].clone()\n", + " corrupted_residual_component = corrupted_residual_component.clone()\n", + " corrupted_residual_component[:, pos : pos + 1, :] = clean_value.unsqueeze(1)\n", " return corrupted_residual_component\n", "\n", "\n", @@ -4204,7 +4041,14 @@ { "cell_type": "code", "execution_count": 22, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:01:05.816938Z", + "iopub.status.busy": "2026-03-02T20:01:05.816873Z", + "iopub.status.idle": "2026-03-02T20:01:05.827321Z", + "shell.execute_reply": "2026-03-02T20:01:05.827121Z" + } + }, "outputs": [ { "data": { @@ -4237,220 +4081,19 @@ ], "xaxis": "x", "yaxis": "y", - "z": [ - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 1.000650405883789, - -0.0002469856117386371, - 9.76665523921838e-06, - -0.00036458822432905436, - -4.8967522161547095e-05 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 1.001051902770996, - -2.7621845219982788e-05, - -1.9768245692830533e-05, - -0.0004596704675350338, - -0.0005947590689174831 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 1.0002663135528564, - 0.0008680911851115525, - 0.0005157867562957108, - -0.0009929431835189462, - -0.0008658089209347963 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.994907796382904, - 0.005429857410490513, - 0.0016050540143623948, - -0.0006193603039719164, - -0.0016324409516528249 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.9675672054290771, - 0.03134213387966156, - 0.0028418952133506536, - -0.0012302964460104704, - -0.000985861523076892 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.967520534992218, - 0.03100077249109745, - 0.0017823305679485202, - -0.00048668819363228977, - -0.0006467136554419994 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.9228319525718689, - 0.05134531855583191, - 0.004728672094643116, - 0.0009345446596853435, - 0.017046840861439705 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.6565483808517456, - 0.02385685034096241, - 0.002357019344344735, - -1.7183941963594407e-05, - 0.3186916410923004 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.027302566915750504, - 0.03142499923706055, - 0.0018202561186626554, - 0.0007990868762135506, - 0.9383866190910339 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.026841485872864723, - 0.02098155952990055, - 0.0012512058019638062, - 0.00032317222212441266, - 1.0048279762268066 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.005687985569238663, - 0.014263377524912357, - 0.00048709093243815005, - -8.977938705356792e-05, - 0.9914212226867676 - ] - ] + "z": { + "bdata": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAgD8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAFMVgD/tVYK53OcWN3whv7mxRUu4AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAHMigD9Yq+C3QR6vt5zS8Lnr2xu6AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAALwIgD+ov2M6vlAHOqEmgrru4mK6AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAEyyfj9X67E73IfSOk0LIro/5NW6AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAGuydz99YQA9o0U6OxlFobqhNYG6AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAGqvdz909v08VqHpOmAT/7nHhSm6AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAK4+bD+tUFI91/maO7LwdDpPpYs8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAHoTKD+Tb8M8iWoaO3D0ibepK6M+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAJuo3zw9uAA9K67uOmzRUTo6OnA/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAFjj27z64Ks8dhKkOkN4qTk6noA/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAGpeurvVr2k8bED/ORaevbjbzX0/", + "dtype": "f4", + "shape": "12, 15" + } } ], "layout": { "coloraxis": { - "cmid": 0, + "cmid": 0.0, "colorscale": [ [ - 0, + 0.0, "rgb(103,0,31)" ], [ @@ -4490,7 +4133,7 @@ "rgb(33,102,172)" ], [ - 1, + 1.0, "rgb(5,48,97)" ] ] @@ -4571,7 +4214,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -4607,7 +4250,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -4631,7 +4274,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -4667,64 +4310,13 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], "type": "heatmap" } ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], "histogram": [ { "marker": { @@ -4745,7 +4337,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -4781,7 +4373,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -4796,7 +4388,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -4832,7 +4424,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -4925,6 +4517,17 @@ "type": "scattergl" } ], + "scattermap": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermap" + } + ], "scattermapbox": [ { "marker": { @@ -4977,7 +4580,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -5013,7 +4616,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -5104,7 +4707,7 @@ ], "sequential": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -5140,13 +4743,13 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], "sequentialminus": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -5182,7 +4785,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ] @@ -5318,8 +4921,8 @@ "anchor": "y", "constrain": "domain", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "scaleanchor": "y", "title": { @@ -5331,8 +4934,8 @@ "autorange": "reversed", "constrain": "domain", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "title": { "text": "Layer" @@ -5374,7 +4977,14 @@ { "cell_type": "code", "execution_count": 23, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:01:05.828240Z", + "iopub.status.busy": "2026-03-02T20:01:05.828187Z", + "iopub.status.idle": "2026-03-02T20:01:11.665460Z", + "shell.execute_reply": "2026-03-02T20:01:11.665175Z" + } + }, "outputs": [], "source": [ "patched_attn_diff = torch.zeros(\n", @@ -5421,7 +5031,14 @@ { "cell_type": "code", "execution_count": 24, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:01:11.666597Z", + "iopub.status.busy": "2026-03-02T20:01:11.666533Z", + "iopub.status.idle": "2026-03-02T20:01:11.676610Z", + "shell.execute_reply": "2026-03-02T20:01:11.676427Z" + } + }, "outputs": [ { "data": { @@ -5454,220 +5071,19 @@ ], "xaxis": "x", "yaxis": "y", - "z": [ - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.035456884652376175, - -0.0002469856117386371, - 9.76665523921838e-06, - -0.00036458822432905436, - -4.8967522161547095e-05 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.0029848709236830473, - 7.950929284561425e-05, - 2.0842242520302534e-05, - 8.088535105343908e-05, - -0.0005967392353340983 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.0019131568260490894, - 0.0006668510613963008, - 0.00039482791908085346, - -0.0007051457650959492, - -0.00027282864903099835 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.1546323299407959, - 0.0038019807543605566, - 0.0005171628436073661, - -0.00011964991426793858, - -0.0005599213181994855 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.005406397394835949, - 0.019581740722060204, - 0.001007509301416576, - -0.0002424211270408705, - 0.0007936497568152845 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.3520970046520233, - 0.0010525835677981377, - 0.00022436455765273422, - 0.00013367898645810783, - 8.172441448550671e-05 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.11986024677753448, - 0.021243548020720482, - 0.002727783052250743, - 0.0013409851817414165, - 0.01797366514801979 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.013310473412275314, - 0.011509180068969727, - 0.00037542887730523944, - -4.094611358596012e-05, - 0.29760244488716125 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.0015009435592219234, - 0.017351653426885605, - 0.0005848917062394321, - 0.0010122752282768488, - 0.5697318911552429 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.00012901381705887616, - 0.00630143890157342, - 0.00014156615361571312, - 0.00031229801243171096, - 0.27152299880981445 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.0009373303619213402, - 8.669164526509121e-05, - 0.00033243544748984277, - 9.73309283835988e-07, - -0.1929796040058136 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.40617984533309937 - ] - ] + "z": { + "bdata": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIg7ET3tVYK53OcWN3whv7mxRUu4AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAmUQ7vzY6Y4evavN+oMqjjEPBy6AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADKb+rrkry46odPOOU7HOLrJDo+5AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAALlXHj6DNHk7mroHOl9G+7hVBBO6AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAFspsbv1a6A8hyKEOj6Vfrn1DlA6AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAANGtD7bG4o6HzFrOdRuCzmyAa04AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACN59T2IB648OccyOzDQrzpYPZM8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAHoVWjzDkjw86ozEOSTlKriMX5g+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAPa5xLoYJo48fkwZOkyvhDrz2RE/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAJrJBrkaes47Ut0UOUztozlHBYs+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAGGgdboP8rU4ASuuOXQomTXmm0W+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAC89s++", + "dtype": "f4", + "shape": "12, 15" + } } ], "layout": { "coloraxis": { - "cmid": 0, + "cmid": 0.0, "colorscale": [ [ - 0, + 0.0, "rgb(103,0,31)" ], [ @@ -5707,7 +5123,7 @@ "rgb(33,102,172)" ], [ - 1, + 1.0, "rgb(5,48,97)" ] ] @@ -5788,7 +5204,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -5824,7 +5240,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -5848,7 +5264,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -5884,64 +5300,13 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], "type": "heatmap" } ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], "histogram": [ { "marker": { @@ -5962,7 +5327,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -5998,7 +5363,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -6013,7 +5378,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -6049,7 +5414,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -6142,6 +5507,17 @@ "type": "scattergl" } ], + "scattermap": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermap" + } + ], "scattermapbox": [ { "marker": { @@ -6194,7 +5570,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -6230,7 +5606,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -6321,7 +5697,7 @@ ], "sequential": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -6357,13 +5733,13 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], "sequentialminus": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -6399,7 +5775,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ] @@ -6535,8 +5911,8 @@ "anchor": "y", "constrain": "domain", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "scaleanchor": "y", "title": { @@ -6548,8 +5924,8 @@ "autorange": "reversed", "constrain": "domain", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "title": { "text": "Layer" @@ -6593,7 +5969,14 @@ { "cell_type": "code", "execution_count": 25, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:01:11.677475Z", + "iopub.status.busy": "2026-03-02T20:01:11.677420Z", + "iopub.status.idle": "2026-03-02T20:01:11.686943Z", + "shell.execute_reply": "2026-03-02T20:01:11.686752Z" + } + }, "outputs": [ { "data": { @@ -6626,220 +6009,19 @@ ], "xaxis": "x", "yaxis": "y", - "z": [ - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.8507890701293945, - -0.00027843358111567795, - -7.293107046280056e-05, - -0.00047373308916576207, - 4.0039929444901645e-05 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.008863994851708412, - 0.000222149450564757, - 0.00014938619278836995, - -4.853121208725497e-05, - 0.000304041663184762 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.013550343923270702, - 5.86334899708163e-05, - -0.0003296833310741931, - -0.0006382559076882899, - 0.0007730424986220896 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.0019468198297545314, - 0.0004995090421289206, - 0.00017318192112725228, - 0.00016871812113095075, - 0.00040764876757748425 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.019787074998021126, - 0.004128609783947468, - -4.86990247736685e-05, - -0.00017019486404024065, - 0.0007914346642792225 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.09652391821146011, - -0.0018826150335371494, - -0.0004844730719923973, - 0.0007094081956893206, - -0.00018335132335778326 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.015900013968348503, - -0.0008501688134856522, - 0.00012337534280959517, - 2.7521158699528314e-05, - -0.007238299585878849 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.010360540822148323, - 0.0031509376130998135, - 0.0005309234256856143, - 0.0002361114020459354, - 0.008496351540088654 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.012533102184534073, - 2.201692586822901e-05, - -0.00035374757135286927, - 8.615465048933402e-05, - -0.021631328389048576 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -0.00033465056912973523, - 0.0008094912045635283, - 1.6244195649051107e-05, - 0.00012924875773023814, - 0.03162466362118721 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.0013599144294857979, - -0.00019499746849760413, - -9.934466652339324e-05, - -0.00014217027637641877, - 0.028764141723513603 - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.02044912613928318 - ] - ] + "z": { + "bdata": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAFrNWT+M8ZG5U6qYuJ5s+LmLpCg4AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAKU9ETw0tWc5i0YdOSP2TLiYXp85AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAKP+XTyMYHQ4HrWsufNUJ7oHqEo6AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACk1/zpP1gI6Efs1OdxnMTlxsNU5AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAEQXorwrR4c7xI1LuO2mMbkRik86AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAMOuxT1vx/a6/Jj9uaYUOjr6/kC5AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAI8+gryZy166LF4BOeou7jctLO27AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAKa+KTyMgk47xC8LOnHBdzmmNws8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAExVTbzTobw3fnu5uTWCtjgKNbG8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAE9Ur7lTSlQ6eUuGN9/OBzktiQE9AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAC4msjrlC0y54cbPuGMcFbkFous8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADTg6c8", + "dtype": "f4", + "shape": "12, 15" + } } ], "layout": { "coloraxis": { - "cmid": 0, + "cmid": 0.0, "colorscale": [ [ - 0, + 0.0, "rgb(103,0,31)" ], [ @@ -6879,7 +6061,7 @@ "rgb(33,102,172)" ], [ - 1, + 1.0, "rgb(5,48,97)" ] ] @@ -6960,7 +6142,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -6996,7 +6178,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -7020,7 +6202,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -7056,64 +6238,13 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], "type": "heatmap" } ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], "histogram": [ { "marker": { @@ -7134,7 +6265,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -7170,7 +6301,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -7185,7 +6316,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -7221,7 +6352,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -7314,6 +6445,17 @@ "type": "scattergl" } ], + "scattermap": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermap" + } + ], "scattermapbox": [ { "marker": { @@ -7366,7 +6508,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -7402,7 +6544,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -7493,7 +6635,7 @@ ], "sequential": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -7529,13 +6671,13 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], "sequentialminus": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -7571,7 +6713,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ] @@ -7707,8 +6849,8 @@ "anchor": "y", "constrain": "domain", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "scaleanchor": "y", "title": { @@ -7720,8 +6862,8 @@ "autorange": "reversed", "constrain": "domain", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "title": { "text": "Layer" @@ -7757,7 +6899,14 @@ { "cell_type": "code", "execution_count": 26, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:01:11.687735Z", + "iopub.status.busy": "2026-03-02T20:01:11.687686Z", + "iopub.status.idle": "2026-03-02T20:01:14.032005Z", + "shell.execute_reply": "2026-03-02T20:01:14.031726Z" + } + }, "outputs": [], "source": [ "def patch_head_vector(\n", @@ -7766,9 +6915,11 @@ " head_index,\n", " clean_cache,\n", "):\n", - " corrupted_head_vector[:, :, head_index, :] = clean_cache[hook.name][\n", + " clean_value = _cache_lookup(clean_cache, hook.name)[\n", " :, :, head_index, :\n", - " ]\n", + " ].clone()\n", + " corrupted_head_vector = corrupted_head_vector.clone()\n", + " corrupted_head_vector[:, :, head_index : head_index + 1, :] = clean_value.unsqueeze(2)\n", " return corrupted_head_vector\n", "\n", "\n", @@ -7800,7 +6951,14 @@ { "cell_type": "code", "execution_count": 27, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:01:14.033126Z", + "iopub.status.busy": "2026-03-02T20:01:14.033063Z", + "iopub.status.idle": "2026-03-02T20:01:14.043209Z", + "shell.execute_reply": "2026-03-02T20:01:14.043010Z" + } + }, "outputs": [ { "data": { @@ -7816,184 +6974,19 @@ "type": "heatmap", "xaxis": "x", "yaxis": "y", - "z": [ - [ - 0.0009487751522101462, - 0.016124747693538666, - 0.0018548924708738923, - 0.0034389030188322067, - -0.00982347596436739, - 0.011058605276048183, - -0.004063969012349844, - -0.0015792781487107277, - -0.0012082795146852732, - 0.003828897839412093, - -0.004256919026374817, - -0.0011422622483223677 - ], - [ - -0.0010771177476271987, - -0.00037898647133260965, - 2.5171791548928013e-06, - -0.00026067905128002167, - -0.00014146546891424805, - 0.0038321535103023052, - -0.0004293300735298544, - -0.00142992555629462, - -0.0009228314156644046, - 0.0006944393389858305, - 0.00043302192352712154, - -0.0035714071709662676 - ], - [ - -0.0004967569257132709, - 0.0008057993836700916, - 0.0005424688570201397, - -0.0005309234256856143, - -0.0007159864180721343, - -0.0010389237431809306, - -0.0009490771917626262, - -8.649027586216107e-05, - 0.0002766547549981624, - 0.0021084228064864874, - -0.0001975146442418918, - -0.0016405630158260465 - ], - [ - 0.1162627637386322, - 0.0002507446042727679, - -0.0014675153652206063, - -0.00039680811460129917, - 0.018962211906909943, - -0.00018764731066767126, - 0.011170871555805206, - -0.0013301445869728923, - -0.0007356539717875421, - -0.00030253134900704026, - -0.00014683544577565044, - -0.00022228369198273867 - ], - [ - -0.001650598249398172, - 0.0002927311579696834, - -0.00143563118763268, - 0.03084198758006096, - -0.007432155776768923, - -0.00028236035723239183, - 0.006017433945089579, - -0.011007187888026237, - -0.001266107545234263, - 0.0014901700196787715, - -0.0001800622121663764, - 0.002944394713267684 - ], - [ - -0.004211106337606907, - 0.0029597999528050423, - 0.002045023487880826, - 0.0013397098518908024, - -0.0012190865818411112, - 0.34349915385246277, - 0.0005632104002870619, - -0.0001262281439267099, - -0.00515326950699091, - 0.016240738332271576, - 0.01709030382335186, - -0.004175194539129734 - ], - [ - 0.039775289595127106, - 0.015226684510707855, - -0.0010229480685666203, - 0.0008072761120274663, - -0.004935584031045437, - -0.002123525831848383, - -0.014274083077907562, - 0.0013746818294748664, - 0.0014838266652077436, - 0.1302703619003296, - -0.00033616088330745697, - 0.0012919505825266242 - ], - [ - 0.00037177055492065847, - 0.019514480605721474, - 0.00022255218937061727, - 0.124249167740345, - -0.00040352059295400977, - -0.007652895525097847, - 0.0013010123511776328, - -0.0011253133416175842, - -0.007449474185705185, - 0.19224143028259277, - -0.003275118535384536, - -0.0005017912480980158 - ], - [ - -0.001007912098430097, - 3.091096004936844e-05, - -0.0008595998515374959, - 0.012359987013041973, - -0.0004041247011628002, - -0.004328910261392593, - 0.3185553252696991, - 0.002330605871975422, - 0.0021182901691645384, - 0.0001405928487656638, - 0.2779357433319092, - 0.005738262087106705 - ], - [ - 0.0058898297138512135, - -0.0009689796715974808, - 0.00912561360746622, - 0.020675739273428917, - -0.03700518235564232, - 0.014263041317462921, - -0.04828466475009918, - 0.05834139883518219, - 0.0006514795240946114, - 0.26360899209976196, - 0.0004918567719869316, - -0.00261044898070395 - ], - [ - 0.08374208211898804, - 0.020676210522651672, - -0.003743582172319293, - 0.01085072010755539, - -0.001096583902835846, - 0.00047430366976186633, - 0.04818058758974075, - -0.4799128472805023, - 0.00018429107149131596, - 0.011861988343298435, - 0.06088569387793541, - 0.0008461413672193885 - ], - [ - 0.005328264087438583, - -0.011493473313748837, - -0.11350836604833603, - 0.006329597905278206, - 0.00031669469899497926, - -0.0011600167490541935, - -0.022669579833745956, - 0.004070379305630922, - 0.0073160636238753796, - -0.00834545586258173, - -0.27817651629447937, - 0.0036344374530017376 - ] - ] + "z": { + "bdata": "N8t4Om8VhDy9MvM6SWZhO7/yILwMMjU8pjGFuyrzzrpTWZ66qO96Owh+i7vOwJW6iSqNunCFxrmBogI2D4OIuSIpFLnoK3s7WrTguUR2u7pZuHG6kPY1Oo5c4zkyC2q7IisCuhpyUzoROw46EmgLug+hO7q0HYi6N8t4uuRPtbhTGZE5DCIKO5q4TrkI99a68RruPaAXgzmMUcC6eBzQuVBXmzwaQUW5QAY3PIdBrroz5kC69tueuYdwGblvlmi5Yk3YupGUmTlzKry6Mqj8PPKE87siKZS51SzFO5JXNLwV8aW6l1HDOubpPLmZ+0A7c/2JuxIQQjvpEAY7Z66vOk+1n7oc368++o8TOrp6A7nQ1ai7OAuFPE0BjDyKzIi7CuwiPVV6eTxj+oW6cs5TOt60obsgIQu73t1pvLgwtDpeecI6i2UFPokssLn3SKk6ji3DOSbdnzzfPWo5xHb+PXDF07mrvvq77pKqOm1ek7oqGvS7H9tEPjCfVrvEngO6ZiGEuvRSBDiDTWG6KYJKPOqu07k42o27qBmjPpa1GDsr0Qo7QZ4UOZFNjj6cDLw7CwTBO98dfroXhxU8yV6pPIaRF70ztGk8f8VFvYL3bj0awSo66/eGPt8uATo+Diu7U4GrPTFgqTwrTnW7tcgxPO+ej7q2xvg5MFhFPQC39b7pv0A5ilpCPL1jeT0Psl06PpiuO+1NPLzwdui9Wm3PO8/cpTnsF5i6xbW5vLxlhTtzvO87PboIvAhtjr5KMm47", + "dtype": "f4", + "shape": "12, 12" + } } ], "layout": { "coloraxis": { - "cmid": 0, + "cmid": 0.0, "colorscale": [ [ - 0, + 0.0, "rgb(103,0,31)" ], [ @@ -8033,7 +7026,7 @@ "rgb(33,102,172)" ], [ - 1, + 1.0, "rgb(5,48,97)" ] ] @@ -8114,7 +7107,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -8150,7 +7143,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -8174,7 +7167,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -8210,64 +7203,13 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], "type": "heatmap" } ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], "histogram": [ { "marker": { @@ -8288,7 +7230,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -8324,7 +7266,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -8339,7 +7281,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -8375,7 +7317,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -8468,6 +7410,17 @@ "type": "scattergl" } ], + "scattermap": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermap" + } + ], "scattermapbox": [ { "marker": { @@ -8520,7 +7473,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -8556,7 +7509,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -8647,7 +7600,7 @@ ], "sequential": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -8683,13 +7636,13 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], "sequentialminus": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -8725,7 +7678,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ] @@ -8861,8 +7814,8 @@ "anchor": "y", "constrain": "domain", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "scaleanchor": "y", "title": { @@ -8874,8 +7827,8 @@ "autorange": "reversed", "constrain": "domain", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "title": { "text": "Layer" @@ -8920,7 +7873,14 @@ { "cell_type": "code", "execution_count": 28, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:01:14.044115Z", + "iopub.status.busy": "2026-03-02T20:01:14.044062Z", + "iopub.status.idle": "2026-03-02T20:01:16.388776Z", + "shell.execute_reply": "2026-03-02T20:01:16.388497Z" + } + }, "outputs": [], "source": [ "patched_head_v_diff = torch.zeros(\n", @@ -8951,7 +7911,14 @@ { "cell_type": "code", "execution_count": 29, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:01:16.389868Z", + "iopub.status.busy": "2026-03-02T20:01:16.389796Z", + "iopub.status.idle": "2026-03-02T20:01:16.399764Z", + "shell.execute_reply": "2026-03-02T20:01:16.399537Z" + } + }, "outputs": [ { "data": { @@ -8967,184 +7934,19 @@ "type": "heatmap", "xaxis": "x", "yaxis": "y", - "z": [ - [ - -0.00019892427371814847, - 0.005339574534446001, - 0.0006527548539452255, - 0.003504416672512889, - -0.00898387935012579, - 0.0034814265090972185, - -0.0008631910313852131, - -3.406582254683599e-05, - 0.0005166929331608117, - 0.00044255363172851503, - -0.0039068968035280704, - -0.0001880836207419634 - ], - [ - -0.0004399022145662457, - -0.00044510437874123454, - -6.73597096465528e-05, - 7.242763240355998e-05, - -3.6549441574607044e-05, - -0.0019323208834975958, - -0.0001572397886775434, - 1.6143509128596634e-05, - 0.00020593880617525429, - 0.000336798548232764, - 0.0003515324497129768, - -0.0005669358652085066 - ], - [ - 0.00021013410878367722, - -0.0007199132232926786, - 0.0004868560063187033, - -0.0005974104860797524, - -0.0005921411793678999, - -0.0005443819100037217, - -0.000227552984142676, - -0.0004809825913980603, - 0.00020570388005580753, - 0.001183376181870699, - -0.0003574058646336198, - -0.0009104468626901507 - ], - [ - 0.0010395278222858906, - -0.00012042184971505776, - -7.762980385450646e-05, - -0.0007275318494066596, - -0.001310007064603269, - -0.0023108376190066338, - 0.010987084358930588, - -5.0712766096694395e-05, - 0.00014314358122646809, - 0.00015069512301124632, - -7.957642083056271e-05, - -2.0238119759596884e-05 - ], - [ - -0.0005373673629947007, - -0.0008137872209772468, - -0.00013334336108528078, - 0.030609702691435814, - -0.007185807917267084, - 0.000148916311445646, - 0.0013340713921934366, - -0.01142292469739914, - -0.0005336419562809169, - 0.0005126654868945479, - 0.00037344868178479373, - 0.0029547319281846285 - ], - [ - 8.22278525447473e-06, - 6.477540864580078e-06, - 0.0015973682748153806, - 0.00034015480196103454, - -0.0012577504385262728, - -5.450531898532063e-05, - 0.0006331544718705118, - -0.00027081489679403603, - 7.427356467815116e-05, - -0.006704355590045452, - 0.003175975289195776, - -0.0017300404142588377 - ], - [ - 0.04863045737147331, - 0.015314852818846703, - -0.0004648726317100227, - -0.00011676354915834963, - -4.930314753437415e-05, - -0.003952810075134039, - -0.01737578585743904, - -0.00015421917487401515, - 0.0012194222072139382, - -0.00018090127559844404, - -0.00042647725786082447, - 0.00012334177154116333 - ], - [ - -2.956846401502844e-05, - -0.0013855225406587124, - -0.00012129446986364201, - 0.1332160234451294, - -0.00024490474606864154, - -0.007315828464925289, - 0.00033297244226559997, - -0.000795092957559973, - -0.007938209921121597, - 0.208413764834404, - -0.00019127204723190516, - -0.00020650937221944332 - ], - [ - -0.0020483459811657667, - -0.0003764357534237206, - -0.0033135139383375645, - -0.009666135534644127, - -0.00031723169377073646, - -0.005141589790582657, - 0.31717124581336975, - 0.0028427678626030684, - 0.0004723234742414206, - -0.0011529687326401472, - 0.2726709246635437, - -0.003175639547407627 - ], - [ - -0.00043929810635745525, - 5.7089622714556754e-05, - -0.0020629793871194124, - 0.020066648721694946, - -0.007871017791330814, - 0.011316264048218727, - 0.003056862158700824, - 0.06856372952461243, - -0.002747517777606845, - -0.009279227815568447, - 0.000506624230183661, - -0.0013159140944480896 - ], - [ - -0.012957162223756313, - -0.0030454176012426615, - -0.01792328804731369, - -0.0043589151464402676, - -0.0011521632550284266, - 0.0004999117809347808, - -0.0031131464056670666, - 0.019585633650422096, - 4.34632929682266e-05, - 0.01297028549015522, - -0.007695754989981651, - -0.0009146086522378027 - ], - [ - 0.004100752994418144, - -0.020459463819861412, - -0.035875942558050156, - 0.014656225219368935, - 0.0008441276149824262, - 0.0017804511589929461, - -0.01804223284125328, - 0.003519016318023205, - 0.008253024891018867, - -0.0017665562918409705, - 0.044167667627334595, - 0.006474285386502743 - ] - ] + "z": { + "bdata": "A0VQuZn9rjui4Co6m65lOyAvE7zFK2Q78PphukpkCbhXrwc6VzzoOTILgLs7v0W5+NPmudor6rlxwY24AniXOIdwGbglQf26mxYluaFCfDcRDFg5r7ywOS1JuDk/lRS6P4RcOWLcPLpsQP85Z78cugE8G7qytA667kBuuWRY+7kF31c5OBqbOn1du7lJpm66uziIOmhq+7gYJ6O4U4g+uhy7q7qfaxe77QU0PHGwVbi7aRY54IoeOYukqLg3Lau3NuAMuk4pVbrsyAu55ML6PMp367skwxs58uWuOjoiO7xrxAu6gm8GOjK5wzn6tUE7NYIBN6685TZ0eNE66p2xOY/ppLpinGS4XA4mOvrgjbnw/Jo4Z6/buz81UDtcn+K6XzJHPSfrejxTiPO5RVLzuFzOTbh4jYG7ZFWOvIDmILmFB6A6x3Q8uZaz37kzeQE54dfxt/WOtbonO/64G2oIPp79f7k3r++7LtauOQVOULraDQK8hGpVPvyYSLlAwFi58zQGu5tFxbkhPFm7dFoevP+Qprlfdqi7PmSiPhZOOjv54Pc5KyCXuoibiz4aIlC74HnmuT9zbzhELQe7A2GkPO/wALxAZTk8HWBIO4dqjD3vDjS7NwQYvN7yBDoGW6y6DklUvCWKR7sU0pK8idKOuwYNl7okJQM6cwNMuz1xoDxWADc4IH5UPP4z/LtVxG+6Ll+GO7Oap7xr8hK9vBxwPG84XTriVek6tM2TvKmbZjuHOAc8I3bnupTnND2yJ9Q7", + "dtype": "f4", + "shape": "12, 12" + } } ], "layout": { "coloraxis": { - "cmid": 0, + "cmid": 0.0, "colorscale": [ [ - 0, + 0.0, "rgb(103,0,31)" ], [ @@ -9184,7 +7986,7 @@ "rgb(33,102,172)" ], [ - 1, + 1.0, "rgb(5,48,97)" ] ] @@ -9265,7 +8067,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -9301,74 +8103,23 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], "type": "contour" } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" } ], - "heatmapgl": [ + "heatmap": [ { "colorbar": { "outlinewidth": 0, @@ -9376,7 +8127,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -9412,11 +8163,11 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], - "type": "heatmapgl" + "type": "heatmap" } ], "histogram": [ @@ -9439,7 +8190,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -9475,7 +8226,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -9490,7 +8241,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -9526,7 +8277,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -9619,6 +8370,17 @@ "type": "scattergl" } ], + "scattermap": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermap" + } + ], "scattermapbox": [ { "marker": { @@ -9671,7 +8433,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -9707,7 +8469,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -9798,7 +8560,7 @@ ], "sequential": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -9834,13 +8596,13 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], "sequentialminus": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -9876,7 +8638,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ] @@ -10012,8 +8774,8 @@ "anchor": "y", "constrain": "domain", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "scaleanchor": "y", "title": { @@ -10025,8 +8787,8 @@ "autorange": "reversed", "constrain": "domain", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "title": { "text": "Layer" @@ -10059,7 +8821,14 @@ { "cell_type": "code", "execution_count": 30, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:01:16.400703Z", + "iopub.status.busy": "2026-03-02T20:01:16.400646Z", + "iopub.status.idle": "2026-03-02T20:01:16.413102Z", + "shell.execute_reply": "2026-03-02T20:01:16.412894Z" + } + }, "outputs": [ { "data": { @@ -10218,152 +8987,10 @@ ], "legendgroup": "", "marker": { - "color": [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 3, - 3, - 3, - 3, - 3, - 3, - 3, - 3, - 3, - 3, - 3, - 3, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 6, - 6, - 6, - 6, - 6, - 6, - 6, - 6, - 6, - 6, - 6, - 6, - 7, - 7, - 7, - 7, - 7, - 7, - 7, - 7, - 7, - 7, - 7, - 7, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - 9, - 9, - 9, - 9, - 9, - 9, - 9, - 9, - 9, - 9, - 9, - 9, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 11, - 11, - 11, - 11, - 11, - 11, - 11, - 11, - 11, - 11, - 11, - 11 - ], + "color": { + "bdata": "AAAAAAAAAAAAAAAAAQEBAQEBAQEBAQEBAgICAgICAgICAgICAwMDAwMDAwMDAwMDBAQEBAQEBAQEBAQEBQUFBQUFBQUFBQUFBgYGBgYGBgYGBgYGBwcHBwcHBwcHBwcHCAgICAgICAgICAgICQkJCQkJCQkJCQkJCgoKCgoKCgoKCgoKCwsLCwsLCwsLCwsL", + "dtype": "i1" + }, "coloraxis": "coloraxis", "symbol": "circle" }, @@ -10372,299 +8999,15 @@ "orientation": "v", "showlegend": false, "type": "scatter", - "x": [ - -0.00019892427371814847, - 0.005339574534446001, - 0.0006527548539452255, - 0.003504416672512889, - -0.00898387935012579, - 0.0034814265090972185, - -0.0008631910313852131, - -3.406582254683599e-05, - 0.0005166929331608117, - 0.00044255363172851503, - -0.0039068968035280704, - -0.0001880836207419634, - -0.0004399022145662457, - -0.00044510437874123454, - -6.73597096465528e-05, - 7.242763240355998e-05, - -3.6549441574607044e-05, - -0.0019323208834975958, - -0.0001572397886775434, - 1.6143509128596634e-05, - 0.00020593880617525429, - 0.000336798548232764, - 0.0003515324497129768, - -0.0005669358652085066, - 0.00021013410878367722, - -0.0007199132232926786, - 0.0004868560063187033, - -0.0005974104860797524, - -0.0005921411793678999, - -0.0005443819100037217, - -0.000227552984142676, - -0.0004809825913980603, - 0.00020570388005580753, - 0.001183376181870699, - -0.0003574058646336198, - -0.0009104468626901507, - 0.0010395278222858906, - -0.00012042184971505776, - -7.762980385450646e-05, - -0.0007275318494066596, - -0.001310007064603269, - -0.0023108376190066338, - 0.010987084358930588, - -5.0712766096694395e-05, - 0.00014314358122646809, - 0.00015069512301124632, - -7.957642083056271e-05, - -2.0238119759596884e-05, - -0.0005373673629947007, - -0.0008137872209772468, - -0.00013334336108528078, - 0.030609702691435814, - -0.007185807917267084, - 0.000148916311445646, - 0.0013340713921934366, - -0.01142292469739914, - -0.0005336419562809169, - 0.0005126654868945479, - 0.00037344868178479373, - 0.0029547319281846285, - 8.22278525447473e-06, - 6.477540864580078e-06, - 0.0015973682748153806, - 0.00034015480196103454, - -0.0012577504385262728, - -5.450531898532063e-05, - 0.0006331544718705118, - -0.00027081489679403603, - 7.427356467815116e-05, - -0.006704355590045452, - 0.003175975289195776, - -0.0017300404142588377, - 0.04863045737147331, - 0.015314852818846703, - -0.0004648726317100227, - -0.00011676354915834963, - -4.930314753437415e-05, - -0.003952810075134039, - -0.01737578585743904, - -0.00015421917487401515, - 0.0012194222072139382, - -0.00018090127559844404, - -0.00042647725786082447, - 0.00012334177154116333, - -2.956846401502844e-05, - -0.0013855225406587124, - -0.00012129446986364201, - 0.1332160234451294, - -0.00024490474606864154, - -0.007315828464925289, - 0.00033297244226559997, - -0.000795092957559973, - -0.007938209921121597, - 0.208413764834404, - -0.00019127204723190516, - -0.00020650937221944332, - -0.0020483459811657667, - -0.0003764357534237206, - -0.0033135139383375645, - -0.009666135534644127, - -0.00031723169377073646, - -0.005141589790582657, - 0.31717124581336975, - 0.0028427678626030684, - 0.0004723234742414206, - -0.0011529687326401472, - 0.2726709246635437, - -0.003175639547407627, - -0.00043929810635745525, - 5.7089622714556754e-05, - -0.0020629793871194124, - 0.020066648721694946, - -0.007871017791330814, - 0.011316264048218727, - 0.003056862158700824, - 0.06856372952461243, - -0.002747517777606845, - -0.009279227815568447, - 0.000506624230183661, - -0.0013159140944480896, - -0.012957162223756313, - -0.0030454176012426615, - -0.01792328804731369, - -0.0043589151464402676, - -0.0011521632550284266, - 0.0004999117809347808, - -0.0031131464056670666, - 0.019585633650422096, - 4.34632929682266e-05, - 0.01297028549015522, - -0.007695754989981651, - -0.0009146086522378027, - 0.004100752994418144, - -0.020459463819861412, - -0.035875942558050156, - 0.014656225219368935, - 0.0008441276149824262, - 0.0017804511589929461, - -0.01804223284125328, - 0.003519016318023205, - 0.008253024891018867, - -0.0017665562918409705, - 0.044167667627334595, - 0.006474285386502743 - ], + "x": { + "bdata": "A0VQuZn9rjui4Co6m65lOyAvE7zFK2Q78PphukpkCbhXrwc6VzzoOTILgLs7v0W5+NPmudor6rlxwY24AniXOIdwGbglQf26mxYluaFCfDcRDFg5r7ywOS1JuDk/lRS6P4RcOWLcPLpsQP85Z78cugE8G7qytA667kBuuWRY+7kF31c5OBqbOn1du7lJpm66uziIOmhq+7gYJ6O4U4g+uhy7q7qfaxe77QU0PHGwVbi7aRY54IoeOYukqLg3Lau3NuAMuk4pVbrsyAu55ML6PMp367skwxs58uWuOjoiO7xrxAu6gm8GOjK5wzn6tUE7NYIBN6685TZ0eNE66p2xOY/ppLpinGS4XA4mOvrgjbnw/Jo4Z6/buz81UDtcn+K6XzJHPSfrejxTiPO5RVLzuFzOTbh4jYG7ZFWOvIDmILmFB6A6x3Q8uZaz37kzeQE54dfxt/WOtbonO/64G2oIPp79f7k3r++7LtauOQVOULraDQK8hGpVPvyYSLlAwFi58zQGu5tFxbkhPFm7dFoevP+Qprlfdqi7PmSiPhZOOjv54Pc5KyCXuoibiz4aIlC74HnmuT9zbzhELQe7A2GkPO/wALxAZTk8HWBIO4dqjD3vDjS7NwQYvN7yBDoGW6y6DklUvCWKR7sU0pK8idKOuwYNl7okJQM6cwNMuz1xoDxWADc4IH5UPP4z/LtVxG+6Ll+GO7Oap7xr8hK9vBxwPG84XTriVek6tM2TvKmbZjuHOAc8I3bnupTnND2yJ9Q7", + "dtype": "f4" + }, "xaxis": "x", - "y": [ - 0.0009487751522101462, - 0.016124747693538666, - 0.0018548924708738923, - 0.0034389030188322067, - -0.00982347596436739, - 0.011058605276048183, - -0.004063969012349844, - -0.0015792781487107277, - -0.0012082795146852732, - 0.003828897839412093, - -0.004256919026374817, - -0.0011422622483223677, - -0.0010771177476271987, - -0.00037898647133260965, - 2.5171791548928013e-06, - -0.00026067905128002167, - -0.00014146546891424805, - 0.0038321535103023052, - -0.0004293300735298544, - -0.00142992555629462, - -0.0009228314156644046, - 0.0006944393389858305, - 0.00043302192352712154, - -0.0035714071709662676, - -0.0004967569257132709, - 0.0008057993836700916, - 0.0005424688570201397, - -0.0005309234256856143, - -0.0007159864180721343, - -0.0010389237431809306, - -0.0009490771917626262, - -8.649027586216107e-05, - 0.0002766547549981624, - 0.0021084228064864874, - -0.0001975146442418918, - -0.0016405630158260465, - 0.1162627637386322, - 0.0002507446042727679, - -0.0014675153652206063, - -0.00039680811460129917, - 0.018962211906909943, - -0.00018764731066767126, - 0.011170871555805206, - -0.0013301445869728923, - -0.0007356539717875421, - -0.00030253134900704026, - -0.00014683544577565044, - -0.00022228369198273867, - -0.001650598249398172, - 0.0002927311579696834, - -0.00143563118763268, - 0.03084198758006096, - -0.007432155776768923, - -0.00028236035723239183, - 0.006017433945089579, - -0.011007187888026237, - -0.001266107545234263, - 0.0014901700196787715, - -0.0001800622121663764, - 0.002944394713267684, - -0.004211106337606907, - 0.0029597999528050423, - 0.002045023487880826, - 0.0013397098518908024, - -0.0012190865818411112, - 0.34349915385246277, - 0.0005632104002870619, - -0.0001262281439267099, - -0.00515326950699091, - 0.016240738332271576, - 0.01709030382335186, - -0.004175194539129734, - 0.039775289595127106, - 0.015226684510707855, - -0.0010229480685666203, - 0.0008072761120274663, - -0.004935584031045437, - -0.002123525831848383, - -0.014274083077907562, - 0.0013746818294748664, - 0.0014838266652077436, - 0.1302703619003296, - -0.00033616088330745697, - 0.0012919505825266242, - 0.00037177055492065847, - 0.019514480605721474, - 0.00022255218937061727, - 0.124249167740345, - -0.00040352059295400977, - -0.007652895525097847, - 0.0013010123511776328, - -0.0011253133416175842, - -0.007449474185705185, - 0.19224143028259277, - -0.003275118535384536, - -0.0005017912480980158, - -0.001007912098430097, - 3.091096004936844e-05, - -0.0008595998515374959, - 0.012359987013041973, - -0.0004041247011628002, - -0.004328910261392593, - 0.3185553252696991, - 0.002330605871975422, - 0.0021182901691645384, - 0.0001405928487656638, - 0.2779357433319092, - 0.005738262087106705, - 0.0058898297138512135, - -0.0009689796715974808, - 0.00912561360746622, - 0.020675739273428917, - -0.03700518235564232, - 0.014263041317462921, - -0.04828466475009918, - 0.05834139883518219, - 0.0006514795240946114, - 0.26360899209976196, - 0.0004918567719869316, - -0.00261044898070395, - 0.08374208211898804, - 0.020676210522651672, - -0.003743582172319293, - 0.01085072010755539, - -0.001096583902835846, - 0.00047430366976186633, - 0.04818058758974075, - -0.4799128472805023, - 0.00018429107149131596, - 0.011861988343298435, - 0.06088569387793541, - 0.0008461413672193885, - 0.005328264087438583, - -0.011493473313748837, - -0.11350836604833603, - 0.006329597905278206, - 0.00031669469899497926, - -0.0011600167490541935, - -0.022669579833745956, - 0.004070379305630922, - 0.0073160636238753796, - -0.00834545586258173, - -0.27817651629447937, - 0.0036344374530017376 - ], + "y": { + "bdata": "N8t4Om8VhDy9MvM6SWZhO7/yILwMMjU8pjGFuyrzzrpTWZ66qO96Owh+i7vOwJW6iSqNunCFxrmBogI2D4OIuSIpFLnoK3s7WrTguUR2u7pZuHG6kPY1Oo5c4zkyC2q7IisCuhpyUzoROw46EmgLug+hO7q0HYi6N8t4uuRPtbhTGZE5DCIKO5q4TrkI99a68RruPaAXgzmMUcC6eBzQuVBXmzwaQUW5QAY3PIdBrroz5kC69tueuYdwGblvlmi5Yk3YupGUmTlzKry6Mqj8PPKE87siKZS51SzFO5JXNLwV8aW6l1HDOubpPLmZ+0A7c/2JuxIQQjvpEAY7Z66vOk+1n7oc368++o8TOrp6A7nQ1ai7OAuFPE0BjDyKzIi7CuwiPVV6eTxj+oW6cs5TOt60obsgIQu73t1pvLgwtDpeecI6i2UFPokssLn3SKk6ji3DOSbdnzzfPWo5xHb+PXDF07mrvvq77pKqOm1ek7oqGvS7H9tEPjCfVrvEngO6ZiGEuvRSBDiDTWG6KYJKPOqu07k42o27qBmjPpa1GDsr0Qo7QZ4UOZFNjj6cDLw7CwTBO98dfroXhxU8yV6pPIaRF70ztGk8f8VFvYL3bj0awSo66/eGPt8uATo+Diu7U4GrPTFgqTwrTnW7tcgxPO+ej7q2xvg5MFhFPQC39b7pv0A5ilpCPL1jeT0Psl06PpiuO+1NPLzwdui9Wm3PO8/cpTnsF5i6xbW5vLxlhTtzvO87PboIvAhtjr5KMm47", + "dtype": "f4" + }, "yaxis": "y" } ], @@ -10677,7 +9020,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -10713,7 +9056,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ] @@ -10797,7 +9140,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -10833,7 +9176,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -10857,7 +9200,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -10893,64 +9236,13 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], "type": "heatmap" } ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], "histogram": [ { "marker": { @@ -10971,7 +9263,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -11007,7 +9299,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -11022,7 +9314,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -11058,7 +9350,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -11151,6 +9443,17 @@ "type": "scattergl" } ], + "scattermap": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermap" + } + ], "scattermapbox": [ { "marker": { @@ -11203,7 +9506,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -11239,7 +9542,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -11330,7 +9633,7 @@ ], "sequential": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -11366,13 +9669,13 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], "sequentialminus": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -11408,7 +9711,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ] @@ -11543,8 +9846,8 @@ "xaxis": { "anchor": "y", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "range": [ -0.5, @@ -11557,8 +9860,8 @@ "yaxis": { "anchor": "x", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "range": [ -0.5, @@ -11606,7 +9909,14 @@ { "cell_type": "code", "execution_count": 31, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:01:16.413943Z", + "iopub.status.busy": "2026-03-02T20:01:16.413886Z", + "iopub.status.idle": "2026-03-02T20:01:18.814428Z", + "shell.execute_reply": "2026-03-02T20:01:18.814149Z" + } + }, "outputs": [], "source": [ "def patch_head_pattern(\n", @@ -11615,9 +9925,11 @@ " head_index,\n", " clean_cache,\n", "):\n", - " corrupted_head_pattern[:, head_index, :, :] = clean_cache[hook.name][\n", + " clean_value = _cache_lookup(clean_cache, hook.name)[\n", " :, head_index, :, :\n", - " ]\n", + " ].clone()\n", + " corrupted_head_pattern = corrupted_head_pattern.clone()\n", + " corrupted_head_pattern[:, head_index : head_index + 1, :, :] = clean_value.unsqueeze(1)\n", " return corrupted_head_pattern\n", "\n", "\n", @@ -11642,7 +9954,14 @@ { "cell_type": "code", "execution_count": 32, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:01:18.815552Z", + "iopub.status.busy": "2026-03-02T20:01:18.815472Z", + "iopub.status.idle": "2026-03-02T20:01:18.835004Z", + "shell.execute_reply": "2026-03-02T20:01:18.834811Z" + } + }, "outputs": [ { "data": { @@ -11658,184 +9977,19 @@ "type": "heatmap", "xaxis": "x", "yaxis": "y", - "z": [ - [ - 0.0006401354330591857, - 0.005318799521774054, - 0.0011584057938307524, - -5.920405237702653e-05, - -0.00106671336106956, - 0.005079298280179501, - -0.0030818663071841, - -0.0020521720871329308, - -0.0014405983965843916, - 0.003492669900879264, - -0.002568227471783757, - -0.0009168237447738647 - ], - [ - -0.0007600873941555619, - 0.0001683824957581237, - 0.00012246915139257908, - -0.00034914951538667083, - 1.4901700524205808e-05, - 0.0050090523436665535, - -0.0002975976967718452, - -0.0014448943547904491, - -0.001099134678952396, - 0.00047447148244827986, - 5.195457561057992e-05, - -0.0034954219590872526 - ], - [ - -0.0007243098807521164, - 0.0017458146903663874, - -0.00015556166181340814, - 5.7626621128292754e-05, - -9.7398049547337e-05, - -0.0004238593974150717, - -0.0007917031762190163, - 0.00027222454082220793, - 0.00010179472155869007, - 0.0004223826399538666, - 0.00015193692524917424, - -0.0007437760941684246 - ], - [ - 0.11458104848861694, - 0.00021140948229003698, - -0.0009424989693798125, - 0.000429833511589095, - 0.02004295401275158, - 0.002104730810970068, - 7.628730963915586e-05, - -0.001543701975606382, - -0.0008484235731884837, - -0.0005819046637043357, - 0.00011921360419364646, - -1.899631206470076e-05 - ], - [ - -0.001127125695347786, - 0.001237143180333078, - -0.0012324444251134992, - -0.0005952289211563766, - -0.0007541133090853691, - -0.0005842540413141251, - 0.004813014063984156, - 0.00018187458044849336, - -0.0005361591465771198, - 0.0008579217828810215, - -0.0002985374303534627, - -1.144477391790133e-05 - ], - [ - -0.004241178277879953, - 0.0029509058222174644, - 0.0005218615406192839, - 0.0009535074350424111, - 0.0001622070267330855, - 0.34350839257240295, - -0.0003052163519896567, - 0.00010293584637111053, - -0.005300541408360004, - 0.024864863604307175, - 0.014383262023329735, - -0.0023285921197384596 - ], - [ - -0.0023893399629741907, - -0.002172795357182622, - -0.00047614958020858467, - 0.00043188079143874347, - -0.004675475414842367, - 0.0018583494238555431, - -0.0026542814448475838, - 0.0014367386465892196, - 0.00030326974228955805, - 0.13043038547039032, - 8.813483145786449e-05, - 0.0011766973184421659 - ], - [ - 0.00031847349600866437, - 0.02057075686752796, - 0.00031840638257563114, - -0.002512782346457243, - -0.0002628941729199141, - -0.00024718698114156723, - 0.0005524033331312239, - -0.00043131023994646966, - 0.00025715501396916807, - 0.008090951479971409, - -0.0030689111445099115, - -0.0004238593974150717 - ], - [ - 0.000976699055172503, - 0.00039251212729141116, - 0.0017534669023007154, - 0.022595642134547234, - -4.4805787183577195e-05, - 0.00014220383309293538, - 0.009584981948137283, - -0.0003157213795930147, - 0.0015271222218871117, - 0.0011813960736617446, - -0.010774029418826103, - 0.00936581939458847 - ], - [ - 0.006314125377684832, - -0.0010949057759717107, - 0.011662023141980171, - 0.0013481340138241649, - -0.02918696030974388, - 0.0038333951961249113, - -0.04409456625580788, - -0.005032042507082224, - 0.00482167350128293, - 0.2766477167606354, - -3.164933150401339e-05, - -0.0006618167390115559 - ], - [ - 0.0953889712691307, - 0.02506939135491848, - 0.014239178970456123, - 0.014754998497664928, - 9.890835644910112e-05, - -8.977938705356792e-05, - 0.05082912743091583, - -0.5051022171974182, - 0.00014696970174554735, - -0.0016026375815272331, - 0.06883199512958527, - 0.002327115274965763 - ], - [ - 0.0013425961369648576, - 0.009630928747355938, - -0.07776415348052979, - -0.007728713098913431, - -0.0005726079107262194, - -0.002957182005047798, - -0.0049475994892418385, - 0.00045916702947579324, - -0.0006328188464976847, - -0.006520198658108711, - -0.3204910457134247, - -0.002473111730068922 - ] - ] + "z": { + "bdata": "wIgnOklIrjta1Jc6HEp6uE3Mi7qXcqY7+PRJu9d/Brti3Ly61udkOzE/KLuoDnC6plRHusIEMTlzW/84WQm3uTqDfjfaHqQ7KdWbuSNXvbovFZC6l1H4ObSsVji/CmW7Jt09ujn05DrfTiK5/XZuONLDy7gyOd65A1RPuuuMjzl71NU4eFzdOfHJHjmECUO6UKnqPa8rXjk68na6g03hOfsypDzP7gk7VUSgOHFSyrqPp166CYQYumRY+zhbI6S3Z8CTul5Kojone6G6OhQcurixRbo0Jhm697CdOxqwPTndgwy6cAVhOmzRnLmLtUq3hvyKu72CQTsm3Qg6xvZ5OhrBKjly4K8+NMafuTml2DjYrq27UrXLPH+oazwdkhi7t4Mcu1thDrv2ufm5D2HjOa0smbtYmvM6AvEtu19WvDrz0p45dY8FPiMHujjDUZo6qS6nOS+DqDwQ0KY5fqoku092ibkpVYG5iO4QOjYP4rkxH4c5M48EPGAiSbsmDN65m/R/OmsEzjkZ5OU6yBq5PGHxOrhS3RQ5iwwdPKdDpbljOsg6b/iaOtCDMLzschk8B+nOO6Nvj7ogEz88qqqwOi8X77y7N3s7AJs0vWzfpLv9BZ472qSNPm5aArhTey26PFvDPXNezTwzSmk8+r1xPMVazzhvJ7u43jFQPTdOAb+3JBo5NPPRuqr3jD3pghg7uO+vOpbKHTxWQp+9/UH9u6ghFrpPxkG7IyCiuw2Y8DlaBSa6fqPVu0YXpL6uDiK7", + "dtype": "f4", + "shape": "12, 12" + } } ], "layout": { "coloraxis": { - "cmid": 0, + "cmid": 0.0, "colorscale": [ [ - 0, + 0.0, "rgb(103,0,31)" ], [ @@ -11875,7 +10029,7 @@ "rgb(33,102,172)" ], [ - 1, + 1.0, "rgb(5,48,97)" ] ] @@ -11956,7 +10110,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -11992,7 +10146,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -12016,7 +10170,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -12052,64 +10206,13 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], "type": "heatmap" } ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], "histogram": [ { "marker": { @@ -12130,7 +10233,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -12166,7 +10269,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -12181,7 +10284,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -12217,7 +10320,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -12310,6 +10413,17 @@ "type": "scattergl" } ], + "scattermap": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermap" + } + ], "scattermapbox": [ { "marker": { @@ -12362,7 +10476,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -12398,7 +10512,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -12489,7 +10603,7 @@ ], "sequential": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -12525,13 +10639,13 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], "sequentialminus": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -12567,7 +10681,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ] @@ -12703,8 +10817,8 @@ "anchor": "y", "constrain": "domain", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "scaleanchor": "y", "title": { @@ -12716,8 +10830,8 @@ "autorange": "reversed", "constrain": "domain", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "title": { "text": "Layer" @@ -12894,299 +11008,15 @@ "orientation": "v", "showlegend": false, "type": "scatter", - "x": [ - 0.0006401354330591857, - 0.005318799521774054, - 0.0011584057938307524, - -5.920405237702653e-05, - -0.00106671336106956, - 0.005079298280179501, - -0.0030818663071841, - -0.0020521720871329308, - -0.0014405983965843916, - 0.003492669900879264, - -0.002568227471783757, - -0.0009168237447738647, - -0.0007600873941555619, - 0.0001683824957581237, - 0.00012246915139257908, - -0.00034914951538667083, - 1.4901700524205808e-05, - 0.0050090523436665535, - -0.0002975976967718452, - -0.0014448943547904491, - -0.001099134678952396, - 0.00047447148244827986, - 5.195457561057992e-05, - -0.0034954219590872526, - -0.0007243098807521164, - 0.0017458146903663874, - -0.00015556166181340814, - 5.7626621128292754e-05, - -9.7398049547337e-05, - -0.0004238593974150717, - -0.0007917031762190163, - 0.00027222454082220793, - 0.00010179472155869007, - 0.0004223826399538666, - 0.00015193692524917424, - -0.0007437760941684246, - 0.11458104848861694, - 0.00021140948229003698, - -0.0009424989693798125, - 0.000429833511589095, - 0.02004295401275158, - 0.002104730810970068, - 7.628730963915586e-05, - -0.001543701975606382, - -0.0008484235731884837, - -0.0005819046637043357, - 0.00011921360419364646, - -1.899631206470076e-05, - -0.001127125695347786, - 0.001237143180333078, - -0.0012324444251134992, - -0.0005952289211563766, - -0.0007541133090853691, - -0.0005842540413141251, - 0.004813014063984156, - 0.00018187458044849336, - -0.0005361591465771198, - 0.0008579217828810215, - -0.0002985374303534627, - -1.144477391790133e-05, - -0.004241178277879953, - 0.0029509058222174644, - 0.0005218615406192839, - 0.0009535074350424111, - 0.0001622070267330855, - 0.34350839257240295, - -0.0003052163519896567, - 0.00010293584637111053, - -0.005300541408360004, - 0.024864863604307175, - 0.014383262023329735, - -0.0023285921197384596, - -0.0023893399629741907, - -0.002172795357182622, - -0.00047614958020858467, - 0.00043188079143874347, - -0.004675475414842367, - 0.0018583494238555431, - -0.0026542814448475838, - 0.0014367386465892196, - 0.00030326974228955805, - 0.13043038547039032, - 8.813483145786449e-05, - 0.0011766973184421659, - 0.00031847349600866437, - 0.02057075686752796, - 0.00031840638257563114, - -0.002512782346457243, - -0.0002628941729199141, - -0.00024718698114156723, - 0.0005524033331312239, - -0.00043131023994646966, - 0.00025715501396916807, - 0.008090951479971409, - -0.0030689111445099115, - -0.0004238593974150717, - 0.000976699055172503, - 0.00039251212729141116, - 0.0017534669023007154, - 0.022595642134547234, - -4.4805787183577195e-05, - 0.00014220383309293538, - 0.009584981948137283, - -0.0003157213795930147, - 0.0015271222218871117, - 0.0011813960736617446, - -0.010774029418826103, - 0.00936581939458847, - 0.006314125377684832, - -0.0010949057759717107, - 0.011662023141980171, - 0.0013481340138241649, - -0.02918696030974388, - 0.0038333951961249113, - -0.04409456625580788, - -0.005032042507082224, - 0.00482167350128293, - 0.2766477167606354, - -3.164933150401339e-05, - -0.0006618167390115559, - 0.0953889712691307, - 0.02506939135491848, - 0.014239178970456123, - 0.014754998497664928, - 9.890835644910112e-05, - -8.977938705356792e-05, - 0.05082912743091583, - -0.5051022171974182, - 0.00014696970174554735, - -0.0016026375815272331, - 0.06883199512958527, - 0.002327115274965763, - 0.0013425961369648576, - 0.009630928747355938, - -0.07776415348052979, - -0.007728713098913431, - -0.0005726079107262194, - -0.002957182005047798, - -0.0049475994892418385, - 0.00045916702947579324, - -0.0006328188464976847, - -0.006520198658108711, - -0.3204910457134247, - -0.002473111730068922 - ], + "x": { + "bdata": "wIgnOklIrjta1Jc6HEp6uE3Mi7qXcqY7+PRJu9d/Brti3Ly61udkOzE/KLuoDnC6plRHusIEMTlzW/84WQm3uTqDfjfaHqQ7KdWbuSNXvbovFZC6l1H4ObSsVji/CmW7Jt09ujn05DrfTiK5/XZuONLDy7gyOd65A1RPuuuMjzl71NU4eFzdOfHJHjmECUO6UKnqPa8rXjk68na6g03hOfsypDzP7gk7VUSgOHFSyrqPp166CYQYumRY+zhbI6S3Z8CTul5Kojone6G6OhQcurixRbo0Jhm697CdOxqwPTndgwy6cAVhOmzRnLmLtUq3hvyKu72CQTsm3Qg6xvZ5OhrBKjly4K8+NMafuTml2DjYrq27UrXLPH+oazwdkhi7t4Mcu1thDrv2ufm5D2HjOa0smbtYmvM6AvEtu19WvDrz0p45dY8FPiMHujjDUZo6qS6nOS+DqDwQ0KY5fqoku092ibkpVYG5iO4QOjYP4rkxH4c5M48EPGAiSbsmDN65m/R/OmsEzjkZ5OU6yBq5PGHxOrhS3RQ5iwwdPKdDpbljOsg6b/iaOtCDMLzschk8B+nOO6Nvj7ogEz88qqqwOi8X77y7N3s7AJs0vWzfpLv9BZ472qSNPm5aArhTey26PFvDPXNezTwzSmk8+r1xPMVazzhvJ7u43jFQPTdOAb+3JBo5NPPRuqr3jD3pghg7uO+vOpbKHTxWQp+9/UH9u6ghFrpPxkG7IyCiuw2Y8DlaBSa6fqPVu0YXpL6uDiK7", + "dtype": "f4" + }, "xaxis": "x", - "y": [ - 0.0009487751522101462, - 0.016124747693538666, - 0.0018548924708738923, - 0.0034389030188322067, - -0.00982347596436739, - 0.011058605276048183, - -0.004063969012349844, - -0.0015792781487107277, - -0.0012082795146852732, - 0.003828897839412093, - -0.004256919026374817, - -0.0011422622483223677, - -0.0010771177476271987, - -0.00037898647133260965, - 2.5171791548928013e-06, - -0.00026067905128002167, - -0.00014146546891424805, - 0.0038321535103023052, - -0.0004293300735298544, - -0.00142992555629462, - -0.0009228314156644046, - 0.0006944393389858305, - 0.00043302192352712154, - -0.0035714071709662676, - -0.0004967569257132709, - 0.0008057993836700916, - 0.0005424688570201397, - -0.0005309234256856143, - -0.0007159864180721343, - -0.0010389237431809306, - -0.0009490771917626262, - -8.649027586216107e-05, - 0.0002766547549981624, - 0.0021084228064864874, - -0.0001975146442418918, - -0.0016405630158260465, - 0.1162627637386322, - 0.0002507446042727679, - -0.0014675153652206063, - -0.00039680811460129917, - 0.018962211906909943, - -0.00018764731066767126, - 0.011170871555805206, - -0.0013301445869728923, - -0.0007356539717875421, - -0.00030253134900704026, - -0.00014683544577565044, - -0.00022228369198273867, - -0.001650598249398172, - 0.0002927311579696834, - -0.00143563118763268, - 0.03084198758006096, - -0.007432155776768923, - -0.00028236035723239183, - 0.006017433945089579, - -0.011007187888026237, - -0.001266107545234263, - 0.0014901700196787715, - -0.0001800622121663764, - 0.002944394713267684, - -0.004211106337606907, - 0.0029597999528050423, - 0.002045023487880826, - 0.0013397098518908024, - -0.0012190865818411112, - 0.34349915385246277, - 0.0005632104002870619, - -0.0001262281439267099, - -0.00515326950699091, - 0.016240738332271576, - 0.01709030382335186, - -0.004175194539129734, - 0.039775289595127106, - 0.015226684510707855, - -0.0010229480685666203, - 0.0008072761120274663, - -0.004935584031045437, - -0.002123525831848383, - -0.014274083077907562, - 0.0013746818294748664, - 0.0014838266652077436, - 0.1302703619003296, - -0.00033616088330745697, - 0.0012919505825266242, - 0.00037177055492065847, - 0.019514480605721474, - 0.00022255218937061727, - 0.124249167740345, - -0.00040352059295400977, - -0.007652895525097847, - 0.0013010123511776328, - -0.0011253133416175842, - -0.007449474185705185, - 0.19224143028259277, - -0.003275118535384536, - -0.0005017912480980158, - -0.001007912098430097, - 3.091096004936844e-05, - -0.0008595998515374959, - 0.012359987013041973, - -0.0004041247011628002, - -0.004328910261392593, - 0.3185553252696991, - 0.002330605871975422, - 0.0021182901691645384, - 0.0001405928487656638, - 0.2779357433319092, - 0.005738262087106705, - 0.0058898297138512135, - -0.0009689796715974808, - 0.00912561360746622, - 0.020675739273428917, - -0.03700518235564232, - 0.014263041317462921, - -0.04828466475009918, - 0.05834139883518219, - 0.0006514795240946114, - 0.26360899209976196, - 0.0004918567719869316, - -0.00261044898070395, - 0.08374208211898804, - 0.020676210522651672, - -0.003743582172319293, - 0.01085072010755539, - -0.001096583902835846, - 0.00047430366976186633, - 0.04818058758974075, - -0.4799128472805023, - 0.00018429107149131596, - 0.011861988343298435, - 0.06088569387793541, - 0.0008461413672193885, - 0.005328264087438583, - -0.011493473313748837, - -0.11350836604833603, - 0.006329597905278206, - 0.00031669469899497926, - -0.0011600167490541935, - -0.022669579833745956, - 0.004070379305630922, - 0.0073160636238753796, - -0.00834545586258173, - -0.27817651629447937, - 0.0036344374530017376 - ], + "y": { + "bdata": "N8t4Om8VhDy9MvM6SWZhO7/yILwMMjU8pjGFuyrzzrpTWZ66qO96Owh+i7vOwJW6iSqNunCFxrmBogI2D4OIuSIpFLnoK3s7WrTguUR2u7pZuHG6kPY1Oo5c4zkyC2q7IisCuhpyUzoROw46EmgLug+hO7q0HYi6N8t4uuRPtbhTGZE5DCIKO5q4TrkI99a68RruPaAXgzmMUcC6eBzQuVBXmzwaQUW5QAY3PIdBrroz5kC69tueuYdwGblvlmi5Yk3YupGUmTlzKry6Mqj8PPKE87siKZS51SzFO5JXNLwV8aW6l1HDOubpPLmZ+0A7c/2JuxIQQjvpEAY7Z66vOk+1n7oc368++o8TOrp6A7nQ1ai7OAuFPE0BjDyKzIi7CuwiPVV6eTxj+oW6cs5TOt60obsgIQu73t1pvLgwtDpeecI6i2UFPokssLn3SKk6ji3DOSbdnzzfPWo5xHb+PXDF07mrvvq77pKqOm1ek7oqGvS7H9tEPjCfVrvEngO6ZiGEuvRSBDiDTWG6KYJKPOqu07k42o27qBmjPpa1GDsr0Qo7QZ4UOZFNjj6cDLw7CwTBO98dfroXhxU8yV6pPIaRF70ztGk8f8VFvYL3bj0awSo66/eGPt8uATo+Diu7U4GrPTFgqTwrTnW7tcgxPO+ej7q2xvg5MFhFPQC39b7pv0A5ilpCPL1jeT0Psl06PpiuO+1NPLzwdui9Wm3PO8/cpTnsF5i6xbW5vLxlhTtzvO87PboIvAhtjr5KMm47", + "dtype": "f4" + }, "yaxis": "y" } ], @@ -13243,86 +11073,26 @@ "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" } ], - "contourcarpet": [ + "choropleth": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, - "type": "contourcarpet" + "type": "choropleth" } ], - "heatmap": [ + "contour": [ { "colorbar": { "outlinewidth": 0, @@ -13330,7 +11100,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -13366,14 +11136,23 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], - "type": "heatmap" + "type": "contour" } ], - "heatmapgl": [ + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ { "colorbar": { "outlinewidth": 0, @@ -13381,7 +11160,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -13417,11 +11196,11 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], - "type": "heatmapgl" + "type": "heatmap" } ], "histogram": [ @@ -13444,7 +11223,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -13480,7 +11259,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -13495,7 +11274,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -13531,7 +11310,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -13624,6 +11403,17 @@ "type": "scattergl" } ], + "scattermap": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermap" + } + ], "scattermapbox": [ { "marker": { @@ -13676,7 +11466,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -13712,7 +11502,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -13803,7 +11593,7 @@ ], "sequential": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -13839,13 +11629,13 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], "sequentialminus": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -13881,7 +11671,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ] @@ -14016,8 +11806,8 @@ "xaxis": { "anchor": "y", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "title": { "text": "Attention Patch" @@ -14026,8 +11816,8 @@ "yaxis": { "anchor": "x", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "title": { "text": "Output Patch" @@ -14088,34 +11878,41 @@ { "cell_type": "code", "execution_count": 33, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:01:18.835948Z", + "iopub.status.busy": "2026-03-02T20:01:18.835888Z", + "iopub.status.idle": "2026-03-02T20:01:19.756609Z", + "shell.execute_reply": "2026-03-02T20:01:19.756394Z" + } + }, "outputs": [ { "data": { "text/html": [ - "

Top Early Heads


\n", + "

Top Early Heads


\n", "

Top Middle Heads


\n", + "

Top Middle Heads


\n", "

Top Late Heads


\n", + "

Top Late Heads


\n", "
" ], @@ -14226,7 +12023,14 @@ { "cell_type": "code", "execution_count": 34, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:01:19.757694Z", + "iopub.status.busy": "2026-03-02T20:01:19.757623Z", + "iopub.status.idle": "2026-03-02T20:01:19.927861Z", + "shell.execute_reply": "2026-03-02T20:01:19.927582Z" + } + }, "outputs": [], "source": [ "example_text = \"Research in mechanistic interpretability seeks to explain behaviors of machine learning models in terms of their internal components.\"\n", @@ -14241,18 +12045,25 @@ { "cell_type": "code", "execution_count": 35, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:01:19.928921Z", + "iopub.status.busy": "2026-03-02T20:01:19.928853Z", + "iopub.status.idle": "2026-03-02T20:01:19.933411Z", + "shell.execute_reply": "2026-03-02T20:01:19.933218Z" + } + }, "outputs": [ { "data": { "text/html": [ - "

Induction Heads


\n", + "

Induction Heads


\n", "
" ], @@ -14337,7 +12148,14 @@ { "cell_type": "code", "execution_count": 36, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:01:19.934342Z", + "iopub.status.busy": "2026-03-02T20:01:19.934270Z", + "iopub.status.idle": "2026-03-02T20:01:20.784258Z", + "shell.execute_reply": "2026-03-02T20:01:20.784016Z" + } + }, "outputs": [ { "name": "stdout", @@ -14431,7 +12249,14 @@ { "cell_type": "code", "execution_count": 37, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:01:20.785307Z", + "iopub.status.busy": "2026-03-02T20:01:20.785212Z", + "iopub.status.idle": "2026-03-02T20:01:20.812155Z", + "shell.execute_reply": "2026-03-02T20:01:20.811960Z" + } + }, "outputs": [ { "data": { @@ -14447,184 +12272,19 @@ "type": "heatmap", "xaxis": "x", "yaxis": "y", - "z": [ - [ - 0.039069853723049164, - 0.0004489101702347398, - 0.03133601322770119, - 0.007519590202718973, - 0.034592196345329285, - 0.00036230171099305153, - 0.034512776881456375, - 0.19740213453769684, - 0.038447845727205276, - 0.04053792357444763, - 0.027628764510154724, - 0.02496313862502575 - ], - [ - 0.1890650987625122, - 0.17219914495944977, - 0.06807752698659897, - 0.04494515433907509, - 0.07908554375171661, - 0.03096739575266838, - 0.028282109647989273, - 0.03644327446818352, - 0.026936717331409454, - 0.018826229497790337, - 0.045100897550582886, - 0.0065726665779948235 - ], - [ - 0.15745528042316437, - 0.020724520087242126, - 0.4817989468574524, - 0.2991352379322052, - 0.10764895379543304, - 0.33004048466682434, - 0.0997551754117012, - 0.04926132410764694, - 0.25493940711021423, - 0.3606453835964203, - 0.1257179230451584, - 0.07931824028491974 - ], - [ - 0.005844001192599535, - 0.15787364542484283, - 0.4189082086086273, - 0.30129021406173706, - 0.014345049858093262, - 0.032344333827495575, - 0.3312888443470001, - 0.5285974144935608, - 0.34242063760757446, - 0.101837158203125, - 0.10516070574522018, - 0.2233113795518875 - ], - [ - 0.10626544803380966, - 0.11930850893259048, - 0.022880680859088898, - 0.22826944291591644, - 0.020003994926810265, - 0.10010036826133728, - 0.1739213615655899, - 0.17407020926475525, - 0.02587701380252838, - 0.10249985754489899, - 0.009514841251075268, - 0.9921423196792603 - ], - [ - 0.019766658544540405, - 0.00528325280174613, - 0.16648508608341217, - 0.12087740004062653, - 0.16500000655651093, - 0.00803269725292921, - 0.41770195960998535, - 0.025827765464782715, - 0.04802601411938667, - 0.016231779009103775, - 0.03110172413289547, - 0.024261215701699257 - ], - [ - 0.2172909826040268, - 0.039100028574466705, - 0.01804858259856701, - 0.059900715947151184, - 0.032934583723545074, - 0.0873451679944992, - 0.026895340532064438, - 0.0943947583436966, - 0.49925994873046875, - 0.006240115500986576, - 0.027026718482375145, - 0.1278565675020218 - ], - [ - 0.2511657178401947, - 0.01330868061631918, - 0.006663354113698006, - 0.037430502474308014, - 0.02331537753343582, - 0.01740722358226776, - 0.022067422047257423, - 0.022141192108392715, - 0.04502448812127113, - 0.0208425372838974, - 0.008310739882290363, - 0.017167754471302032 - ], - [ - 0.020890623331069946, - 0.016537941992282867, - 0.02158307284116745, - 0.0150058064609766, - 0.02421221323311329, - 0.10198988765478134, - 0.029100384563207626, - 0.22793792188167572, - 0.02781485579907894, - 0.0179410632699728, - 0.024828944355249405, - 0.03806235268712044 - ], - [ - 0.02607586607336998, - 0.015407431870698929, - 0.02044427953660488, - 0.14558182656764984, - 0.01247025839984417, - 0.017151640728116035, - 0.013311829417943954, - 0.024451706558465958, - 0.018111787736415863, - 0.01319331955164671, - 0.0357399508357048, - 0.01879822090268135 - ], - [ - 0.02147812582552433, - 0.018419174477458, - 0.018183622509241104, - 0.02172141708433628, - 0.0315677747130394, - 0.034705750644207, - 0.017550116404891014, - 0.011417553760111332, - 0.01579565554857254, - 0.04592214897274971, - 0.01621554046869278, - 0.03039470687508583 - ], - [ - 0.03320508822798729, - 0.0175714660435915, - 0.015131079591810703, - 0.04148406535387039, - 0.015181189402937889, - 0.01758997142314911, - 0.015148494392633438, - 0.01767607219517231, - 0.06622709333896637, - 0.018451133742928505, - 0.01700744964182377, - 0.029749270528554916 - ] - ] + "z": { + "bdata": "tAcgPdRb6zkvWgA952b2O4ywDT0y8705RF0NPckjSj6Aex09FgsmPb1V4jx+f8w8QJpBPvRUMD4zbIs9Zxg4PZ73oT1Rr/084K/nPIpFFT1cqtw8dTmaPLO7OD2GX9c76TshPnTGqTxprvY+PiiZPv123D0m+6g+cUzMPUDGST1Ah4I+h6a4Piy8AD6qcaI9AX+/O5SpIT4le9Y+JkOaPncHazx4ewQ9t56pPi1SBz9rUa8+wo/QPVBe1z3Vq2Q+x6HZPaFX9D1KcLs8sL9pPoPfozzZAc09mhgyPng/Mj7Q+9M8JevRPQ/kGzwN/X0/vu2hPB0frTvJeio+Z473PY/1KD5ymwM8otzVPq+U0zyntkQ9eviEPPnI/jxOv8Y8o4FePrMmID2j2pM8n1p1PWXmBj0G4rI9q1PcPD9SwT0Jn/8+9HnMOzFn3Tz27AI+3piAPqwMWjxeWNo7EVEZPWkAvzyhmY48r8a0PHthtTxwazg9RL6qPNspCDx0o4w80iKrPJh6hzwUz7A8sdp1PDVZxjwX4NA93WPuPL5oaT4K3OM8MvmSPIRmyzz75hs9y5zVPCVvfDyxeqc8iBMVPvhPTDyEgYw8uhlaPAtPyDwCX5Q8qyhYPAhkEj21/pk8y/KvPLXjljzM9ZQ8FvGxPCpNAT2EJw49KsWPPHsQOzzLZYE8KBk8PWPWhDwt/vg8CgIIPQTyjzwt6Hc8KespPX66eDzQGJA8QzF4PF7NkDz9oYc9xSaXPDFTizzBtPM8", + "dtype": "f4", + "shape": "12, 12" + } } ], "layout": { "coloraxis": { - "cmid": 0, + "cmid": 0.0, "colorscale": [ [ - 0, + 0.0, "rgb(103,0,31)" ], [ @@ -14664,7 +12324,7 @@ "rgb(33,102,172)" ], [ - 1, + 1.0, "rgb(5,48,97)" ] ] @@ -14745,7 +12405,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -14781,7 +12441,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -14805,7 +12465,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -14841,64 +12501,13 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], "type": "heatmap" } ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], "histogram": [ { "marker": { @@ -14919,7 +12528,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -14955,7 +12564,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -14970,7 +12579,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -15006,7 +12615,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -15099,6 +12708,17 @@ "type": "scattergl" } ], + "scattermap": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermap" + } + ], "scattermapbox": [ { "marker": { @@ -15151,7 +12771,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -15187,7 +12807,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -15278,7 +12898,7 @@ ], "sequential": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -15314,13 +12934,13 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], "sequentialminus": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -15356,7 +12976,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ] @@ -15492,8 +13112,8 @@ "anchor": "y", "constrain": "domain", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "scaleanchor": "y", "title": { @@ -15505,8 +13125,8 @@ "autorange": "reversed", "constrain": "domain", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "title": { "text": "Layer" @@ -15532,184 +13152,19 @@ "type": "heatmap", "xaxis": "x", "yaxis": "y", - "z": [ - [ - 0.0031923248898237944, - 0.13236315548419952, - 0.005006915424019098, - 1.0427449524286203e-05, - 0.0013110184809193015, - 0.7034568786621094, - 0.00426204688847065, - 0.00016496369789820164, - 0.002474633976817131, - 0.0008572910446673632, - 0.01889149099588394, - 0.008690938353538513 - ], - [ - 0.0002916341181844473, - 0.00013782267342321575, - 0.0015036173863336444, - 0.005392482969909906, - 0.0018583914497867227, - 0.009062949568033218, - 0.012414448894560337, - 0.0022405502386391163, - 0.005135662388056517, - 0.005220627877861261, - 0.005546474829316139, - 0.02975049614906311 - ], - [ - 0.0024816279765218496, - 0.009442180395126343, - 0.0003456332196947187, - 0.0002591445227153599, - 0.0052116685546934605, - 0.000570951378904283, - 0.0015209749108180404, - 0.006313100922852755, - 0.001560864970088005, - 0.0004215767839923501, - 0.00015359291865024716, - 0.005160381551831961 - ], - [ - 0.6775657534599304, - 0.002840448170900345, - 0.0007841526530683041, - 0.00471264636144042, - 0.006322895642369986, - 0.006206681486219168, - 0.0005474805948324502, - 0.00037829449865967035, - 0.0020155368838459253, - 0.007952751591801643, - 0.003576782764866948, - 0.002608788898214698 - ], - [ - 0.00860405620187521, - 0.0070286463014781475, - 0.007598803844302893, - 0.003442801535129547, - 0.016561277210712433, - 0.0059797209687530994, - 0.004869826138019562, - 0.0007624455611221492, - 0.006062133703380823, - 0.007536627352237701, - 0.012022900395095348, - 1.055422134237094e-12 - ], - [ - 0.00950299296528101, - 0.00856209360063076, - 0.004162600729614496, - 0.003008665982633829, - 0.006847422569990158, - 0.004358117934316397, - 0.007669268175959587, - 0.009584215469658375, - 0.0076188258826732635, - 0.0043280418030917645, - 0.041402824223041534, - 0.00976183544844389 - ], - [ - 0.004456141032278538, - 0.008873268961906433, - 0.007405205629765987, - 0.0062249391339719296, - 0.00731915095821023, - 0.005623893812298775, - 0.017349667847156525, - 0.005529467947781086, - 0.002920132130384445, - 0.008636755868792534, - 0.006222263444215059, - 0.00835894700139761 - ], - [ - 0.003699858672916889, - 0.04107949137687683, - 0.04148268699645996, - 0.009313640184700489, - 0.009097025729715824, - 0.008774377405643463, - 0.007298537530004978, - 0.023312218487262726, - 0.008843323215842247, - 0.00987986009567976, - 0.017598601058125496, - 0.006039854139089584 - ], - [ - 0.008986304514110088, - 0.028667239472270012, - 0.008891218341886997, - 0.010114557109773159, - 0.009737391024827957, - 0.007611637003719807, - 0.009763265959918499, - 0.005155472084879875, - 0.009276345372200012, - 0.011895839124917984, - 0.010411946102976799, - 0.007498950231820345 - ], - [ - 0.024409977719187737, - 0.011438451707363129, - 0.02003096230328083, - 0.0051185814663767815, - 0.015081286430358887, - 0.012334450148046017, - 0.015452565625309944, - 0.008602450601756573, - 0.014702522195875645, - 0.020766200497746468, - 0.009192758239805698, - 0.005703347735106945 - ], - [ - 0.017897022888064384, - 0.013280633836984634, - 0.006755237001925707, - 0.012744844891130924, - 0.008020960725843906, - 0.007722244597971439, - 0.017341373488307, - 0.0074546560645103455, - 0.007832515984773636, - 0.00825214572250843, - 0.013642766512930393, - 0.012807483784854412 - ], - [ - 0.004923742264509201, - 0.007951060310006142, - 0.007947920821607113, - 0.004564082249999046, - 0.010363400913774967, - 0.009582078084349632, - 0.0102877551689744, - 0.00832072552293539, - 0.0025700009427964687, - 0.012810997664928436, - 0.008063871413469315, - 0.006558285094797611 - ] - ] + "z": { + "bdata": "VDZROzCKBz4QEaQ7WfEuN67WqzrAFTQ/paiLO8D6LDlwLSI73LtgOkzCmjxvZA48neaYOcCDEDkhFcU6hLOwOyWV8zrTfBQ8AmZLPDbWEjsUSag7xhGrOym/tTt1t/M8caMiO2GzGjzyNLU5y96HOcrGqjtErBU6m1vHOgXezjtCl8w6pwfdORAOITliGKk79nQtPyAnOjuCj00672uaO2YwzzuTYcs7wIMPOpRVxjkbGQQ7UEwCPGNpajs8+So75fcMPGZQ5jtl//g7259hO5Wrhzz38MM70pKfO7zeRzoKpcY78PX2O7L7RDxtiZQrYLIbPAtIDDxoZog7SS1FO5Bg4Dunzo47l077OxYHHTxYp/k7JtKNOwKWKT0X8B88ngSSO0xhETx1p/I7VPrLO1fV7zvNSLg71SCOPAcwtTs8YD87wIANPN7jyzvv8wg8G3pyO+9CKD186Sk9MpgYPKkLFTxkwg88TijvOzL5vjyP4xA8Id8hPJ8qkDzE6cU7WzsTPGrX6jxdrBE8nLclPI+JHzwna/k7/vUfPJfvqDvZ+xc8x+ZCPOqWKjzJufU7dffHPDJoOzwCGKQ867mnO2YXdzxzFko8nix9PFPxDDzX4nA81B2qPEKdFjz14ro7xpySPAyXWTznWt07q89QPEVqAzz3Cv07dg+OPJ1G9DviUwA8/jMHPOaFXzxh1lE8ZFehOwVFAjzqNwI8SI6VOzjLKTwG/hw89o0oPKVTCDyibSg7IeVRPGseBDwL59Y7", + "dtype": "f4", + "shape": "12, 12" + } } ], "layout": { "coloraxis": { - "cmid": 0, + "cmid": 0.0, "colorscale": [ [ - 0, + 0.0, "rgb(103,0,31)" ], [ @@ -15749,7 +13204,7 @@ "rgb(33,102,172)" ], [ - 1, + 1.0, "rgb(5,48,97)" ] ] @@ -15830,7 +13285,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -15866,7 +13321,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -15890,7 +13345,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -15926,64 +13381,13 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], "type": "heatmap" } ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], "histogram": [ { "marker": { @@ -16004,7 +13408,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -16040,7 +13444,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -16055,7 +13459,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -16091,7 +13495,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -16184,6 +13588,17 @@ "type": "scattergl" } ], + "scattermap": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermap" + } + ], "scattermapbox": [ { "marker": { @@ -16236,7 +13651,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -16272,7 +13687,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -16363,7 +13778,7 @@ ], "sequential": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -16399,13 +13814,13 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], "sequentialminus": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -16441,7 +13856,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ] @@ -16577,8 +13992,8 @@ "anchor": "y", "constrain": "domain", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "scaleanchor": "y", "title": { @@ -16590,8 +14005,8 @@ "autorange": "reversed", "constrain": "domain", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "title": { "text": "Layer" @@ -16613,188 +14028,23 @@ { "coloraxis": "coloraxis", "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", - "name": "0", - "type": "heatmap", - "xaxis": "x", - "yaxis": "y", - "z": [ - [ - 0.004035575315356255, - 3.85937346436549e-05, - 0.003946058917790651, - 1.7428524756724073e-07, - 5.9896130551351234e-05, - 4.0836803236743435e-05, - 0.0035017586778849363, - 0.00024610417312942445, - 0.0031679815147072077, - 0.0030104012694209814, - 0.002093541668727994, - 0.008525434881448746 - ], - [ - 0.000526473973877728, - 0.00015670718858018517, - 0.001507942914031446, - 0.005595325026661158, - 0.0018401180859655142, - 0.0038875630125403404, - 0.005349153187125921, - 0.004649169277399778, - 0.005880181211978197, - 0.007283917628228664, - 0.005552186165004969, - 0.00012677280756179243 - ], - [ - 0.0022015420254319906, - 0.008784863166511059, - 0.002159146359190345, - 0.0010447809472680092, - 0.005142326466739178, - 0.002251626690849662, - 0.0008376616751775146, - 0.006352409720420837, - 0.002618127502501011, - 0.0010309136705473065, - 0.00015219187480397522, - 0.005351166240870953 - ], - [ - 0.007752244360744953, - 0.0030915802344679832, - 0.001362923881970346, - 0.004341960418969393, - 0.011233060620725155, - 0.006535551976412535, - 0.000906877510715276, - 0.0006078600417822599, - 0.002819513902068138, - 0.005254077725112438, - 0.004195652436465025, - 0.00255418848246336 - ], - [ - 0.007342735771089792, - 0.004788339603692293, - 0.007458819076418877, - 0.0033073313534259796, - 0.007871866226196289, - 0.004219769034534693, - 0.004172054585069418, - 0.0005154653917998075, - 0.008124975487589836, - 0.0068268910981714725, - 0.008085492067039013, - 3.761376626831847e-11 - ], - [ - 0.4337766170501709, - 0.9306095838546753, - 0.006382268853485584, - 0.0034730439074337482, - 0.005500996019691229, - 0.9255973696708679, - 0.00538142304867506, - 0.007857315242290497, - 0.00863779615610838, - 0.01576443389058113, - 0.012188379652798176, - 0.008265726268291473 - ], - [ - 0.002507298020645976, - 0.008432027883827686, - 0.008623305708169937, - 0.007653353735804558, - 0.01105806790292263, - 0.005525435321033001, - 0.017205175012350082, - 0.004794349893927574, - 0.0040976013988256454, - 0.9257788062095642, - 0.020375633612275124, - 0.006313954945653677 - ], - [ - 0.005555536597967148, - 0.18942977488040924, - 0.8509925007820129, - 0.008273146115243435, - 0.008239664137363434, - 0.00864996388554573, - 0.02832852303981781, - 0.08996275067329407, - 0.006617339327931404, - 0.009413909167051315, - 0.9037814736366272, - 0.03037159889936447 - ], - [ - 0.00735454261302948, - 0.3791317641735077, - 0.005602709017693996, - 0.025401461869478226, - 0.008504674769937992, - 0.00623108958825469, - 0.11892436444759369, - 0.005114651285111904, - 0.013350939378142357, - 0.01576736941933632, - 0.025843923911452293, - 0.008429747074842453 - ], - [ - 0.2398916333913803, - 0.14378757774829865, - 0.09330663084983826, - 0.005819779820740223, - 0.07744801044464111, - 0.01644793339073658, - 0.4442836344242096, - 0.011141352355480194, - 0.03619001433253288, - 0.472646564245224, - 0.00803996529430151, - 0.030953049659729004 - ], - [ - 0.3606555163860321, - 0.48201146721839905, - 0.022851115092635155, - 0.1264195442199707, - 0.04125598818063736, - 0.0072374604642391205, - 0.2877156138420105, - 0.3897320628166199, - 0.030060900375247, - 0.006112942937761545, - 0.1655488908290863, - 0.22245149314403534 - ], - [ - 0.007408542558550835, - 0.033737149089574814, - 0.02041277289390564, - 0.002755412133410573, - 0.02518630214035511, - 0.07808877527713776, - 0.033082809299230576, - 0.046440087258815765, - 0.0032543439883738756, - 0.2744256258010864, - 0.3800230026245117, - 0.009483495727181435 - ] - ] + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": { + "bdata": "4DyEO4zfITj3TYE75SE7NEI5ezg6SCs4w31lOwkIgTngnU87MEpFO8gzCTtDrgs8TAMKOhlRJDlHpsU6HFm3OwQw8TqExn478UevOxJYmDuLrsA7663uOxPvtTtg7gQ5XkgQO2LuDzypfw07jvGIOv+AqDtBkBM7gZZbOskn0Ds+lSs7ux+HOv2VHznQWK87xwb+O5CcSjsPpLI6EEaOO+gKODxSKNY7+LhtOoZYHzoiyTg7fCqsO/Z7iTvEZCc7BZvwO5jnnDsUafQ7175YO/n4ADxgRYo7jLWIO28gBzq7HgU8YLTfOwl5BDwjbiUuChjePmg8bj99ItE7i5xjO+lBtDv482w/YVewO/C7ADyZhQ08ZySBPM2xRzz+bAc8UVEkO5cmCjy9SA083sj6O78sNTyoDrU7yvGMPMsZnTurRYY70/9sP6/qpjz25M47cwu2O+H5QT6i2lk/EowHPKH/BjyXuA08BRHoPD8+uD2B1tg70TwaPDpeZz/Jzfg8b/7wO4Qdwj4il7c7KBfQPBFXCzwKLsw7qo7zPZWYpztCvlo8piqBPDq30zwoHQo8GaZ1PgI9Ez6XF789B7S+OxKdnj3avYY8GXnjPlGKNjz8OxQ9uf7xPh26Azwjkf08zae4PjLK9j5HMrs8I3QBPhT8KD1LKO07b0+TPviKxz5HQvY8AU/IO52FKT5EymM+VsPyO/YvCj2cOKc8eZQ0O4RTzjz27J893oEHPfE3Pj27RlU7fYGMPmySwj6qYBs8", + "dtype": "f4", + "shape": "12, 12" + } } ], "layout": { "coloraxis": { - "cmid": 0, + "cmid": 0.0, "colorscale": [ [ - 0, + 0.0, "rgb(103,0,31)" ], [ @@ -16834,7 +14084,7 @@ "rgb(33,102,172)" ], [ - 1, + 1.0, "rgb(5,48,97)" ] ] @@ -16915,7 +14165,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -16951,7 +14201,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -16975,7 +14225,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -17011,64 +14261,13 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], "type": "heatmap" } ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], "histogram": [ { "marker": { @@ -17089,7 +14288,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -17125,7 +14324,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -17140,7 +14339,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -17176,7 +14375,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -17269,6 +14468,17 @@ "type": "scattergl" } ], + "scattermap": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermap" + } + ], "scattermapbox": [ { "marker": { @@ -17321,7 +14531,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -17357,7 +14567,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -17448,7 +14658,7 @@ ], "sequential": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -17484,13 +14694,13 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], "sequentialminus": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -17526,7 +14736,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ] @@ -17662,8 +14872,8 @@ "anchor": "y", "constrain": "domain", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "scaleanchor": "y", "title": { @@ -17675,8 +14885,8 @@ "autorange": "reversed", "constrain": "domain", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "title": { "text": "Layer" @@ -17750,13 +14960,26 @@ { "cell_type": "code", "execution_count": 38, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:01:20.813008Z", + "iopub.status.busy": "2026-03-02T20:01:20.812954Z", + "iopub.status.idle": "2026-03-02T20:01:20.997787Z", + "shell.execute_reply": "2026-03-02T20:01:20.997576Z" + } + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Top Name Mover to ablate: L9H9\n", + "Top Name Mover to ablate: L9H9\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "Original logit diff: 3.55\n", "Post ablation logit diff: 2.92\n", "Direct Logit Attribution of top name mover head: 2.99\n", @@ -17772,6 +14995,7 @@ "\n", "\n", "def ablate_top_head_hook(z: Float[torch.Tensor, \"batch pos head_index d_head\"], hook):\n", + " z = z.clone()\n", " z[:, -1, top_name_mover_head, :] = 0\n", " return z\n", "\n", @@ -17780,6 +15004,18 @@ "model.blocks[top_name_mover_layer].attn.hook_z.add_hook(ablate_top_head_hook)\n", "# Runs the model, temporarily adds caching hooks and then removes *all* hooks after running, including the ablation hook.\n", "ablated_logits, ablated_cache = model.run_with_cache(tokens)\n", + "# TransformerBridge workarounds: fix device, pos_embed batch dim, and hook_result shape\n", + "for key in list(ablated_cache.cache_dict.keys()):\n", + " if isinstance(ablated_cache.cache_dict[key], torch.Tensor):\n", + " ablated_cache.cache_dict[key] = ablated_cache.cache_dict[key].to(device)\n", + "if \"hook_pos_embed\" in ablated_cache.cache_dict:\n", + " pe = ablated_cache.cache_dict[\"hook_pos_embed\"]\n", + " if pe.shape[0] == 1 and tokens.shape[0] > 1:\n", + " ablated_cache.cache_dict[\"hook_pos_embed\"] = pe.expand(tokens.shape[0], -1, -1)\n", + "for layer in range(model.cfg.n_layers):\n", + " key = f\"blocks.{layer}.attn.hook_result\"\n", + " if key in ablated_cache.cache_dict and ablated_cache.cache_dict[key].ndim == 3:\n", + " del ablated_cache.cache_dict[key]\n", "print(f\"Original logit diff: {original_average_logit_diff:.2f}\")\n", "print(\n", " f\"Post ablation logit diff: {logits_to_ave_logit_diff(ablated_logits, answer_tokens).item():.2f}\"\n", @@ -17789,7 +15025,7 @@ ")\n", "print(\n", " f\"Naive prediction of post ablation logit diff: {original_average_logit_diff - per_head_logit_diffs.flatten()[top_name_mover].item():.2f}\"\n", - ")" + ")\n" ] }, { @@ -17805,7 +15041,14 @@ { "cell_type": "code", "execution_count": 39, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:01:20.998722Z", + "iopub.status.busy": "2026-03-02T20:01:20.998662Z", + "iopub.status.idle": "2026-03-02T20:01:21.035423Z", + "shell.execute_reply": "2026-03-02T20:01:21.035215Z" + } + }, "outputs": [ { "name": "stdout", @@ -17828,184 +15071,19 @@ "type": "heatmap", "xaxis": "x", "yaxis": "y", - "z": [ - [ - -0.002156503964215517, - -0.0004650682385545224, - 0.00024167183437384665, - 0.0002806585980579257, - -0.0004162999684922397, - -0.0004892416181974113, - -0.002620948012918234, - -0.002935677068307996, - 0.00042561208829283714, - 0.0005418329383246601, - 0.00023754138965159655, - -7.48957390896976e-05 - ], - [ - -0.000658505829051137, - 0.0004060641804244369, - -0.0009330413886345923, - 0.0008937822422012687, - -0.0009785268921405077, - -0.000533820129930973, - -0.0027988189831376076, - -0.004214101936668158, - 0.002578593324869871, - 0.0024506838526576757, - 0.0005351756699383259, - 0.0012349633034318686 - ], - [ - 0.0009405204327777028, - -0.0011168691562488675, - -0.0011541967978700995, - -0.0015697095077484846, - -0.0005699327448382974, - 0.001451514894142747, - 0.002439911477267742, - 0.003158293664455414, - 0.000923738582059741, - -0.003578126197680831, - -0.0010650777257978916, - -0.0003558753523975611 - ], - [ - -0.0005624951445497572, - -1.1960582924075425e-05, - 0.0011531109921634197, - 0.0007360265008173883, - 0.0016493839211761951, - 0.0008800819050520658, - -0.0006905529880896211, - -0.003031972097232938, - 0.0008080147090367973, - 0.00010368914809077978, - -0.0005807994166389108, - -0.0011067037703469396 - ], - [ - -0.0026375530287623405, - 0.0002691895351745188, - -0.0016417437000200152, - -0.003406986128538847, - 0.0017449699807912111, - 0.00046454701805487275, - -0.0007899806369096041, - 0.0018328562146052718, - -0.00086324627045542, - -0.0003978293389081955, - 0.0007879206677898765, - -0.00012048585631418973 - ], - [ - 0.0008688560919836164, - 0.0009473530226387084, - -0.0022812988609075546, - -0.0011803123634308577, - 0.0002407809515716508, - -0.0004318578285165131, - -0.0003728170122485608, - -0.000738416681997478, - 0.0008113418589346111, - -0.00040444196201860905, - -0.007074396125972271, - 0.003946478478610516 - ], - [ - -0.014917617663741112, - -0.0022801742888987064, - 0.0022679336834698915, - -8.302251808345318e-05, - -0.004980948753654957, - 0.0027670026756823063, - 0.006266288459300995, - -0.003485947148874402, - -0.0013348984066396952, - -0.0017918883822858334, - -0.0012231896398589015, - 0.00040514359716326 - ], - [ - -0.0002460568503011018, - -0.005790225230157375, - -0.0004975841729901731, - 0.142182856798172, - -0.0014961492270231247, - -0.019006317481398582, - 0.003133433870971203, - -0.001858205534517765, - -0.011305196210741997, - 0.1922595500946045, - -0.0011892566690221429, - -0.0010282933944836259 - ], - [ - -0.0038003993686288595, - -0.0008570950012654066, - -0.013956742361187935, - 0.00828910805284977, - 0.004315475933253765, - -0.009073829278349876, - -0.08315148949623108, - 0.0034569751005619764, - -0.01805492490530014, - 0.002178061753511429, - 0.29780513048171997, - 0.02409379370510578 - ], - [ - 0.08904723823070526, - -0.0007931794971227646, - 0.07247699797153473, - 0.015016308054327965, - -0.02120928093791008, - 0.05205465108156204, - 1.4411165714263916, - 0.04743674397468567, - -0.03229031339287758, - 0, - 0.0019993737805634737, - -0.00807223655283451 - ], - [ - 0.8600788116455078, - 0.3260062038898468, - 0.16344408690929413, - 0.07133537530899048, - -0.00444837287068367, - 0.000681330740917474, - 0.36613449454307556, - -0.7105098962783813, - -0.002031375654041767, - -0.032143525779247284, - 1.2294330596923828, - 0.0018453558441251516 - ], - [ - 0.016877274960279465, - -0.001730365096591413, - -0.5010868310928345, - 0.02749764919281006, - -0.0059662917628884315, - -0.004944110754877329, - -0.08855228126049042, - 0.006622308399528265, - 0.044124361127614975, - -0.02726735547184944, - -1.134916067123413, - 0.02287953346967697 - ] - ] + "z": { + "bdata": "DlQNu8LU87ldaH05LyaTOSZC2rkVQQC6B8Qru7JkQLtXI985YAkOOngYeTk0EZ24lKEsuiTm1DnglnS650xqOptBgLoC7wu6UWw3u4oWirto/Sg7jpsgO3xKDDqy3aE6aIx2OghjkrqxSZe6/r/NuvRqFbqlPr46+eYfO277TjvGJHI6eH5qu5qXi7oWlbq5OXQTumDLR7eoJ5c6E/JAOngv2DqVtGY6pAc1umWxRrsU0VM6EE3ZOB49GLoQEJG67dwsu6IxjTm8MNe6EUhfuz+45DoHivM5WBBPurY48DrwRGK6bpnQuTSMTjo4nvy44sJjOoVXeDq2fRW7XrWausCUfDnpaeK5bnnDueKUQbpms1Q6kAnUuSzP57skUYE7aWl0vL5tFbsAoRQ7yv6tuBo4o7saVzU7bVbNO8J2ZLva9q66jNzquuhSoLqQlNQ50A6Buda7vbvybwK6m5gRPrUdxLpUs5u8GFlNOx6P87raOTm88d9EPuXgm7qvxYa6MhB5u5CwYLq7qmS89s4HPBZnjTviqRS8akuqvb6MYjsn6JO8prsOOwN6mD76YMU8VF62PbbvT7qfbpQ9qAd2PGC/rbwZN1U9nna4P0BNQj0/QwS9AAAAAIQHAzsqQQS8VC5cP2bqpj4AXic+bxiSPZDDkbuOnDI6Jna7PhbkNb+qIAW7XKgDvSZenT9g3vE6R0KKPFjo4rpGRwC/JUPhPOSAw7uTAqK75lq1vVkA2TvIuzQ9K1/fvAJFkb83brs8", + "dtype": "f4", + "shape": "12, 12" + } } ], "layout": { "coloraxis": { - "cmid": 0, + "cmid": 0.0, "colorscale": [ [ - 0, + 0.0, "rgb(103,0,31)" ], [ @@ -18045,7 +15123,7 @@ "rgb(33,102,172)" ], [ - 1, + 1.0, "rgb(5,48,97)" ] ] @@ -18129,7 +15207,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -18165,7 +15243,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -18189,7 +15267,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -18225,64 +15303,13 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], "type": "heatmap" } ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], "histogram": [ { "marker": { @@ -18303,7 +15330,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -18339,7 +15366,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -18354,7 +15381,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -18390,7 +15417,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -18483,6 +15510,17 @@ "type": "scattergl" } ], + "scattermap": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermap" + } + ], "scattermapbox": [ { "marker": { @@ -18535,7 +15573,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -18571,7 +15609,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -18662,7 +15700,7 @@ ], "sequential": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -18698,13 +15736,13 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], "sequentialminus": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -18740,7 +15778,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ] @@ -18873,8 +15911,8 @@ "anchor": "y", "constrain": "domain", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "scaleanchor": "y", "title": { @@ -18886,8 +15924,8 @@ "autorange": "reversed", "constrain": "domain", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "title": { "text": "Layer" @@ -19064,299 +16102,15 @@ "orientation": "v", "showlegend": false, "type": "scatter", - "x": [ - -0.002156503964215517, - -0.0004650682385545224, - 0.00024167183437384665, - 0.0002806585980579257, - -0.0004162999684922397, - -0.0004892416181974113, - -0.002620948012918234, - -0.002935677068307996, - 0.00042561208829283714, - 0.0005418329383246601, - 0.00023754138965159655, - -7.48957390896976e-05, - -0.000658505829051137, - 0.0004060641804244369, - -0.0009330413886345923, - 0.0008937822422012687, - -0.0009785268921405077, - -0.000533820129930973, - -0.0027988189831376076, - -0.004214101936668158, - 0.002578593324869871, - 0.0024506838526576757, - 0.0005351756699383259, - 0.0012349633034318686, - 0.0009405204327777028, - -0.0011168691562488675, - -0.0011541967978700995, - -0.0015697095077484846, - -0.0005699327448382974, - 0.001451514894142747, - 0.002439911477267742, - 0.003158293664455414, - 0.000923738582059741, - -0.003578126197680831, - -0.0010650777257978916, - -0.0003558753523975611, - -0.0005624951445497572, - -1.1960582924075425e-05, - 0.0011531109921634197, - 0.0007360265008173883, - 0.0016493839211761951, - 0.0008800819050520658, - -0.0006905529880896211, - -0.003031972097232938, - 0.0008080147090367973, - 0.00010368914809077978, - -0.0005807994166389108, - -0.0011067037703469396, - -0.0026375530287623405, - 0.0002691895351745188, - -0.0016417437000200152, - -0.003406986128538847, - 0.0017449699807912111, - 0.00046454701805487275, - -0.0007899806369096041, - 0.0018328562146052718, - -0.00086324627045542, - -0.0003978293389081955, - 0.0007879206677898765, - -0.00012048585631418973, - 0.0008688560919836164, - 0.0009473530226387084, - -0.0022812988609075546, - -0.0011803123634308577, - 0.0002407809515716508, - -0.0004318578285165131, - -0.0003728170122485608, - -0.000738416681997478, - 0.0008113418589346111, - -0.00040444196201860905, - -0.007074396125972271, - 0.003946478478610516, - -0.014917617663741112, - -0.0022801742888987064, - 0.0022679336834698915, - -8.302251808345318e-05, - -0.004980948753654957, - 0.0027670026756823063, - 0.006266288459300995, - -0.003485947148874402, - -0.0013348984066396952, - -0.0017918883822858334, - -0.0012231896398589015, - 0.00040514359716326, - -0.0002460568503011018, - -0.005790225230157375, - -0.0004975841729901731, - 0.142182856798172, - -0.0014961492270231247, - -0.019006317481398582, - 0.003133433870971203, - -0.001858205534517765, - -0.011305196210741997, - 0.1922595500946045, - -0.0011892566690221429, - -0.0010282933944836259, - -0.0038003993686288595, - -0.0008570950012654066, - -0.013956742361187935, - 0.00828910805284977, - 0.004315475933253765, - -0.009073829278349876, - -0.08315148949623108, - 0.0034569751005619764, - -0.01805492490530014, - 0.002178061753511429, - 0.29780513048171997, - 0.02409379370510578, - 0.08904723823070526, - -0.0007931794971227646, - 0.07247699797153473, - 0.015016308054327965, - -0.02120928093791008, - 0.05205465108156204, - 1.4411165714263916, - 0.04743674397468567, - -0.03229031339287758, - 0, - 0.0019993737805634737, - -0.00807223655283451, - 0.8600788116455078, - 0.3260062038898468, - 0.16344408690929413, - 0.07133537530899048, - -0.00444837287068367, - 0.000681330740917474, - 0.36613449454307556, - -0.7105098962783813, - -0.002031375654041767, - -0.032143525779247284, - 1.2294330596923828, - 0.0018453558441251516, - 0.016877274960279465, - -0.001730365096591413, - -0.5010868310928345, - 0.02749764919281006, - -0.0059662917628884315, - -0.004944110754877329, - -0.08855228126049042, - 0.006622308399528265, - 0.044124361127614975, - -0.02726735547184944, - -1.134916067123413, - 0.02287953346967697 - ], + "x": { + "bdata": "DlQNu8LU87ldaH05LyaTOSZC2rkVQQC6B8Qru7JkQLtXI985YAkOOngYeTk0EZ24lKEsuiTm1DnglnS650xqOptBgLoC7wu6UWw3u4oWirto/Sg7jpsgO3xKDDqy3aE6aIx2OghjkrqxSZe6/r/NuvRqFbqlPr46+eYfO277TjvGJHI6eH5qu5qXi7oWlbq5OXQTumDLR7eoJ5c6E/JAOngv2DqVtGY6pAc1umWxRrsU0VM6EE3ZOB49GLoQEJG67dwsu6IxjTm8MNe6EUhfuz+45DoHivM5WBBPurY48DrwRGK6bpnQuTSMTjo4nvy44sJjOoVXeDq2fRW7XrWausCUfDnpaeK5bnnDueKUQbpms1Q6kAnUuSzP57skUYE7aWl0vL5tFbsAoRQ7yv6tuBo4o7saVzU7bVbNO8J2ZLva9q66jNzquuhSoLqQlNQ50A6Buda7vbvybwK6m5gRPrUdxLpUs5u8GFlNOx6P87raOTm88d9EPuXgm7qvxYa6MhB5u5CwYLq7qmS89s4HPBZnjTviqRS8akuqvb6MYjsn6JO8prsOOwN6mD76YMU8VF62PbbvT7qfbpQ9qAd2PGC/rbwZN1U9nna4P0BNQj0/QwS9AAAAAIQHAzsqQQS8VC5cP2bqpj4AXic+bxiSPZDDkbuOnDI6Jna7PhbkNb+qIAW7XKgDvSZenT9g3vE6R0KKPFjo4rpGRwC/JUPhPOSAw7uTAqK75lq1vVkA2TvIuzQ9K1/fvAJFkb83brs8", + "dtype": "f4" + }, "xaxis": "x", - "y": [ - -0.0020563392899930477, - -0.0005101899732835591, - 0.0004685786843765527, - 0.00012512074317783117, - -0.0006028738571330905, - -0.0002429460291750729, - -0.0023189077619463205, - -0.002758360467851162, - 0.000564602785743773, - 0.0009697531932033598, - -0.0002504526637494564, - 4.737317794933915e-06, - -0.0010070882271975279, - 0.00039470894262194633, - -0.00154874159488827, - 0.0014034928753972054, - -0.0012653048615902662, - -0.0011358022456988692, - -0.00281596090644598, - -0.0029645217582583427, - 0.0029190476052463055, - 0.0025743592996150255, - 0.00036239007022231817, - 0.0017548729665577412, - 0.0005569400964304805, - -0.001126631861552596, - -0.0017353934235870838, - -0.0014514457434415817, - -0.00028735760133713484, - 0.0017211002996191382, - 0.0026658899150788784, - 0.00311466702260077, - 0.0005667927907779813, - -0.003666515462100506, - -0.0018847601022571325, - 7.039372576400638e-06, - -0.0007264417363330722, - 0.00011364505917299539, - 0.0014301587361842394, - 0.0007490540738217533, - 0.0020184689201414585, - 0.0007436950691044331, - -0.00046178390039131045, - -0.0039057559333741665, - 0.0011406694538891315, - -4.022853681817651e-05, - -0.0013293239753693342, - -0.0017636751290410757, - -0.0028280913829803467, - 0.00033634810824878514, - -0.0014248639345169067, - -0.003777273464947939, - 0.0015998880844563246, - 0.0002989505883306265, - -0.000804675742983818, - 0.002038792008534074, - -0.0015593919670209289, - -0.0006436670082621276, - 0.0011168173514306545, - -0.00035012533771805465, - 0.0011338205076754093, - 0.0011259170714765787, - -0.002516670385375619, - -0.0014790185960009694, - 0.0003878737334161997, - -6.408110493794084e-05, - -0.0005096744280308485, - -0.0008840755908749998, - 0.0006398351397365332, - -0.0010097370250150561, - -0.006759158335626125, - 0.0033667823299765587, - -0.01514742337167263, - -0.0021350777242332697, - 0.002593174111098051, - -0.00042678468162193894, - -0.005558924749493599, - 0.0026658528950065374, - 0.006411008536815643, - -0.003826778382062912, - -0.0003843410813715309, - -0.0016430341638624668, - -0.0013344454346224666, - -9.20506427064538e-05, - -9.476230479776859e-05, - -0.0057889921590685844, - -0.0006383581785485148, - 0.13493388891220093, - -0.001768707763403654, - -0.018917907029390335, - 0.003873429261147976, - -0.0021450775675475597, - -0.010327338241040707, - 0.18325845897197723, - -0.0007747983909212053, - -0.00104526337236166, - -0.003833949100226164, - -0.0008046097937040031, - -0.012673400342464447, - 0.00804573018103838, - 0.003604492638260126, - -0.009398287162184715, - -0.08272082358598709, - 0.003555194940418005, - -0.018404025584459305, - 0.0017587244510650635, - 0.2896133363246918, - 0.022854052484035492, - 0.08595258742570877, - -0.0006932877004146576, - 0.06817055493593216, - 0.013111240230500698, - -0.021098043769598007, - 0.05112447217106819, - 1.3844914436340332, - 0.045836858451366425, - -0.03830280900001526, - 2.985445976257324, - 0.0019662054255604744, - -0.008030137047171593, - 0.5608693957328796, - 0.17083050310611725, - -0.03361757844686508, - 0.05821544677019119, - -0.0024530249647796154, - 0.0018771197646856308, - 0.28827205300331116, - -1.8986485004425049, - -0.0015286931302398443, - -0.035129792988300323, - 0.4802178740501404, - -0.0009115453576669097, - 0.016075748950242996, - -0.03986122086644173, - -0.3879126012325287, - 0.011123123578727245, - -0.005477819126099348, - -0.0025129620917141438, - -0.08056175708770752, - 0.007518616039305925, - 0.0430111438035965, - -0.040082238614559174, - -0.9702364802360535, - 0.011862239800393581 - ], + "y": { + "bdata": "isMGu+K+BbrkqvU5ijQDOaIJHrr0wn65qfgXu8jFNLttARQ66DZ+OqRNg7kAGZ82SwGEupDyzjm8/sq6avW3OtfXpbrN3pS61Is4u8BIQrs0TT87eLYoO6D9vTn2AuY6J/8ROpaqk7o/d+O6zD++upivlrmylOE6L7YuO3wfTDvFkhQ6aElwu6wH97oAC+w2YG4+usBt7ji2d7s6+ltEOgVIBDsY9EI6Ih7yuW71f7togpU6oAcpuPQ6rrosLOe6qlk5u4tmsDluw7q6Z4x3uxa00To+uJw5nepSupSbBTtHYcy6T74ouvdhkjq6jLe5EJyUOnWTkzqO6iS7VNzBushoyzmwXIa46JwFuhTEZ7r6vSc69liEutN63bvRpFw7Si14vIvrC7uW8Sk7fbrfuX4otrsati47bBTSOxDNerucfsm50VnXuvrnrrrwZcC4xfHGuHuxvbsLVye6USwKPh7X57rl+Zq8NNh9O3aUDLtmNCm8WKg7PiEcS7oa/4i6zkJ7u0vuUrr6o0+8/NEDPLY1bDvG+hm8lmmpvZ78aDtCxJa8xYDmOkhIlD4SObs81AewPQzANbrOnIs9MdFWPBDWrLy4Z1E9GzexP6C/Oz2+4xy9rhE/QAnbADuMkAO8TpUPPzLuLj5msgm9oXNuPerBILvoCfY6f5iTPvIG878XXsi6EOQPvWPf9T5w+m66cLGDPH9GI71vnMa+XD82PDZ/s7s9sSS7av2kvUBf9jt4LDA9Fy0kvbVheL/UWkI8", + "dtype": "f4" + }, "yaxis": "y" } ], @@ -19440,7 +16194,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -19476,7 +16230,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -19500,7 +16254,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -19536,64 +16290,13 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], "type": "heatmap" } ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], "histogram": [ { "marker": { @@ -19614,7 +16317,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -19650,7 +16353,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -19665,7 +16368,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -19701,7 +16404,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -19794,6 +16497,17 @@ "type": "scattergl" } ], + "scattermap": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermap" + } + ], "scattermapbox": [ { "marker": { @@ -19846,7 +16560,7 @@ }, "colorscale": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -19882,7 +16596,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], @@ -19973,7 +16687,7 @@ ], "sequential": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -20009,13 +16723,13 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ], "sequentialminus": [ [ - 0, + 0.0, "#0d0887" ], [ @@ -20051,7 +16765,7 @@ "#fdca26" ], [ - 1, + 1.0, "#f0f921" ] ] @@ -20186,8 +16900,8 @@ "xaxis": { "anchor": "y", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "range": [ -3, @@ -20200,8 +16914,8 @@ "yaxis": { "anchor": "x", "domain": [ - 0, - 1 + 0.0, + 1.0 ], "range": [ -3, @@ -20252,7 +16966,14 @@ { "cell_type": "code", "execution_count": 40, - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-02T20:01:21.036306Z", + "iopub.status.busy": "2026-03-02T20:01:21.036245Z", + "iopub.status.idle": "2026-03-02T20:01:21.093321Z", + "shell.execute_reply": "2026-03-02T20:01:21.093130Z" + } + }, "outputs": [ { "name": "stdout", @@ -20312,7 +17033,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv", + "display_name": "transformer-lens", "language": "python", "name": "python3" }, @@ -20326,14 +17047,377 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.12.12" }, - "vscode": { - "interpreter": { - "hash": "eb812820b5094695c8a581672e17220e30dd2c15d704c018326e3cc2e1a566f1" + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": { + "09c32ea4455548fdbb35d27edb7d40c3": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_26ad8ee867a547f1a8de29ca0b85e081", + "IPY_MODEL_beba0111e57442d49fc3810a432dbdbc", + "IPY_MODEL_157689103b9547fa8c329d195a2e9ca0" + ], + "layout": "IPY_MODEL_7822d363af08462485ca18444f081e6e", + "tabbable": null, + "tooltip": null + } + }, + "157689103b9547fa8c329d195a2e9ca0": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_ef520bc3e2ef45fda85f0fca0e37888c", + "placeholder": "​", + "style": "IPY_MODEL_51d53e278d5a4494850d176414951b9c", + "tabbable": null, + "tooltip": null, + "value": " 148/148 [00:00<00:00, 5688.60it/s, Materializing param=transformer.wte.weight]" + } + }, + "26ad8ee867a547f1a8de29ca0b85e081": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_34f724ab14464382b47e6dea48e104b9", + "placeholder": "​", + "style": "IPY_MODEL_bf880351381e472e84d2aa89e248f162", + "tabbable": null, + "tooltip": null, + "value": "Loading weights: 100%" + } + }, + "34f724ab14464382b47e6dea48e104b9": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "39b052002e60437ea5dab202b1345319": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "51d53e278d5a4494850d176414951b9c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "7822d363af08462485ca18444f081e6e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "9f56afbdf0e74664947ff00ee6684704": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "beba0111e57442d49fc3810a432dbdbc": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_9f56afbdf0e74664947ff00ee6684704", + "max": 148.0, + "min": 0.0, + "orientation": "horizontal", + "style": "IPY_MODEL_39b052002e60437ea5dab202b1345319", + "tabbable": null, + "tooltip": null, + "value": 148.0 + } + }, + "bf880351381e472e84d2aa89e248f162": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "ef520bc3e2ef45fda85f0fca0e37888c": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + } + }, + "version_major": 2, + "version_minor": 0 } } }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} From 136f958e74f11e05a58c4bac859815b076e4da50 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Tue, 3 Mar 2026 17:37:50 -0600 Subject: [PATCH 5/7] Removed inline bug fix code in favorite of systemic fixes --- demos/Exploratory_Analysis_Demo.ipynb | 66 ++++----------------------- 1 file changed, 8 insertions(+), 58 deletions(-) diff --git a/demos/Exploratory_Analysis_Demo.ipynb b/demos/Exploratory_Analysis_Demo.ipynb index 85325d057..bd6fe1fe3 100644 --- a/demos/Exploratory_Analysis_Demo.ipynb +++ b/demos/Exploratory_Analysis_Demo.ipynb @@ -382,18 +382,7 @@ "tokens = model.to_tokens(prompts, prepend_bos=True)\n", "\n", "# Run the model and cache all activations\n", - "original_logits, cache = model.run_with_cache(tokens)\n", - "\n", - "# TransformerBridge workarounds:\n", - "# Bug 5: run_with_cache moves tensors to CPU; move back to model device\n", - "for key in list(cache.cache_dict.keys()):\n", - " if isinstance(cache.cache_dict[key], torch.Tensor):\n", - " cache.cache_dict[key] = cache.cache_dict[key].to(device)\n", - "# Bug 3: pos_embed cached as [1,seq,d] instead of [batch,seq,d]\n", - "if \"hook_pos_embed\" in cache.cache_dict:\n", - " pe = cache.cache_dict[\"hook_pos_embed\"]\n", - " if pe.shape[0] == 1 and tokens.shape[0] > 1:\n", - " cache.cache_dict[\"hook_pos_embed\"] = pe.expand(tokens.shape[0], -1, -1)\n" + "original_logits, cache = model.run_with_cache(tokens)" ] }, { @@ -994,35 +983,13 @@ "metadata": {}, "outputs": [], "source": [ - "def _cache_lookup(cache, hook_name, expected_ndim=None):\n", - " \"\"\"Look up cache value, handling bridge hook.name aliasing.\n", - "\n", - " TransformerBridge hook.name may differ from the cache key (e.g. hook.name is\n", - " 'blocks.0.attn.hook_result' but cache stores 'blocks.0.hook_attn_out').\n", - " Additionally, compute_head_results may overwrite hook_result with a 4D tensor,\n", - " so we fall back to the block-level alias when the shape doesn't match.\n", - " \"\"\"\n", - " try:\n", - " val = cache[hook_name]\n", - " if expected_ndim is not None and val.ndim != expected_ndim:\n", - " raise KeyError(f\"Shape mismatch: expected {expected_ndim}D, got {val.ndim}D\")\n", - " return val\n", - " except KeyError:\n", - " # Try the block-level alias: blocks.X.attn.hook_result -> blocks.X.hook_attn_out\n", - " parts = hook_name.split(\".\")\n", - " if len(parts) >= 4 and parts[2] == \"attn\":\n", - " alt_key = f\"{parts[0]}.{parts[1]}.hook_attn_out\"\n", - " return cache[alt_key]\n", - " raise\n", - "\n", - "\n", "def patch_residual_component(\n", " corrupted_residual_component: Float[torch.Tensor, \"batch pos d_model\"],\n", " hook,\n", " pos,\n", " clean_cache,\n", "):\n", - " clean_value = _cache_lookup(clean_cache, hook.name, expected_ndim=3)[:, pos, :].clone()\n", + " clean_value = clean_cache[hook.name][:, pos, :].clone()\n", " corrupted_residual_component = corrupted_residual_component.clone()\n", " corrupted_residual_component[:, pos : pos + 1, :] = clean_value.unsqueeze(1)\n", " return corrupted_residual_component\n", @@ -1214,7 +1181,7 @@ " head_index,\n", " clean_cache,\n", "):\n", - " clean_value = _cache_lookup(clean_cache, hook.name)[\n", + " clean_value = clean_cache[hook.name][\n", " :, :, head_index, :\n", " ].clone()\n", " corrupted_head_vector = corrupted_head_vector.clone()\n", @@ -1379,7 +1346,7 @@ " head_index,\n", " clean_cache,\n", "):\n", - " clean_value = _cache_lookup(clean_cache, hook.name)[\n", + " clean_value = clean_cache[hook.name][\n", " :, head_index, :, :\n", " ].clone()\n", " corrupted_head_pattern = corrupted_head_pattern.clone()\n", @@ -1807,18 +1774,6 @@ "model.blocks[top_name_mover_layer].attn.hook_z.add_hook(ablate_top_head_hook)\n", "# Runs the model, temporarily adds caching hooks and then removes *all* hooks after running, including the ablation hook.\n", "ablated_logits, ablated_cache = model.run_with_cache(tokens)\n", - "# TransformerBridge workarounds: fix device, pos_embed batch dim, and hook_result shape\n", - "for key in list(ablated_cache.cache_dict.keys()):\n", - " if isinstance(ablated_cache.cache_dict[key], torch.Tensor):\n", - " ablated_cache.cache_dict[key] = ablated_cache.cache_dict[key].to(device)\n", - "if \"hook_pos_embed\" in ablated_cache.cache_dict:\n", - " pe = ablated_cache.cache_dict[\"hook_pos_embed\"]\n", - " if pe.shape[0] == 1 and tokens.shape[0] > 1:\n", - " ablated_cache.cache_dict[\"hook_pos_embed\"] = pe.expand(tokens.shape[0], -1, -1)\n", - "for layer in range(model.cfg.n_layers):\n", - " key = f\"blocks.{layer}.attn.hook_result\"\n", - " if key in ablated_cache.cache_dict and ablated_cache.cache_dict[key].ndim == 3:\n", - " del ablated_cache.cache_dict[key]\n", "print(f\"Original logit diff: {original_average_logit_diff:.2f}\")\n", "print(\n", " f\"Post ablation logit diff: {logits_to_ave_logit_diff(ablated_logits, answer_tokens).item():.2f}\"\n", @@ -1828,7 +1783,7 @@ ")\n", "print(\n", " f\"Naive prediction of post ablation logit diff: {original_average_logit_diff - per_head_logit_diffs.flatten()[top_name_mover].item():.2f}\"\n", - ")\n" + ")" ] }, { @@ -1916,7 +1871,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv", + "display_name": "transformer-lens", "language": "python", "name": "python3" }, @@ -1930,14 +1885,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" - }, - "vscode": { - "interpreter": { - "hash": "eb812820b5094695c8a581672e17220e30dd2c15d704c018326e3cc2e1a566f1" - } + "version": "3.12.12" } }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} From 5a0ea0c0fff26d61db9bb4feb0c1036c08e4b7ed Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Tue, 3 Mar 2026 19:27:39 -0600 Subject: [PATCH 6/7] Additional bug resolution --- demos/Exploratory_Analysis_Demo.ipynb | 9 +-------- transformer_lens/ActivationCache.py | 22 +++++++++++++++++++--- transformer_lens/model_bridge/bridge.py | 19 +++++++++++++------ 3 files changed, 33 insertions(+), 17 deletions(-) diff --git a/demos/Exploratory_Analysis_Demo.ipynb b/demos/Exploratory_Analysis_Demo.ipynb index bd6fe1fe3..c2c86cea8 100644 --- a/demos/Exploratory_Analysis_Demo.ipynb +++ b/demos/Exploratory_Analysis_Demo.ipynb @@ -715,13 +715,6 @@ "metadata": {}, "outputs": [], "source": [ - "# hook_result shape: Bridge captures [batch,pos,d_model] not [batch,pos,n_heads,d_head]\n", - "# Remove so compute_head_results can recompute from z + W_O\n", - "for layer in range(model.cfg.n_layers):\n", - " key = f\"blocks.{layer}.attn.hook_result\"\n", - " if key in cache.cache_dict and cache.cache_dict[key].ndim == 3:\n", - " del cache.cache_dict[key]\n", - "\n", "per_head_residual, labels = cache.stack_head_results(\n", " layer=-1, pos_slice=-1, return_labels=True\n", ")\n", @@ -1890,4 +1883,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/transformer_lens/ActivationCache.py b/transformer_lens/ActivationCache.py index 55cdbf34a..93bc856b1 100644 --- a/transformer_lens/ActivationCache.py +++ b/transformer_lens/ActivationCache.py @@ -677,9 +677,25 @@ def compute_head_results( Intended use is to enable use_attn_results when running and caching the model, but this can be useful if you forget. """ - if "blocks.0.attn.hook_result" in self.cache_dict: - logging.warning("Tried to compute head results when they were already cached") - return + # If valid 4D per-head results already exist (from forward pass with + # use_attn_result=True, or from a prior compute_head_results() call), + # return early to preserve idempotency. + # + # TransformerBridge may populate hook_result with a 3D combined-output + # tensor (from the hook_result → hook_out alias). We detect these + # wrong-shape entries by checking ndim and remove them before + # recomputing the correct 4D per-head results from z and W_O. + first_key = "blocks.0.attn.hook_result" + if first_key in self.cache_dict: + val = self.cache_dict[first_key] + if isinstance(val, torch.Tensor) and val.ndim >= 4: + logging.warning("Tried to compute head results when they were already cached") + return + # Stale 3D entries exist — remove them before recomputing + for layer in range(self.model.cfg.n_layers): + key = f"blocks.{layer}.attn.hook_result" + if key in self.cache_dict: + del self.cache_dict[key] for layer in range(self.model.cfg.n_layers): # Note that we haven't enabled set item on this object so we need to edit the underlying # cache_dict directly. diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index 6e738bcd1..6d7337bd8 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -206,8 +206,6 @@ def _register_aliases(self) -> None: for part in single_target.split("."): target_obj = getattr(target_obj, part) object.__setattr__(self, alias_name, target_obj) - if isinstance(target_obj, HookPoint): - target_obj.name = alias_name break except AttributeError: continue @@ -216,8 +214,6 @@ def _register_aliases(self) -> None: for part in target_path.split("."): target_obj = getattr(target_obj, part) object.__setattr__(self, alias_name, target_obj) - if isinstance(target_obj, HookPoint): - target_obj.name = alias_name except AttributeError: pass @@ -405,6 +401,11 @@ def _add_aliases_to_hooks(self, hooks: Dict[str, HookPoint]) -> None: def _scan_existing_hooks(self, module: nn.Module, prefix: str = "") -> None: """Scan existing modules for hooks and add them to registry.""" visited = set() + # Track which HookPoint objects have already been named so that + # alias entries (from get_hooks() in compatibility mode) do not + # overwrite the canonical name. get_hooks() returns canonical + # entries first, so the first name assigned is always canonical. + named_hook_ids: set = set() def scan_module(mod: nn.Module, path: str = "") -> None: obj_id = id(mod) @@ -417,7 +418,10 @@ def scan_module(mod: nn.Module, path: str = "") -> None: hooks_dict = cast(Dict[str, HookPoint], component_hooks) for hook_name, hook in hooks_dict.items(): full_name = f"{path}.{hook_name}" if path else hook_name - hook.name = full_name + hook_id = id(hook) + if hook_id not in named_hook_ids: + hook.name = full_name + named_hook_ids.add(hook_id) self._hook_registry[full_name] = hook for attr_name in dir(mod): if attr_name.startswith("_"): @@ -448,7 +452,10 @@ def scan_module(mod: nn.Module, path: str = "") -> None: continue name = f"{path}.{attr_name}" if path else attr_name if isinstance(attr, HookPoint): - attr.name = name + hook_id = id(attr) + if hook_id not in named_hook_ids: + attr.name = name + named_hook_ids.add(hook_id) self._hook_registry[name] = attr for child_name, child_module in mod.named_children(): if ( From bda13c0247768d887e2b3e4f9050b662608892f8 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Tue, 3 Mar 2026 19:52:00 -0600 Subject: [PATCH 7/7] More bug fixes --- transformer_lens/ActivationCache.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/transformer_lens/ActivationCache.py b/transformer_lens/ActivationCache.py index 93bc856b1..cf93fb8ec 100644 --- a/transformer_lens/ActivationCache.py +++ b/transformer_lens/ActivationCache.py @@ -750,11 +750,11 @@ def stack_head_results( # Default to the residual stream immediately pre unembed layer = self.model.cfg.n_layers - if "blocks.0.attn.hook_result" not in self.cache_dict: - print( - "Tried to stack head results when they weren't cached. Computing head results now" - ) - self.compute_head_results() + # Always call compute_head_results() – it handles idempotency + # (returns early for valid 4D data) and also cleans up any stale 3D + # entries that TransformerBridge's hook_result alias may have placed + # in the cache. + self.compute_head_results() components: Any = [] labels = []