From bcfcd83e05a9a90d14f415fb0af8a3a35ac8fe18 Mon Sep 17 00:00:00 2001 From: degenfabian Date: Mon, 18 Aug 2025 20:17:57 +0200 Subject: [PATCH 1/3] updated loading in patchscopes generation demo to use transformer bridge --- demos/Patchscopes_Generation_Demo.ipynb | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/demos/Patchscopes_Generation_Demo.ipynb b/demos/Patchscopes_Generation_Demo.ipynb index 49c4655d4..b249d112e 100644 --- a/demos/Patchscopes_Generation_Demo.ipynb +++ b/demos/Patchscopes_Generation_Demo.ipynb @@ -65,7 +65,7 @@ "from typing import List, Callable, Tuple, Union\n", "from functools import partial\n", "from jaxtyping import Float\n", - "from transformer_lens import HookedTransformer\n", + "from transformer_lens.model_bridge import TransformerBridge\n", "from transformer_lens.ActivationCache import ActivationCache\n", "import transformer_lens.utils as utils\n", "from transformer_lens.hook_points import (\n", @@ -148,7 +148,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -217,7 +217,8 @@ "source": [ "# NBVAL_IGNORE_OUTPUT\n", "# I'm using an M2 macbook air, so I use CPU for better support\n", - "model = HookedTransformer.from_pretrained(\"gpt2-small\", device=\"cpu\")\n", + "model = TransformerBridge.boot_transformers(\"gpt2\", device=\"cpu\")\n", + "model.enable_compatibility_mode()\n", "model.eval()" ] }, @@ -263,17 +264,17 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "def get_source_representation(prompts: List[str], layer_id: int, model: HookedTransformer, pos_id: Union[int, List[int]]=None) -> torch.Tensor:\n", + "def get_source_representation(prompts: List[str], layer_id: int, model: TransformerBridge, pos_id: Union[int, List[int]]=None) -> torch.Tensor:\n", " \"\"\"Get source hidden representation represented by (S, i, M, l)\n", " \n", " Args:\n", " - prompts (List[str]): a list of source prompts\n", " - layer_id (int): the layer id of the model\n", - " - model (HookedTransformer): the source model\n", + " - model (TransformerBridge): the source model\n", " - pos_id (Union[int, List[int]]): the position id(s) of the model, if None, return all positions\n", "\n", " Returns:\n", @@ -325,19 +326,19 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# recall the target representation (T,i*,f,M*,l*), and we also need the hidden representation from our source model (S, i, M, l)\n", - "def feed_source_representation(source_rep: torch.Tensor, prompt: List[str], f: Callable, model: HookedTransformer, layer_id: int, pos_id: Union[int, List[int]]=None) -> ActivationCache:\n", + "def feed_source_representation(source_rep: torch.Tensor, prompt: List[str], f: Callable, model: TransformerBridge, layer_id: int, pos_id: Union[int, List[int]]=None) -> ActivationCache:\n", " \"\"\"Feed the source hidden representation to the target model\n", " \n", " Args:\n", " - source_rep (torch.Tensor): the source hidden representation\n", " - prompt (List[str]): the target prompt\n", " - f (Callable): the mapping function\n", - " - model (HookedTransformer): the target model\n", + " - model (TransformerBridge): the target model\n", " - layer_id (int): the layer id of the target model\n", " - pos_id (Union[int, List[int]]): the position id(s) of the target model, if None, return all positions\n", " \"\"\"\n", @@ -417,11 +418,11 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "def generate_with_patching(model: HookedTransformer, prompts: List[str], target_f: Callable, max_new_tokens: int = 50):\n", + "def generate_with_patching(model: TransformerBridge, prompts: List[str], target_f: Callable, max_new_tokens: int = 50):\n", " temp_prompts = prompts\n", " input_tokens = model.to_tokens(temp_prompts)\n", " for _ in range(max_new_tokens):\n", From a2332efbb614d7b09263f6a329a136c9e3827b88 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Fri, 27 Feb 2026 15:07:35 -0600 Subject: [PATCH 2/3] Migrate Patchscopes Generation Demo to TransformerBridge - Replace HookedTransformer with TransformerBridge.boot_transformers() - Fix deprecated ipython.magic() to ipython.run_line_magic() - Clear stale outputs from unrun cells All 20 cells pass locally. --- demos/Patchscopes_Generation_Demo.ipynb | 7449 +++++++++++------------ 1 file changed, 3690 insertions(+), 3759 deletions(-) diff --git a/demos/Patchscopes_Generation_Demo.ipynb b/demos/Patchscopes_Generation_Demo.ipynb index 2a9109154..fbfb70ff3 100644 --- a/demos/Patchscopes_Generation_Demo.ipynb +++ b/demos/Patchscopes_Generation_Demo.ipynb @@ -1,3783 +1,3714 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\n", - " \"Open\n", - "" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Patchscopes & Generation with Patching\n", - "\n", - "This notebook contains a demo for Patchscopes (https://arxiv.org/pdf/2401.06102) and demonstrates how to generate multiple tokens with patching. Since there're also some applications in [Patchscopes](##Patchscopes-pipeline) that require generating multiple tokens with patching, I think it's suitable to put both of them in the same notebook. Additionally, generation with patching can be well-described using Patchscopes. Therefore, I simply implement it with the Patchscopes pipeline (see [here](##Generation-with-patching))." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Setup (Ignore)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", - "import os\n", - "\n", - "DEBUG_MODE = False\n", - "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", - "try:\n", - " import google.colab\n", - "\n", - " IN_COLAB = True\n", - " print(\"Running as a Colab notebook\")\n", - "except:\n", - " IN_COLAB = False\n", - " print(\"Running as a Jupyter notebook - intended for development only!\")\n", - " from IPython import get_ipython\n", - "\n", - " ipython = get_ipython()\n", - " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", - " ipython.run_line_magic(\"load_ext\", \"autoreload\")\n", - " ipython.run_line_magic(\"autoreload\", \"2\")\n", - "\n", - "if IN_COLAB or IN_GITHUB:\n", - " %pip install transformer_lens\n", - " %pip install torchtyping\n", - " # Install my janky personal plotting utils\n", - " %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n", - " # Install another version of node that makes PySvelte work way faster\n", - " %pip install circuitsvis\n", - " # Needed for PySvelte to work, v3 came out and broke things...\n", - " %pip install typeguard==2.13.3\n", - "\n", - "import torch\n", - "from typing import List, Callable, Tuple, Union\n", - "from functools import partial\n", - "from jaxtyping import Float\n", - "from transformer_lens import HookedTransformer\n", - "from transformer_lens.ActivationCache import ActivationCache\n", - "import transformer_lens.utils as utils\n", - "from transformer_lens.hook_points import (\n", - " HookPoint,\n", - ") # Hooking utilities" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Helper Funcs\n", - "\n", - "A helper function to plot logit lens" - ] - }, - { - "cell_type": "code", - "execution_count": 116, - "metadata": {}, - "outputs": [], - "source": [ - "import plotly.graph_objects as go\n", - "import numpy as np\n", - "\n", - "# Parameters\n", - "num_layers = 5\n", - "seq_len = 10\n", - "\n", - "# Create a matrix of tokens for demonstration\n", - "tokens = np.array([[\"token_{}_{}\".format(i, j) for j in range(seq_len)] for i in range(num_layers)])[::-1]\n", - "values = np.random.rand(num_layers, seq_len)\n", - "orig_tokens = ['Token {}'.format(i) for i in range(seq_len)]\n", - "\n", - "def draw_logit_lens(num_layers, seq_len, orig_tokens, tokens, values):\n", - " # Create the heatmap\n", - " fig = go.Figure(data=go.Heatmap(\n", - " z=values,\n", - " x=orig_tokens,\n", - " y=['Layer {}'.format(i) for i in range(num_layers)][::-1],\n", - " colorscale='Blues',\n", - " showscale=True,\n", - " colorbar=dict(title='Value')\n", - " ))\n", - "\n", - " # Add text annotations\n", - " annotations = []\n", - " for i in range(num_layers):\n", - " for j in range(seq_len):\n", - " annotations.append(\n", - " dict(\n", - " x=j, y=i,\n", - " text=tokens[i, j],\n", - " showarrow=False,\n", - " font=dict(color='white')\n", - " )\n", - " )\n", - "\n", - " fig.update_layout(\n", - " annotations=annotations,\n", - " xaxis=dict(side='top'),\n", - " yaxis=dict(autorange='reversed'),\n", - " margin=dict(l=50, r=50, t=100, b=50),\n", - " width=1000,\n", - " height=600,\n", - " plot_bgcolor='white'\n", - " )\n", - "\n", - " # Show the plot\n", - " fig.show()\n", - "# draw_logit_lens(num_layers, seq_len, orig_tokens, tokens, values)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Model Preparation" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ + "cells": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Loaded pretrained model gpt2-small into HookedTransformer\n" - ] + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + " \"Open\n", + "" + ] }, { - "data": { - "text/plain": [ - "HookedTransformer(\n", - " (embed): Embed()\n", - " (hook_embed): HookPoint()\n", - " (pos_embed): PosEmbed()\n", - " (hook_pos_embed): HookPoint()\n", - " (blocks): ModuleList(\n", - " (0-11): 12 x TransformerBlock(\n", - " (ln1): LayerNormPre(\n", - " (hook_scale): HookPoint()\n", - " (hook_normalized): HookPoint()\n", - " )\n", - " (ln2): LayerNormPre(\n", - " (hook_scale): HookPoint()\n", - " (hook_normalized): HookPoint()\n", - " )\n", - " (attn): Attention(\n", - " (hook_k): HookPoint()\n", - " (hook_q): HookPoint()\n", - " (hook_v): HookPoint()\n", - " (hook_z): HookPoint()\n", - " (hook_attn_scores): HookPoint()\n", - " (hook_pattern): HookPoint()\n", - " (hook_result): HookPoint()\n", - " )\n", - " (mlp): MLP(\n", - " (hook_pre): HookPoint()\n", - " (hook_post): HookPoint()\n", - " )\n", - " (hook_attn_in): HookPoint()\n", - " (hook_q_input): HookPoint()\n", - " (hook_k_input): HookPoint()\n", - " (hook_v_input): HookPoint()\n", - " (hook_mlp_in): HookPoint()\n", - " (hook_attn_out): HookPoint()\n", - " (hook_mlp_out): HookPoint()\n", - " (hook_resid_pre): HookPoint()\n", - " (hook_resid_mid): HookPoint()\n", - " (hook_resid_post): HookPoint()\n", - " )\n", - " )\n", - " (ln_final): LayerNormPre(\n", - " (hook_scale): HookPoint()\n", - " (hook_normalized): HookPoint()\n", - " )\n", - " (unembed): Unembed()\n", - ")" + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Patchscopes & Generation with Patching\n", + "\n", + "This notebook contains a demo for Patchscopes (https://arxiv.org/pdf/2401.06102) and demonstrates how to generate multiple tokens with patching. Since there're also some applications in [Patchscopes](##Patchscopes-pipeline) that require generating multiple tokens with patching, I think it's suitable to put both of them in the same notebook. Additionally, generation with patching can be well-described using Patchscopes. Therefore, I simply implement it with the Patchscopes pipeline (see [here](##Generation-with-patching))." ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# NBVAL_IGNORE_OUTPUT\n", - "# I'm using an M2 macbook air, so I use CPU for better support\n", - "model = HookedTransformer.from_pretrained(\"gpt2-small\", device=\"cpu\")\n", - "model.eval()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Patchscopes Definition\n", - "\n", - "Here we first wirte down the formal definition decribed in the paper https://arxiv.org/pdf/2401.06102.\n", - "\n", - "The representations are:\n", - "\n", - "source: (S, i, M, l), where S is the source prompt, i is the source position, M is the source model, and l is the source layer.\n", - "\n", - "target: (T,i*,f,M*,l*), where T is the target prompt, i* is the target position, M* is the target model, l* is the target layer, and f is the mapping function that takes the original hidden states as input and output the target hidden states\n", - "\n", - "By defulat, S = T, i = i*, M = M*, l = l*, f = identity function" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Patchscopes Pipeline\n", - "\n", - "### Get hidden representation from the source model\n", - "\n", - "1. We first need to extract the source hidden states from model M at position i of layer l with prompt S. In TransformerLens, we can do this using run_with_cache.\n", - "2. Then, we map the source representation with a function f, and feed the hidden representation to the target position using a hook. Specifically, we focus on residual stream (resid_post), whereas you can manipulate more fine-grainedly with TransformerLens\n" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [], - "source": [ - "prompts = [\"Patchscopes is a nice tool to inspect hidden representation of language model\"]\n", - "input_tokens = model.to_tokens(prompts)\n", - "clean_logits, clean_cache = model.run_with_cache(input_tokens)" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [], - "source": [ - "def get_source_representation(prompts: List[str], layer_id: int, model: HookedTransformer, pos_id: Union[int, List[int]]=None) -> torch.Tensor:\n", - " \"\"\"Get source hidden representation represented by (S, i, M, l)\n", - " \n", - " Args:\n", - " - prompts (List[str]): a list of source prompts\n", - " - layer_id (int): the layer id of the model\n", - " - model (HookedTransformer): the source model\n", - " - pos_id (Union[int, List[int]]): the position id(s) of the model, if None, return all positions\n", - "\n", - " Returns:\n", - " - source_rep (torch.Tensor): the source hidden representation\n", - " \"\"\"\n", - " input_tokens = model.to_tokens(prompts)\n", - " _, cache = model.run_with_cache(input_tokens)\n", - " layer_name = \"blocks.{id}.hook_resid_post\"\n", - " layer_name = layer_name.format(id=layer_id)\n", - " if pos_id is None:\n", - " return cache[layer_name][:, :, :]\n", - " else:\n", - " return cache[layer_name][:, pos_id, :]" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [], - "source": [ - "source_rep = get_source_representation(\n", - " prompts=[\"Patchscopes is a nice tool to inspect hidden representation of language model\"],\n", - " layer_id=2,\n", - " model=model,\n", - " pos_id=5\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Feed the representation to the target position\n", - "\n", - "First we need to map the representation using mapping function f, and then feed the target representation to the target position represented by (T,i*,f,M*,l*)" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [], - "source": [ - "# here we use an identity function for demonstration purposes\n", - "def identity_function(source_rep: torch.Tensor) -> torch.Tensor:\n", - " return source_rep" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [], - "source": [ - "# recall the target representation (T,i*,f,M*,l*), and we also need the hidden representation from our source model (S, i, M, l)\n", - "def feed_source_representation(source_rep: torch.Tensor, prompt: List[str], f: Callable, model: HookedTransformer, layer_id: int, pos_id: Union[int, List[int]]=None) -> ActivationCache:\n", - " \"\"\"Feed the source hidden representation to the target model\n", - " \n", - " Args:\n", - " - source_rep (torch.Tensor): the source hidden representation\n", - " - prompt (List[str]): the target prompt\n", - " - f (Callable): the mapping function\n", - " - model (HookedTransformer): the target model\n", - " - layer_id (int): the layer id of the target model\n", - " - pos_id (Union[int, List[int]]): the position id(s) of the target model, if None, return all positions\n", - " \"\"\"\n", - " mapped_rep = f(source_rep)\n", - " # similar to what we did for activation patching, we need to define a function to patch the hidden representation\n", - " def resid_ablation_hook(\n", - " value: Float[torch.Tensor, \"batch pos d_resid\"],\n", - " hook: HookPoint\n", - " ) -> Float[torch.Tensor, \"batch pos d_resid\"]:\n", - " # print(f\"Shape of the value tensor: {value.shape}\")\n", - " # print(f\"Shape of the hidden representation at the target position: {value[:, pos_id, :].shape}\")\n", - " value[:, pos_id, :] = mapped_rep\n", - " return value\n", - " \n", - " input_tokens = model.to_tokens(prompt)\n", - "\n", - " logits = model.run_with_hooks(\n", - " input_tokens,\n", - " return_type=\"logits\",\n", - " fwd_hooks=[(\n", - " utils.get_act_name(\"resid_post\", layer_id),\n", - " resid_ablation_hook\n", - " )]\n", - " )\n", - " \n", - " return logits" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": {}, - "outputs": [], - "source": [ - "patched_logits = feed_source_representation(\n", - " source_rep=source_rep,\n", - " prompt=prompts,\n", - " pos_id=3,\n", - " f=identity_function,\n", - " model=model,\n", - " layer_id=2\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": {}, - "outputs": [ + }, { - "data": { - "text/plain": [ - "(tensor([[ 3.5811, 3.5322, 2.6463, ..., -4.3504, -1.7939, 3.3541]],\n", - " grad_fn=),\n", - " tensor([[ 3.2431, 3.2708, 1.9591, ..., -4.2666, -2.2141, 3.4965]],\n", - " grad_fn=))" + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup (Ignore)" ] - }, - "execution_count": 29, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# NBVAL_IGNORE_OUTPUT\n", - "clean_logits[:, 5], patched_logits[:, 5]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Generation with Patching\n", - "\n", - "In the last step, we've implemented the basic version of Patchscopes where we can only run one single forward pass. Let's now unlock the power by allowing it to generate multiple tokens!" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": {}, - "outputs": [], - "source": [ - "def generate_with_patching(model: HookedTransformer, prompts: List[str], target_f: Callable, max_new_tokens: int = 50):\n", - " temp_prompts = prompts\n", - " input_tokens = model.to_tokens(temp_prompts)\n", - " for _ in range(max_new_tokens):\n", - " logits = target_f(\n", - " prompt=temp_prompts,\n", - " )\n", - " next_tok = torch.argmax(logits[:, -1, :])\n", - " input_tokens = torch.cat((input_tokens, next_tok.view(input_tokens.size(0), 1)), dim=1)\n", - " temp_prompts = model.to_string(input_tokens)\n", - "\n", - " return model.to_string(input_tokens)[0]" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": {}, - "outputs": [ + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "<|endoftext|>Patchscopes is a nice tool to inspect hidden representation of language model file bit file\n" - ] - } - ], - "source": [ - "prompts = [\"Patchscopes is a nice tool to inspect hidden representation of language model\"]\n", - "input_tokens = model.to_tokens(prompts)\n", - "target_f = partial(\n", - " feed_source_representation,\n", - " source_rep=source_rep,\n", - " pos_id=-1,\n", - " f=identity_function,\n", - " model=model,\n", - " layer_id=2\n", - ")\n", - "gen = generate_with_patching(model, prompts, target_f, max_new_tokens=3)\n", - "print(gen)" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "metadata": {}, - "outputs": [ + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", + "import os\n", + "\n", + "DEBUG_MODE = False\n", + "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", + "try:\n", + " import google.colab\n", + "\n", + " IN_COLAB = True\n", + " print(\"Running as a Colab notebook\")\n", + "except:\n", + " IN_COLAB = False\n", + " print(\"Running as a Jupyter notebook - intended for development only!\")\n", + " from IPython import get_ipython\n", + "\n", + " ipython = get_ipython()\n", + " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", + " ipython.run_line_magic(\"load_ext\", \"autoreload\")\n", + " ipython.run_line_magic(\"autoreload\", \"2\")\n", + "\n", + "if IN_COLAB or IN_GITHUB:\n", + " %pip install transformer_lens\n", + " %pip install torchtyping\n", + " # Install my janky personal plotting utils\n", + " %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n", + " # Install another version of node that makes PySvelte work way faster\n", + " %pip install circuitsvis\n", + " # Needed for PySvelte to work, v3 came out and broke things...\n", + " %pip install typeguard==2.13.3\n", + "\n", + "import torch\n", + "from typing import List, Callable, Tuple, Union\n", + "from functools import partial\n", + "from jaxtyping import Float\n", + "from transformer_lens.model_bridge import TransformerBridge\n", + "from transformer_lens.ActivationCache import ActivationCache\n", + "import transformer_lens.utils as utils\n", + "from transformer_lens.hook_points import (\n", + " HookPoint,\n", + ") # Hooking utilities" + ] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Patchscopes is a nice tool to inspect hidden representation of language model.\n", - "\n", - "It is a simple tool to inspect hidden representation of language model.\n", - "\n", - "It is a simple tool to inspect hidden representation of language model.\n", - "\n", - "It is a simple tool to inspect hidden representation of language model.\n", - "\n", - "It is\n" - ] - } - ], - "source": [ - "# Original generation\n", - "print(model.generate(prompts[0], verbose=False, max_new_tokens=50, do_sample=False))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Application Examples" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Logit Lens\n", - "\n", - "For Logit Lens, the configuration is l* ← L*. Here, L* is the last layer." - ] - }, - { - "cell_type": "code", - "execution_count": 104, - "metadata": {}, - "outputs": [], - "source": [ - "token_list = []\n", - "value_list = []\n", - "\n", - "def identity_function(source_rep: torch.Tensor) -> torch.Tensor:\n", - " return source_rep\n", - "\n", - "for source_layer_id in range(12):\n", - " # Prepare source representation\n", - " source_rep = get_source_representation(\n", - " prompts=[\"Patchscopes is a nice tool to inspect hidden representation of language model\"],\n", - " layer_id=source_layer_id,\n", - " model=model,\n", - " pos_id=None\n", - " )\n", - "\n", - " logits = feed_source_representation(\n", - " source_rep=source_rep,\n", - " prompt=[\"Patchscopes is a nice tool to inspect hidden representation of language model\"],\n", - " f=identity_function,\n", - " model=model,\n", - " layer_id=11\n", - " )\n", - " token_list.append([model.to_string(token_id.item()) for token_id in logits.argmax(dim=-1).squeeze()])\n", - " value_list.append([value for value in torch.max(logits.softmax(dim=-1), dim=-1)[0].detach().squeeze().numpy()])" - ] - }, - { - "cell_type": "code", - "execution_count": 109, - "metadata": {}, - "outputs": [], - "source": [ - "token_list = np.array(token_list[::-1])\n", - "value_list = np.array(value_list[::-1])" - ] - }, - { - "cell_type": "code", - "execution_count": 110, - "metadata": {}, - "outputs": [ + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Helper Funcs\n", + "\n", + "A helper function to plot logit lens" + ] + }, + { + "cell_type": "code", + "execution_count": 116, + "metadata": {}, + "outputs": [], + "source": [ + "import plotly.graph_objects as go\n", + "import numpy as np\n", + "\n", + "# Parameters\n", + "num_layers = 5\n", + "seq_len = 10\n", + "\n", + "# Create a matrix of tokens for demonstration\n", + "tokens = np.array([[\"token_{}_{}\".format(i, j) for j in range(seq_len)] for i in range(num_layers)])[::-1]\n", + "values = np.random.rand(num_layers, seq_len)\n", + "orig_tokens = ['Token {}'.format(i) for i in range(seq_len)]\n", + "\n", + "def draw_logit_lens(num_layers, seq_len, orig_tokens, tokens, values):\n", + " # Create the heatmap\n", + " fig = go.Figure(data=go.Heatmap(\n", + " z=values,\n", + " x=orig_tokens,\n", + " y=['Layer {}'.format(i) for i in range(num_layers)][::-1],\n", + " colorscale='Blues',\n", + " showscale=True,\n", + " colorbar=dict(title='Value')\n", + " ))\n", + "\n", + " # Add text annotations\n", + " annotations = []\n", + " for i in range(num_layers):\n", + " for j in range(seq_len):\n", + " annotations.append(\n", + " dict(\n", + " x=j, y=i,\n", + " text=tokens[i, j],\n", + " showarrow=False,\n", + " font=dict(color='white')\n", + " )\n", + " )\n", + "\n", + " fig.update_layout(\n", + " annotations=annotations,\n", + " xaxis=dict(side='top'),\n", + " yaxis=dict(autorange='reversed'),\n", + " margin=dict(l=50, r=50, t=100, b=50),\n", + " width=1000,\n", + " height=600,\n", + " plot_bgcolor='white'\n", + " )\n", + "\n", + " # Show the plot\n", + " fig.show()\n", + "# draw_logit_lens(num_layers, seq_len, orig_tokens, tokens, values)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model Preparation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "# I'm using an M2 macbook air, so I use CPU for better support\n", + "model = TransformerBridge.boot_transformers(\"gpt2\", device=\"cpu\")\n", + "model.enable_compatibility_mode()\n", + "model.eval()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Patchscopes Definition\n", + "\n", + "Here we first wirte down the formal definition decribed in the paper https://arxiv.org/pdf/2401.06102.\n", + "\n", + "The representations are:\n", + "\n", + "source: (S, i, M, l), where S is the source prompt, i is the source position, M is the source model, and l is the source layer.\n", + "\n", + "target: (T,i*,f,M*,l*), where T is the target prompt, i* is the target position, M* is the target model, l* is the target layer, and f is the mapping function that takes the original hidden states as input and output the target hidden states\n", + "\n", + "By defulat, S = T, i = i*, M = M*, l = l*, f = identity function" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Patchscopes Pipeline\n", + "\n", + "### Get hidden representation from the source model\n", + "\n", + "1. We first need to extract the source hidden states from model M at position i of layer l with prompt S. In TransformerLens, we can do this using run_with_cache.\n", + "2. Then, we map the source representation with a function f, and feed the hidden representation to the target position using a hook. Specifically, we focus on residual stream (resid_post), whereas you can manipulate more fine-grainedly with TransformerLens\n" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "prompts = [\"Patchscopes is a nice tool to inspect hidden representation of language model\"]\n", + "input_tokens = model.to_tokens(prompts)\n", + "clean_logits, clean_cache = model.run_with_cache(input_tokens)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_source_representation(prompts: List[str], layer_id: int, model: TransformerBridge, pos_id: Union[int, List[int]]=None) -> torch.Tensor:\n", + " \"\"\"Get source hidden representation represented by (S, i, M, l)\n", + " \n", + " Args:\n", + " - prompts (List[str]): a list of source prompts\n", + " - layer_id (int): the layer id of the model\n", + " - model (TransformerBridge): the source model\n", + " - pos_id (Union[int, List[int]]): the position id(s) of the model, if None, return all positions\n", + "\n", + " Returns:\n", + " - source_rep (torch.Tensor): the source hidden representation\n", + " \"\"\"\n", + " input_tokens = model.to_tokens(prompts)\n", + " _, cache = model.run_with_cache(input_tokens)\n", + " layer_name = \"blocks.{id}.hook_resid_post\"\n", + " layer_name = layer_name.format(id=layer_id)\n", + " if pos_id is None:\n", + " return cache[layer_name][:, :, :]\n", + " else:\n", + " return cache[layer_name][:, pos_id, :]" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "source_rep = get_source_representation(\n", + " prompts=[\"Patchscopes is a nice tool to inspect hidden representation of language model\"],\n", + " layer_id=2,\n", + " model=model,\n", + " pos_id=5\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Feed the representation to the target position\n", + "\n", + "First we need to map the representation using mapping function f, and then feed the target representation to the target position represented by (T,i*,f,M*,l*)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "# here we use an identity function for demonstration purposes\n", + "def identity_function(source_rep: torch.Tensor) -> torch.Tensor:\n", + " return source_rep" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# recall the target representation (T,i*,f,M*,l*), and we also need the hidden representation from our source model (S, i, M, l)\n", + "def feed_source_representation(source_rep: torch.Tensor, prompt: List[str], f: Callable, model: TransformerBridge, layer_id: int, pos_id: Union[int, List[int]]=None) -> ActivationCache:\n", + " \"\"\"Feed the source hidden representation to the target model\n", + " \n", + " Args:\n", + " - source_rep (torch.Tensor): the source hidden representation\n", + " - prompt (List[str]): the target prompt\n", + " - f (Callable): the mapping function\n", + " - model (TransformerBridge): the target model\n", + " - layer_id (int): the layer id of the target model\n", + " - pos_id (Union[int, List[int]]): the position id(s) of the target model, if None, return all positions\n", + " \"\"\"\n", + " mapped_rep = f(source_rep)\n", + " # similar to what we did for activation patching, we need to define a function to patch the hidden representation\n", + " def resid_ablation_hook(\n", + " value: Float[torch.Tensor, \"batch pos d_resid\"],\n", + " hook: HookPoint\n", + " ) -> Float[torch.Tensor, \"batch pos d_resid\"]:\n", + " # print(f\"Shape of the value tensor: {value.shape}\")\n", + " # print(f\"Shape of the hidden representation at the target position: {value[:, pos_id, :].shape}\")\n", + " value[:, pos_id, :] = mapped_rep\n", + " return value\n", + " \n", + " input_tokens = model.to_tokens(prompt)\n", + "\n", + " logits = model.run_with_hooks(\n", + " input_tokens,\n", + " return_type=\"logits\",\n", + " fwd_hooks=[(\n", + " utils.get_act_name(\"resid_post\", layer_id),\n", + " resid_ablation_hook\n", + " )]\n", + " )\n", + " \n", + " return logits" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "patched_logits = feed_source_representation(\n", + " source_rep=source_rep,\n", + " prompt=prompts,\n", + " pos_id=3,\n", + " f=identity_function,\n", + " model=model,\n", + " layer_id=2\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(tensor([[ 3.5811, 3.5322, 2.6463, ..., -4.3504, -1.7939, 3.3541]],\n", + " grad_fn=),\n", + " tensor([[ 3.2431, 3.2708, 1.9591, ..., -4.2666, -2.2141, 3.4965]],\n", + " grad_fn=))" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "clean_logits[:, 5], patched_logits[:, 5]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Generation with Patching\n", + "\n", + "In the last step, we've implemented the basic version of Patchscopes where we can only run one single forward pass. Let's now unlock the power by allowing it to generate multiple tokens!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def generate_with_patching(model: TransformerBridge, prompts: List[str], target_f: Callable, max_new_tokens: int = 50):\n", + " temp_prompts = prompts\n", + " input_tokens = model.to_tokens(temp_prompts)\n", + " for _ in range(max_new_tokens):\n", + " logits = target_f(\n", + " prompt=temp_prompts,\n", + " )\n", + " next_tok = torch.argmax(logits[:, -1, :])\n", + " input_tokens = torch.cat((input_tokens, next_tok.view(input_tokens.size(0), 1)), dim=1)\n", + " temp_prompts = model.to_string(input_tokens)\n", + "\n", + " return model.to_string(input_tokens)[0]" + ] + }, { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ { - "colorbar": { - "title": { - "text": "Value" - } - }, - "colorscale": [ - [ - 0, - "rgb(247,251,255)" - ], - [ - 0.125, - "rgb(222,235,247)" - ], - [ - 0.25, - "rgb(198,219,239)" - ], - [ - 0.375, - "rgb(158,202,225)" - ], - [ - 0.5, - "rgb(107,174,214)" - ], - [ - 0.625, - "rgb(66,146,198)" - ], - [ - 0.75, - "rgb(33,113,181)" - ], - [ - 0.875, - "rgb(8,81,156)" - ], - [ - 1, - "rgb(8,48,107)" + "name": "stdout", + "output_type": "stream", + "text": [ + "<|endoftext|>Patchscopes is a nice tool to inspect hidden representation of language model file bit file\n" ] - ], - "showscale": true, - "type": "heatmap", - "x": [ - "<|endoftext|>", - "Patch", - "sc", - "opes", - " is", - " a", - " nice", - " tool", - " to", - " inspect", - " hidden", - " representation", - " of", - " language", - " model" - ], - "y": [ - "Layer 11", - "Layer 10", - "Layer 9", - "Layer 8", - "Layer 7", - "Layer 6", - "Layer 5", - "Layer 4", - "Layer 3", - "Layer 2", - "Layer 1", - "Layer 0" - ], - "z": [ - [ - 0.34442219138145447, - 0.9871702790260315, - 0.3734475076198578, - 0.9830440878868103, - 0.4042338728904724, - 0.09035539627075195, - 0.8022230863571167, - 0.5206465125083923, - 0.14175501465797424, - 0.9898471236228943, - 0.9606538414955139, - 0.9691148996353149, - 0.662227988243103, - 0.9815096855163574, - 0.9055094718933105 - ], - [ - 0.08009976148605347, - 0.99101722240448, - 0.45667293667793274, - 0.40307697653770447, - 0.49327367544174194, - 0.08549172431230545, - 0.7428992390632629, - 0.8611035943031311, - 0.1983162760734558, - 0.9246276021003723, - 0.8956946730613708, - 0.8638046383857727, - 0.8365117311477661, - 0.9618501663208008, - 0.9175702333450317 - ], - [ - 0.02691030502319336, - 0.9732530117034912, - 0.19330987334251404, - 0.381843239068985, - 0.33808818459510803, - 0.07934993505477905, - 0.3974476158618927, - 0.7191767692565918, - 0.24212224781513214, - 0.7858667373657227, - 0.866357684135437, - 0.6622256636619568, - 0.8740373849868774, - 0.947133481502533, - 0.8450764417648315 - ], - [ - 0.027061497792601585, - 0.9609430432319641, - 0.2772334814071655, - 0.20079827308654785, - 0.2932577431201935, - 0.1255684345960617, - 0.32114332914352417, - 0.6489707827568054, - 0.2919656038284302, - 0.18173590302467346, - 0.635391891002655, - 0.5701303482055664, - 0.8785448670387268, - 0.8575655221939087, - 0.6919821500778198 - ], - [ - 0.026887305080890656, - 0.9309146404266357, - 0.44758421182632446, - 0.24046003818511963, - 0.28474941849708557, - 0.20104897022247314, - 0.5028793811798096, - 0.48273345828056335, - 0.2584459185600281, - 0.36538586020469666, - 0.20586784183979034, - 0.3072110712528229, - 0.9045845866203308, - 0.5042338371276855, - 0.4879302978515625 - ], - [ - 0.0265483595430851, - 0.9315882921218872, - 0.41395631432533264, - 0.2468952238559723, - 0.35624295473098755, - 0.21814416348934174, - 0.6175792813301086, - 0.7821283340454102, - 0.28484007716178894, - 0.3186572194099426, - 0.16824035346508026, - 0.5927833914756775, - 0.8808191418647766, - 0.5171196460723877, - 0.2029583901166916 - ], - [ - 0.026423994451761246, - 0.898944079875946, - 0.32038140296936035, - 0.44839850068092346, - 0.2796024978160858, - 0.20586445927619934, - 0.6313580274581909, - 0.87591552734375, - 0.18971839547157288, - 0.3038368225097656, - 0.36893585324287415, - 0.5965255498886108, - 0.7505314946174622, - 0.5989011526107788, - 0.10610682517290115 - ], - [ - 0.026437079533934593, - 0.6845366358757019, - 0.3912840485572815, - 0.37950050830841064, - 0.5224342346191406, - 0.2038283497095108, - 0.3475077748298645, - 0.647609293460846, - 0.11305152624845505, - 0.4017726182937622, - 0.4405157268047333, - 0.533568799495697, - 0.5206188559532166, - 0.2670389711856842, - 0.08740855008363724 - ], - [ - 0.026673221960663795, - 0.36045604944229126, - 0.27727553248405457, - 0.4515568017959595, - 0.5681671500205994, - 0.36901071667671204, - 0.5300043821334839, - 0.494934618473053, - 0.3656132221221924, - 0.40456005930900574, - 0.2656775712966919, - 0.2756248712539673, - 0.517121434211731, - 0.3028433322906494, - 0.09847757965326309 - ], - [ - 0.026949577033519745, - 0.3112040162086487, - 0.22643150389194489, - 0.7095355987548828, - 0.5966493487358093, - 0.4613777995109558, - 0.8436885476112366, - 0.4194002151489258, - 0.22365105152130127, - 0.4558623731136322, - 0.32150164246559143, - 0.4018287658691406, - 0.8275868892669678, - 0.3780366778373718, - 0.19973652064800262 - ], - [ - 0.027445374056696892, - 0.3283821940422058, - 0.5192154049873352, - 0.1790430098772049, - 0.6429017782211304, - 0.3577035665512085, - 0.6037949919700623, - 0.5884966254234314, - 0.18566730618476868, - 0.3142710030078888, - 0.15301460027694702, - 0.3585647940635681, - 0.4576294720172882, - 0.1486930102109909, - 0.13506801426410675 - ], - [ - 0.062298569828271866, - 0.24093002080917358, - 0.16585318744182587, - 0.16210544109344482, - 0.449150949716568, - 0.042253680527210236, - 0.11057071387767792, - 0.3447357416152954, - 0.08157400786876678, - 0.13642098009586334, - 0.07241284847259521, - 0.25115686655044556, - 0.084745854139328, - 0.0951341837644577, - 0.1267273873090744 + } + ], + "source": [ + "prompts = [\"Patchscopes is a nice tool to inspect hidden representation of language model\"]\n", + "input_tokens = model.to_tokens(prompts)\n", + "target_f = partial(\n", + " feed_source_representation,\n", + " source_rep=source_rep,\n", + " pos_id=-1,\n", + " f=identity_function,\n", + " model=model,\n", + " layer_id=2\n", + ")\n", + "gen = generate_with_patching(model, prompts, target_f, max_new_tokens=3)\n", + "print(gen)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Patchscopes is a nice tool to inspect hidden representation of language model.\n", + "\n", + "It is a simple tool to inspect hidden representation of language model.\n", + "\n", + "It is a simple tool to inspect hidden representation of language model.\n", + "\n", + "It is a simple tool to inspect hidden representation of language model.\n", + "\n", + "It is\n" ] - ] } - ], - "layout": { - "annotations": [ - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": "\n", - "x": 0, - "y": 0 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " Patch", - "x": 1, - "y": 0 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": "rawl", - "x": 2, - "y": 0 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": "opes", - "x": 3, - "y": 0 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " not", - "x": 4, - "y": 0 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " new", - "x": 5, - "y": 0 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " nice", - "x": 6, - "y": 0 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": "tips", - "x": 7, - "y": 0 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " get", - "x": 8, - "y": 0 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " inspect", - "x": 9, - "y": 0 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " hidden", - "x": 10, - "y": 0 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " representation", - "x": 11, - "y": 0 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " the", - "x": 12, - "y": 0 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " language", - "x": 13, - "y": 0 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " model", - "x": 14, - "y": 0 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": "\n", - "x": 0, - "y": 1 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " Patch", - "x": 1, - "y": 1 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": "urry", - "x": 2, - "y": 1 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " Operator", - "x": 3, - "y": 1 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " not", - "x": 4, - "y": 1 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " new", - "x": 5, - "y": 1 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " nice", - "x": 6, - "y": 1 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": "tips", - "x": 7, - "y": 1 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " get", - "x": 8, - "y": 1 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " inspect", - "x": 9, - "y": 1 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " hidden", - "x": 10, - "y": 1 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " representation", - "x": 11, - "y": 1 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " the", - "x": 12, - "y": 1 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " language", - "x": 13, - "y": 1 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " model", - "x": 14, - "y": 1 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": ",", - "x": 0, - "y": 2 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " Patch", - "x": 1, - "y": 2 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": "rawl", - "x": 2, - "y": 2 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " Operator", - "x": 3, - "y": 2 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " not", - "x": 4, - "y": 2 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " new", - "x": 5, - "y": 2 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " nice", - "x": 6, - "y": 2 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": "tips", - "x": 7, - "y": 2 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " keep", - "x": 8, - "y": 2 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " inspect", - "x": 9, - "y": 2 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " hidden", - "x": 10, - "y": 2 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " representation", - "x": 11, - "y": 2 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " the", - "x": 12, - "y": 2 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " language", - "x": 13, - "y": 2 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " model", - "x": 14, - "y": 2 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": ",", - "x": 0, - "y": 3 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " Patch", - "x": 1, - "y": 3 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": "atch", - "x": 2, - "y": 3 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": ":", - "x": 3, - "y": 3 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " currently", - "x": 4, - "y": 3 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " very", - "x": 5, - "y": 3 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " thing", - "x": 6, - "y": 3 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": "tips", - "x": 7, - "y": 3 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " keep", - "x": 8, - "y": 3 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " the", - "x": 9, - "y": 3 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " hidden", - "x": 10, - "y": 3 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " representation", - "x": 11, - "y": 3 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " the", - "x": 12, - "y": 3 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " language", - "x": 13, - "y": 3 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " model", - "x": 14, - "y": 3 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": ",", - "x": 0, - "y": 4 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " Patch", - "x": 1, - "y": 4 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": "atch", - "x": 2, - "y": 4 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": ":", - "x": 3, - "y": 4 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " currently", - "x": 4, - "y": 4 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " unique", - "x": 5, - "y": 4 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " little", - "x": 6, - "y": 4 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " tool", - "x": 7, - "y": 4 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " keep", - "x": 8, - "y": 4 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " your", - "x": 9, - "y": 4 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " hidden", - "x": 10, - "y": 4 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " of", - "x": 11, - "y": 4 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " the", - "x": 12, - "y": 4 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " language", - "x": 13, - "y": 4 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " model", - "x": 14, - "y": 4 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": ",", - "x": 0, - "y": 5 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " Patch", - "x": 1, - "y": 5 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": "reens", - "x": 2, - "y": 5 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": ":", - "x": 3, - "y": 5 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " currently", - "x": 4, - "y": 5 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " unique", - "x": 5, - "y": 5 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " little", - "x": 6, - "y": 5 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " tool", - "x": 7, - "y": 5 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " keep", - "x": 8, - "y": 5 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " your", - "x": 9, - "y": 5 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " gem", - "x": 10, - "y": 5 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " of", - "x": 11, - "y": 5 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " the", - "x": 12, - "y": 5 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " language", - "x": 13, - "y": 5 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " model", - "x": 14, - "y": 5 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": ",", - "x": 0, - "y": 6 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " Patch", - "x": 1, - "y": 6 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": "ree", - "x": 2, - "y": 6 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": ":", - "x": 3, - "y": 6 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " also", - "x": 4, - "y": 6 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " unique", - "x": 5, - "y": 6 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " little", - "x": 6, - "y": 6 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " tool", - "x": 7, - "y": 6 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " keep", - "x": 8, - "y": 6 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " the", - "x": 9, - "y": 6 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " gems", - "x": 10, - "y": 6 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " of", - "x": 11, - "y": 6 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " the", - "x": 12, - "y": 6 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " language", - "x": 13, - "y": 6 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " (", - "x": 14, - "y": 6 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": ",", - "x": 0, - "y": 7 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " Patch", - "x": 1, - "y": 7 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": "ream", - "x": 2, - "y": 7 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": ":", - "x": 3, - "y": 7 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " currently", - "x": 4, - "y": 7 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " powerful", - "x": 5, - "y": 7 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " little", - "x": 6, - "y": 7 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " tool", - "x": 7, - "y": 7 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " keep", - "x": 8, - "y": 7 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " your", - "x": 9, - "y": 7 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " gems", - "x": 10, - "y": 7 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " of", - "x": 11, - "y": 7 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " the", - "x": 12, - "y": 7 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " language", - "x": 13, - "y": 7 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " (", - "x": 14, - "y": 7 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": ",", - "x": 0, - "y": 8 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": "work", - "x": 1, - "y": 8 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": "ream", - "x": 2, - "y": 8 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": "¶", - "x": 3, - "y": 8 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " currently", - "x": 4, - "y": 8 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " powerful", - "x": 5, - "y": 8 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " tool", - "x": 6, - "y": 8 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": "kit", - "x": 7, - "y": 8 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " help", - "x": 8, - "y": 8 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " your", - "x": 9, - "y": 8 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " objects", - "x": 10, - "y": 8 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " of", - "x": 11, - "y": 8 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " objects", - "x": 12, - "y": 8 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " objects", - "x": 13, - "y": 8 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " data", - "x": 14, - "y": 8 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": ",", - "x": 0, - "y": 9 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": "work", - "x": 1, - "y": 9 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": "ream", - "x": 2, - "y": 9 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": "¶", - "x": 3, - "y": 9 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " a", - "x": 4, - "y": 9 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " tool", - "x": 5, - "y": 9 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " tool", - "x": 6, - "y": 9 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " tool", - "x": 7, - "y": 9 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " help", - "x": 8, - "y": 9 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " your", - "x": 9, - "y": 9 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " objects", - "x": 10, - "y": 9 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " of", - "x": 11, - "y": 9 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " objects", - "x": 12, - "y": 9 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " strings", - "x": 13, - "y": 9 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " variables", - "x": 14, - "y": 9 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": ",", - "x": 0, - "y": 10 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " Notes", - "x": 1, - "y": 10 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": "rew", - "x": 2, - "y": 10 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": "¶", - "x": 3, - "y": 10 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " a", - "x": 4, - "y": 10 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " tool", - "x": 5, - "y": 10 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " tool", - "x": 6, - "y": 10 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " for", - "x": 7, - "y": 10 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " help", - "x": 8, - "y": 10 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " your", - "x": 9, - "y": 10 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " items", - "x": 10, - "y": 10 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " of", - "x": 11, - "y": 10 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " objects", - "x": 12, - "y": 10 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " objects", - "x": 13, - "y": 10 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " objects", - "x": 14, - "y": 10 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": "\n", - "x": 0, - "y": 11 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " Notes", - "x": 1, - "y": 11 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": "rew", - "x": 2, - "y": 11 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": "\n", - "x": 3, - "y": 11 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " a", - "x": 4, - "y": 11 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " tool", - "x": 5, - "y": 11 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " tool", - "x": 6, - "y": 11 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " for", - "x": 7, - "y": 11 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " help", - "x": 8, - "y": 11 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " the", - "x": 9, - "y": 11 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " files", - "x": 10, - "y": 11 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " of", - "x": 11, - "y": 11 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " the", - "x": 12, - "y": 11 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": " features", - "x": 13, - "y": 11 - }, - { - "font": { - "color": "white" - }, - "showarrow": false, - "text": ".", - "x": 14, - "y": 11 - } - ], - "height": 600, - "margin": { - "b": 50, - "l": 50, - "r": 50, - "t": 100 - }, - "plot_bgcolor": "white", - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } + ], + "source": [ + "# Original generation\n", + "print(model.generate(prompts[0], verbose=False, max_new_tokens=50, do_sample=False))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Application Examples" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Logit Lens\n", + "\n", + "For Logit Lens, the configuration is l* ← L*. Here, L* is the last layer." + ] + }, + { + "cell_type": "code", + "execution_count": 104, + "metadata": {}, + "outputs": [], + "source": [ + "token_list = []\n", + "value_list = []\n", + "\n", + "def identity_function(source_rep: torch.Tensor) -> torch.Tensor:\n", + " return source_rep\n", + "\n", + "for source_layer_id in range(12):\n", + " # Prepare source representation\n", + " source_rep = get_source_representation(\n", + " prompts=[\"Patchscopes is a nice tool to inspect hidden representation of language model\"],\n", + " layer_id=source_layer_id,\n", + " model=model,\n", + " pos_id=None\n", + " )\n", + "\n", + " logits = feed_source_representation(\n", + " source_rep=source_rep,\n", + " prompt=[\"Patchscopes is a nice tool to inspect hidden representation of language model\"],\n", + " f=identity_function,\n", + " model=model,\n", + " layer_id=11\n", + " )\n", + " token_list.append([model.to_string(token_id.item()) for token_id in logits.argmax(dim=-1).squeeze()])\n", + " value_list.append([value for value in torch.max(logits.softmax(dim=-1), dim=-1)[0].detach().squeeze().numpy()])" + ] + }, + { + "cell_type": "code", + "execution_count": 109, + "metadata": {}, + "outputs": [], + "source": [ + "token_list = np.array(token_list[::-1])\n", + "value_list = np.array(value_list[::-1])" + ] + }, + { + "cell_type": "code", + "execution_count": 110, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "colorbar": { + "title": { + "text": "Value" + } + }, + "colorscale": [ + [ + 0, + "rgb(247,251,255)" + ], + [ + 0.125, + "rgb(222,235,247)" + ], + [ + 0.25, + "rgb(198,219,239)" + ], + [ + 0.375, + "rgb(158,202,225)" + ], + [ + 0.5, + "rgb(107,174,214)" + ], + [ + 0.625, + "rgb(66,146,198)" + ], + [ + 0.75, + "rgb(33,113,181)" + ], + [ + 0.875, + "rgb(8,81,156)" + ], + [ + 1, + "rgb(8,48,107)" + ] + ], + "showscale": true, + "type": "heatmap", + "x": [ + "<|endoftext|>", + "Patch", + "sc", + "opes", + " is", + " a", + " nice", + " tool", + " to", + " inspect", + " hidden", + " representation", + " of", + " language", + " model" + ], + "y": [ + "Layer 11", + "Layer 10", + "Layer 9", + "Layer 8", + "Layer 7", + "Layer 6", + "Layer 5", + "Layer 4", + "Layer 3", + "Layer 2", + "Layer 1", + "Layer 0" + ], + "z": [ + [ + 0.34442219138145447, + 0.9871702790260315, + 0.3734475076198578, + 0.9830440878868103, + 0.4042338728904724, + 0.09035539627075195, + 0.8022230863571167, + 0.5206465125083923, + 0.14175501465797424, + 0.9898471236228943, + 0.9606538414955139, + 0.9691148996353149, + 0.662227988243103, + 0.9815096855163574, + 0.9055094718933105 + ], + [ + 0.08009976148605347, + 0.99101722240448, + 0.45667293667793274, + 0.40307697653770447, + 0.49327367544174194, + 0.08549172431230545, + 0.7428992390632629, + 0.8611035943031311, + 0.1983162760734558, + 0.9246276021003723, + 0.8956946730613708, + 0.8638046383857727, + 0.8365117311477661, + 0.9618501663208008, + 0.9175702333450317 + ], + [ + 0.02691030502319336, + 0.9732530117034912, + 0.19330987334251404, + 0.381843239068985, + 0.33808818459510803, + 0.07934993505477905, + 0.3974476158618927, + 0.7191767692565918, + 0.24212224781513214, + 0.7858667373657227, + 0.866357684135437, + 0.6622256636619568, + 0.8740373849868774, + 0.947133481502533, + 0.8450764417648315 + ], + [ + 0.027061497792601585, + 0.9609430432319641, + 0.2772334814071655, + 0.20079827308654785, + 0.2932577431201935, + 0.1255684345960617, + 0.32114332914352417, + 0.6489707827568054, + 0.2919656038284302, + 0.18173590302467346, + 0.635391891002655, + 0.5701303482055664, + 0.8785448670387268, + 0.8575655221939087, + 0.6919821500778198 + ], + [ + 0.026887305080890656, + 0.9309146404266357, + 0.44758421182632446, + 0.24046003818511963, + 0.28474941849708557, + 0.20104897022247314, + 0.5028793811798096, + 0.48273345828056335, + 0.2584459185600281, + 0.36538586020469666, + 0.20586784183979034, + 0.3072110712528229, + 0.9045845866203308, + 0.5042338371276855, + 0.4879302978515625 + ], + [ + 0.0265483595430851, + 0.9315882921218872, + 0.41395631432533264, + 0.2468952238559723, + 0.35624295473098755, + 0.21814416348934174, + 0.6175792813301086, + 0.7821283340454102, + 0.28484007716178894, + 0.3186572194099426, + 0.16824035346508026, + 0.5927833914756775, + 0.8808191418647766, + 0.5171196460723877, + 0.2029583901166916 + ], + [ + 0.026423994451761246, + 0.898944079875946, + 0.32038140296936035, + 0.44839850068092346, + 0.2796024978160858, + 0.20586445927619934, + 0.6313580274581909, + 0.87591552734375, + 0.18971839547157288, + 0.3038368225097656, + 0.36893585324287415, + 0.5965255498886108, + 0.7505314946174622, + 0.5989011526107788, + 0.10610682517290115 + ], + [ + 0.026437079533934593, + 0.6845366358757019, + 0.3912840485572815, + 0.37950050830841064, + 0.5224342346191406, + 0.2038283497095108, + 0.3475077748298645, + 0.647609293460846, + 0.11305152624845505, + 0.4017726182937622, + 0.4405157268047333, + 0.533568799495697, + 0.5206188559532166, + 0.2670389711856842, + 0.08740855008363724 + ], + [ + 0.026673221960663795, + 0.36045604944229126, + 0.27727553248405457, + 0.4515568017959595, + 0.5681671500205994, + 0.36901071667671204, + 0.5300043821334839, + 0.494934618473053, + 0.3656132221221924, + 0.40456005930900574, + 0.2656775712966919, + 0.2756248712539673, + 0.517121434211731, + 0.3028433322906494, + 0.09847757965326309 + ], + [ + 0.026949577033519745, + 0.3112040162086487, + 0.22643150389194489, + 0.7095355987548828, + 0.5966493487358093, + 0.4613777995109558, + 0.8436885476112366, + 0.4194002151489258, + 0.22365105152130127, + 0.4558623731136322, + 0.32150164246559143, + 0.4018287658691406, + 0.8275868892669678, + 0.3780366778373718, + 0.19973652064800262 + ], + [ + 0.027445374056696892, + 0.3283821940422058, + 0.5192154049873352, + 0.1790430098772049, + 0.6429017782211304, + 0.3577035665512085, + 0.6037949919700623, + 0.5884966254234314, + 0.18566730618476868, + 0.3142710030078888, + 0.15301460027694702, + 0.3585647940635681, + 0.4576294720172882, + 0.1486930102109909, + 0.13506801426410675 + ], + [ + 0.062298569828271866, + 0.24093002080917358, + 0.16585318744182587, + 0.16210544109344482, + 0.449150949716568, + 0.042253680527210236, + 0.11057071387767792, + 0.3447357416152954, + 0.08157400786876678, + 0.13642098009586334, + 0.07241284847259521, + 0.25115686655044556, + 0.084745854139328, + 0.0951341837644577, + 0.1267273873090744 + ] + ] + } + ], + "layout": { + "annotations": [ + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "\n", + "x": 0, + "y": 0 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " Patch", + "x": 1, + "y": 0 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "rawl", + "x": 2, + "y": 0 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "opes", + "x": 3, + "y": 0 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " not", + "x": 4, + "y": 0 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " new", + "x": 5, + "y": 0 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " nice", + "x": 6, + "y": 0 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "tips", + "x": 7, + "y": 0 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " get", + "x": 8, + "y": 0 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " inspect", + "x": 9, + "y": 0 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " hidden", + "x": 10, + "y": 0 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " representation", + "x": 11, + "y": 0 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " the", + "x": 12, + "y": 0 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " language", + "x": 13, + "y": 0 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " model", + "x": 14, + "y": 0 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "\n", + "x": 0, + "y": 1 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " Patch", + "x": 1, + "y": 1 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "urry", + "x": 2, + "y": 1 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " Operator", + "x": 3, + "y": 1 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " not", + "x": 4, + "y": 1 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " new", + "x": 5, + "y": 1 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " nice", + "x": 6, + "y": 1 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "tips", + "x": 7, + "y": 1 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " get", + "x": 8, + "y": 1 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " inspect", + "x": 9, + "y": 1 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " hidden", + "x": 10, + "y": 1 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " representation", + "x": 11, + "y": 1 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " the", + "x": 12, + "y": 1 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " language", + "x": 13, + "y": 1 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " model", + "x": 14, + "y": 1 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": ",", + "x": 0, + "y": 2 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " Patch", + "x": 1, + "y": 2 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "rawl", + "x": 2, + "y": 2 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " Operator", + "x": 3, + "y": 2 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " not", + "x": 4, + "y": 2 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " new", + "x": 5, + "y": 2 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " nice", + "x": 6, + "y": 2 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "tips", + "x": 7, + "y": 2 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " keep", + "x": 8, + "y": 2 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " inspect", + "x": 9, + "y": 2 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " hidden", + "x": 10, + "y": 2 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " representation", + "x": 11, + "y": 2 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " the", + "x": 12, + "y": 2 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " language", + "x": 13, + "y": 2 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " model", + "x": 14, + "y": 2 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": ",", + "x": 0, + "y": 3 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " Patch", + "x": 1, + "y": 3 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "atch", + "x": 2, + "y": 3 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": ":", + "x": 3, + "y": 3 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " currently", + "x": 4, + "y": 3 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " very", + "x": 5, + "y": 3 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " thing", + "x": 6, + "y": 3 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "tips", + "x": 7, + "y": 3 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " keep", + "x": 8, + "y": 3 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " the", + "x": 9, + "y": 3 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " hidden", + "x": 10, + "y": 3 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " representation", + "x": 11, + "y": 3 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " the", + "x": 12, + "y": 3 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " language", + "x": 13, + "y": 3 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " model", + "x": 14, + "y": 3 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": ",", + "x": 0, + "y": 4 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " Patch", + "x": 1, + "y": 4 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "atch", + "x": 2, + "y": 4 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": ":", + "x": 3, + "y": 4 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " currently", + "x": 4, + "y": 4 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " unique", + "x": 5, + "y": 4 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " little", + "x": 6, + "y": 4 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " tool", + "x": 7, + "y": 4 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " keep", + "x": 8, + "y": 4 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " your", + "x": 9, + "y": 4 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " hidden", + "x": 10, + "y": 4 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " of", + "x": 11, + "y": 4 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " the", + "x": 12, + "y": 4 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " language", + "x": 13, + "y": 4 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " model", + "x": 14, + "y": 4 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": ",", + "x": 0, + "y": 5 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " Patch", + "x": 1, + "y": 5 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "reens", + "x": 2, + "y": 5 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": ":", + "x": 3, + "y": 5 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " currently", + "x": 4, + "y": 5 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " unique", + "x": 5, + "y": 5 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " little", + "x": 6, + "y": 5 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " tool", + "x": 7, + "y": 5 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " keep", + "x": 8, + "y": 5 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " your", + "x": 9, + "y": 5 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " gem", + "x": 10, + "y": 5 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " of", + "x": 11, + "y": 5 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " the", + "x": 12, + "y": 5 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " language", + "x": 13, + "y": 5 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " model", + "x": 14, + "y": 5 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": ",", + "x": 0, + "y": 6 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " Patch", + "x": 1, + "y": 6 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "ree", + "x": 2, + "y": 6 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": ":", + "x": 3, + "y": 6 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " also", + "x": 4, + "y": 6 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " unique", + "x": 5, + "y": 6 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " little", + "x": 6, + "y": 6 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " tool", + "x": 7, + "y": 6 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " keep", + "x": 8, + "y": 6 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " the", + "x": 9, + "y": 6 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " gems", + "x": 10, + "y": 6 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " of", + "x": 11, + "y": 6 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " the", + "x": 12, + "y": 6 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " language", + "x": 13, + "y": 6 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " (", + "x": 14, + "y": 6 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": ",", + "x": 0, + "y": 7 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " Patch", + "x": 1, + "y": 7 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "ream", + "x": 2, + "y": 7 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": ":", + "x": 3, + "y": 7 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " currently", + "x": 4, + "y": 7 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " powerful", + "x": 5, + "y": 7 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " little", + "x": 6, + "y": 7 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " tool", + "x": 7, + "y": 7 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " keep", + "x": 8, + "y": 7 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " your", + "x": 9, + "y": 7 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " gems", + "x": 10, + "y": 7 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " of", + "x": 11, + "y": 7 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " the", + "x": 12, + "y": 7 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " language", + "x": 13, + "y": 7 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " (", + "x": 14, + "y": 7 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": ",", + "x": 0, + "y": 8 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "work", + "x": 1, + "y": 8 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "ream", + "x": 2, + "y": 8 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "¶", + "x": 3, + "y": 8 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " currently", + "x": 4, + "y": 8 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " powerful", + "x": 5, + "y": 8 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " tool", + "x": 6, + "y": 8 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "kit", + "x": 7, + "y": 8 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " help", + "x": 8, + "y": 8 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " your", + "x": 9, + "y": 8 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " objects", + "x": 10, + "y": 8 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " of", + "x": 11, + "y": 8 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " objects", + "x": 12, + "y": 8 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " objects", + "x": 13, + "y": 8 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " data", + "x": 14, + "y": 8 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": ",", + "x": 0, + "y": 9 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "work", + "x": 1, + "y": 9 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "ream", + "x": 2, + "y": 9 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "¶", + "x": 3, + "y": 9 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " a", + "x": 4, + "y": 9 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " tool", + "x": 5, + "y": 9 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " tool", + "x": 6, + "y": 9 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " tool", + "x": 7, + "y": 9 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " help", + "x": 8, + "y": 9 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " your", + "x": 9, + "y": 9 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " objects", + "x": 10, + "y": 9 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " of", + "x": 11, + "y": 9 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " objects", + "x": 12, + "y": 9 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " strings", + "x": 13, + "y": 9 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " variables", + "x": 14, + "y": 9 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": ",", + "x": 0, + "y": 10 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " Notes", + "x": 1, + "y": 10 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "rew", + "x": 2, + "y": 10 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "¶", + "x": 3, + "y": 10 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " a", + "x": 4, + "y": 10 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " tool", + "x": 5, + "y": 10 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " tool", + "x": 6, + "y": 10 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " for", + "x": 7, + "y": 10 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " help", + "x": 8, + "y": 10 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " your", + "x": 9, + "y": 10 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " items", + "x": 10, + "y": 10 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " of", + "x": 11, + "y": 10 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " objects", + "x": 12, + "y": 10 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " objects", + "x": 13, + "y": 10 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " objects", + "x": 14, + "y": 10 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "\n", + "x": 0, + "y": 11 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " Notes", + "x": 1, + "y": 11 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "rew", + "x": 2, + "y": 11 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "\n", + "x": 3, + "y": 11 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " a", + "x": 4, + "y": 11 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " tool", + "x": 5, + "y": 11 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " tool", + "x": 6, + "y": 11 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " for", + "x": 7, + "y": 11 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " help", + "x": 8, + "y": 11 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " the", + "x": 9, + "y": 11 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " files", + "x": 10, + "y": 11 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " of", + "x": 11, + "y": 11 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " the", + "x": 12, + "y": 11 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " features", + "x": 13, + "y": 11 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": ".", + "x": 14, + "y": 11 + } + ], + "height": 600, + "margin": { + "b": 50, + "l": 50, + "r": 50, + "t": 100 + }, + "plot_bgcolor": "white", + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "width": 1000, + "xaxis": { + "side": "top" + }, + "yaxis": { + "autorange": "reversed" + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "num_layers = 12\n", + "seq_len = len(token_list[0])\n", + "orig_tokens = [model.to_string(token_id) for token_id in model.to_tokens([\"Patchscopes is a nice tool to inspect hidden representation of language model\"])[0]]\n", + "draw_logit_lens(num_layers, seq_len, orig_tokens, token_list, value_list)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Entity Description\n", + "\n", + "Entity description tries to answer \"how LLMs resolve entity mentions across multiple layers. Concretely, given a subject entity name, such as “the summer Olympics of 1996”, how does the model contextualize the input tokens of the entity and at which layer is it fully resolved?\"\n", + "\n", + "The configuration is l* ← l, i* ← m, and it requires generating multiple tokens. Here m refers to the last position (the position of x)" + ] + }, + { + "cell_type": "code", + "execution_count": 111, + "metadata": {}, + "outputs": [], + "source": [ + " # Prepare source representation\n", + "source_rep = get_source_representation(\n", + " prompts=[\"Diana, Princess of Wales\"],\n", + " layer_id=11,\n", + " model=model,\n", + " pos_id=-1\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 115, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Generation by patching layer 0:\n", + "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", + "\n", + "\n", + "The \"The \"The \"The \"The \"The \"The \"The \"The\n", + "==============================\n", + "\n", + "Generation by patching layer 1:\n", + "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", + "\n", + "\n", + "The \"The \"The \"The \"The \"The \"The \"The \"The\n", + "==============================\n", + "\n", + "Generation by patching layer 2:\n", + "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", + "\n", + "\n", + "The\n", + "The\n", + "\n", + "\n", + "The\n", + "The\n", + "The\n", + "\n", + "\n", + "The\n", + "The\n", + "==============================\n", + "\n", + "Generation by patching layer 3:\n", + "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", + "\n", + "\n", + "The\n", + "\n", + "\n", + "The\n", + "\n", + "\n", + "The\n", + "\n", + "\n", + "The\n", + "\n", + "\n", + "The\n", + "==============================\n", + "\n", + "Generation by patching layer 4:\n", + "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", + "\n", + "\n", + "The United States\n", + "\n", + "\n", + "The United States\n", + "\n", + "\n", + "The United States\n", + "\n", + "\n", + "==============================\n", + "\n", + "Generation by patching layer 5:\n", + "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", + "\n", + "\n", + "The United States\n", + "\n", + "\n", + "The United States\n", + "\n", + "\n", + "The United States\n", + "\n", + "\n", + "==============================\n", + "\n", + "Generation by patching layer 6:\n", + "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", + "\n", + "\n", + "The United States is the world's most popular and the world's most beautiful.\n", + "\n", + "==============================\n", + "\n", + "Generation by patching layer 7:\n", + "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", + "\n", + "\n", + "The United States is the world's most popular and most beautiful country.\n", + "\n", + "\n", + "\n", + "==============================\n", + "\n", + "Generation by patching layer 8:\n", + "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", + "\n", + "\n", + "The United States is the world's largest exporter of the world's most expensive and\n", + "==============================\n", + "\n", + "Generation by patching layer 9:\n", + "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", + "\n", + "\n", + "The first time I saw the film, I was in the middle of a meeting with\n", + "==============================\n", + "\n", + "Generation by patching layer 10:\n", + "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", + "\n", + "\n", + "The world's most famous actor, actor and producer, Leonardo DiCaprio, has\n", + "==============================\n", + "\n", + "Generation by patching layer 11:\n", + "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x, and the world's largest consumer electronics company, Samsung Electronics Co., Ltd.\n", + "\n", + "\n", + "The\n", + "==============================\n", + "\n" ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" - ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" + } + ], + "source": [ + "target_prompt = [\"Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\"]\n", + "# need to calcualte an absolute position, instead of a relative position\n", + "last_pos_id = len(model.to_tokens(target_prompt)[0]) - 1\n", + "# we need to define the function that takes the generation as input\n", + "for target_layer_id in range(12):\n", + " target_f = partial(\n", + " feed_source_representation,\n", + " source_rep=source_rep,\n", + " pos_id=last_pos_id,\n", + " f=identity_function,\n", + " model=model,\n", + " layer_id=target_layer_id\n", + " )\n", + " gen = generate_with_patching(model, target_prompt, target_f, max_new_tokens=20)\n", + " print(f\"Generation by patching layer {target_layer_id}:\\n{gen}\\n{'='*30}\\n\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As we can see, maybe the early layers of gpt2-small are doing something related to entity resolution, whereas the late layers are apparently not(?)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Zero-Shot Feature Extraction\n", + "\n", + "Zero-shot Feature Extraction \"Consider factual and com- monsense knowledge represented as triplets (σ,ρ,ω) of a subject (e.g., “United States”), a relation (e.g., “largest city of”), and an object (e.g.,\n", + "“New York City”). We investigate to what extent the object ω can be extracted from the last token representation of the subject σ in an arbitrary input context.\"\n", + "\n", + "The configuration is l∗ ← j′ ∈ [1,...,L∗], i∗ ← m, T ← relation verbalization followed by x" + ] + }, + { + "cell_type": "code", + "execution_count": 359, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'Co-founder of company Apple, Steve Jobs, has said that Apple\\'s iPhone 6 and 6 Plus are \"the most important phones'" ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } - }, - "width": 1000, - "xaxis": { - "side": "top" - }, - "yaxis": { - "autorange": "reversed" + "execution_count": 359, + "metadata": {}, + "output_type": "execute_result" } - } - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "num_layers = 12\n", - "seq_len = len(token_list[0])\n", - "orig_tokens = [model.to_string(token_id) for token_id in model.to_tokens([\"Patchscopes is a nice tool to inspect hidden representation of language model\"])[0]]\n", - "draw_logit_lens(num_layers, seq_len, orig_tokens, token_list, value_list)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Entity Description\n", - "\n", - "Entity description tries to answer \"how LLMs resolve entity mentions across multiple layers. Concretely, given a subject entity name, such as “the summer Olympics of 1996”, how does the model contextualize the input tokens of the entity and at which layer is it fully resolved?\"\n", - "\n", - "The configuration is l* ← l, i* ← m, and it requires generating multiple tokens. Here m refers to the last position (the position of x)" - ] - }, - { - "cell_type": "code", - "execution_count": 111, - "metadata": {}, - "outputs": [], - "source": [ - " # Prepare source representation\n", - "source_rep = get_source_representation(\n", - " prompts=[\"Diana, Princess of Wales\"],\n", - " layer_id=11,\n", - " model=model,\n", - " pos_id=-1\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 115, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Generation by patching layer 0:\n", - "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", - "\n", - "\n", - "The \"The \"The \"The \"The \"The \"The \"The \"The\n", - "==============================\n", - "\n", - "Generation by patching layer 1:\n", - "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", - "\n", - "\n", - "The \"The \"The \"The \"The \"The \"The \"The \"The\n", - "==============================\n", - "\n", - "Generation by patching layer 2:\n", - "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", - "\n", - "\n", - "The\n", - "The\n", - "\n", - "\n", - "The\n", - "The\n", - "The\n", - "\n", - "\n", - "The\n", - "The\n", - "==============================\n", - "\n", - "Generation by patching layer 3:\n", - "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", - "\n", - "\n", - "The\n", - "\n", - "\n", - "The\n", - "\n", - "\n", - "The\n", - "\n", - "\n", - "The\n", - "\n", - "\n", - "The\n", - "==============================\n", - "\n", - "Generation by patching layer 4:\n", - "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", - "\n", - "\n", - "The United States\n", - "\n", - "\n", - "The United States\n", - "\n", - "\n", - "The United States\n", - "\n", - "\n", - "==============================\n", - "\n", - "Generation by patching layer 5:\n", - "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", - "\n", - "\n", - "The United States\n", - "\n", - "\n", - "The United States\n", - "\n", - "\n", - "The United States\n", - "\n", - "\n", - "==============================\n", - "\n", - "Generation by patching layer 6:\n", - "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", - "\n", - "\n", - "The United States is the world's most popular and the world's most beautiful.\n", - "\n", - "==============================\n", - "\n", - "Generation by patching layer 7:\n", - "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", - "\n", - "\n", - "The United States is the world's most popular and most beautiful country.\n", - "\n", - "\n", - "\n", - "==============================\n", - "\n", - "Generation by patching layer 8:\n", - "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", - "\n", - "\n", - "The United States is the world's largest exporter of the world's most expensive and\n", - "==============================\n", - "\n", - "Generation by patching layer 9:\n", - "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", - "\n", - "\n", - "The first time I saw the film, I was in the middle of a meeting with\n", - "==============================\n", - "\n", - "Generation by patching layer 10:\n", - "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", - "\n", - "\n", - "The world's most famous actor, actor and producer, Leonardo DiCaprio, has\n", - "==============================\n", - "\n", - "Generation by patching layer 11:\n", - "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x, and the world's largest consumer electronics company, Samsung Electronics Co., Ltd.\n", - "\n", - "\n", - "The\n", - "==============================\n", - "\n" - ] - } - ], - "source": [ - "target_prompt = [\"Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\"]\n", - "# need to calcualte an absolute position, instead of a relative position\n", - "last_pos_id = len(model.to_tokens(target_prompt)[0]) - 1\n", - "# we need to define the function that takes the generation as input\n", - "for target_layer_id in range(12):\n", - " target_f = partial(\n", - " feed_source_representation,\n", - " source_rep=source_rep,\n", - " pos_id=last_pos_id,\n", - " f=identity_function,\n", - " model=model,\n", - " layer_id=target_layer_id\n", - " )\n", - " gen = generate_with_patching(model, target_prompt, target_f, max_new_tokens=20)\n", - " print(f\"Generation by patching layer {target_layer_id}:\\n{gen}\\n{'='*30}\\n\")" - ] - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": "" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "As we can see, maybe the early layers of gpt2-small are doing something related to entity resolution, whereas the late layers are apparently not(?)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Zero-Shot Feature Extraction\n", - "\n", - "Zero-shot Feature Extraction \"Consider factual and com- monsense knowledge represented as triplets (σ,ρ,ω) of a subject (e.g., “United States”), a relation (e.g., “largest city of”), and an object (e.g.,\n", - "“New York City”). We investigate to what extent the object ω can be extracted from the last token representation of the subject σ in an arbitrary input context.\"\n", - "\n", - "The configuration is l∗ ← j′ ∈ [1,...,L∗], i∗ ← m, T ← relation verbalization followed by x" - ] - }, - { - "cell_type": "code", - "execution_count": 359, - "metadata": {}, - "outputs": [ + ], + "source": [ + "# for a triplet (company Apple, co-founder of, Steve Jobs), we need to first make sure that the object is in the continuation\n", + "source_prompt = \"Co-founder of company Apple\"\n", + "model.generate(source_prompt, verbose=False, max_new_tokens=20, do_sample=False)" + ] + }, { - "data": { - "text/plain": [ - "'Co-founder of company Apple, Steve Jobs, has said that Apple\\'s iPhone 6 and 6 Plus are \"the most important phones'" + "cell_type": "code", + "execution_count": 366, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "<|endoftext|>Co-founder of x, co-founder of Google x, co-founder of Facebook x, co-founder of Twitter\n", + "<|endoftext|>Co-founder of x, co-founder of the company, and co-founder of the company, and co-founder\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to hide\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x has been accused of being a \"fraud\" by the US government.\n", + "\n", + "\n", + "The former\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x's, who has been working on the project for over a year, has been working on the project\n", + "<|endoftext|>Co-founder of x's, who has been working on the project for over a year, has been working on the project\n", + "<|endoftext|>Co-founder of xApple, who has been working on the iPhone since 2011, has been working on the iPhone since 2011\n", + "<|endoftext|>Co-founder of x, co-founder of Google x, co-founder of Facebook x, co-founder of Twitter\n", + "<|endoftext|>Co-founder of x, co-founder of Google, co-founder of Facebook, co-founder of Twitter, co\n", + "<|endoftext|>Co-founder of x, co-founder of x, co-founder of x, co-founder of x, co\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes mobile apps for the iPhone, iPad and iPod touch, says he's been\n", + "<|endoftext|>Co-founder of x, a company that makes a lot of things, has been accused of sexual harassment by a former employee\n", + "<|endoftext|>Co-founder of x's, who has been working on the project for over a year, has been working on the project\n", + "<|endoftext|>Co-founder of x's, who has been working on the project for over a year, has been working on the project\n", + "<|endoftext|>Co-founder of xApple, who has been working on the iPhone since 2011, has been working on the iPhone since 2011\n", + "<|endoftext|>Co-founder of x, co-founder of Google x, co-founder of Facebook x, co-founder of Twitter\n", + "<|endoftext|>Co-founder of x, co-founder of Google x, co-founder of Facebook x, co-founder of Twitter\n", + "<|endoftext|>Co-founder of x, co-founder of Google, co-founder of Facebook, co-founder of Twitter, co\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes software for the web, has been accused of using a \"secret\" code\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x's, who has been working on the project for over a year, has been working on the project\n", + "<|endoftext|>Co-founder of x Apple, Steve Jobs, has been accused of being a \"fraud\" by a former employee who\n", + "<|endoftext|>Co-founder of x, co-founder of Google x, co-founder of Facebook x, co-founder of Twitter\n", + "<|endoftext|>Co-founder of x, co-founder of the company, and co-founder of the company, are the first people\n", + "<|endoftext|>Co-founder of x, co-founder of the company, and co-founder of the company, and co-founder\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes software for the iPhone, has been arrested in the US.\n", + "\n", + "\n", + "\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x's, who has been working on the project for over a year, has been working on the project\n", + "<|endoftext|>Co-founder of xInsurance, a company that provides insurance for people with disabilities, said he's been in touch with\n", + "<|endoftext|>Co-founder of x, CEO Tim Cook, and co-founder of Facebook, Mark Zuckerberg, have been named to the\n", + "<|endoftext|>Co-founder of x, co-founder of the company Apple, and co-founder of the company Apple, and co\n", + "<|endoftext|>Co-founder of x, who has been a vocal critic of the company's recent decision to cut its workforce, has been\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to hide\n", + "<|endoftext|>Co-founder of x, a company that makes software for the iPhone, has been arrested in the US.\n", + "\n", + "\n", + "\n", + "<|endoftext|>Co-founder of x, a company that makes software for the web, has been arrested in the US.\n", + "\n", + "\n", + "\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x's, who has been working on the project for over a year, has been working on the project\n", + "<|endoftext|>Co-founder of xInsurance, a company that provides insurance for people with disabilities, said he's been in touch with\n", + "<|endoftext|>Co-founder of x, CEO Tim Cook, and co-founder of Facebook x, Mark Zuckerberg, are among the most\n", + "<|endoftext|>Co-founder of x, who has been a vocal critic of the company's iPhone 6 and iPhone 6 Plus, has been\n", + "<|endoftext|>Co-founder of x, who has been a vocal critic of the company's recent decision to cut its workforce, has been\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes software for the iPhone, has been arrested in the US.\n", + "\n", + "\n", + "\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x's, who has been working on the project for over a year, has been working on the project\n", + "<|endoftext|>Co-founder of x Watch, a company that helps people with disabilities, says he's been working on a new app for\n", + "<|endoftext|>Co-founder of x, CEO Tim Cook, and co-founder Steve Jobs.\n", + "\n", + "\n", + "The company's new CEO\n", + "<|endoftext|>Co-founder of x, who has been a vocal critic of the company's recent decision to cut its workforce by half,\n", + "<|endoftext|>Co-founder of x, co-founder of the company, and co-founder of the company, and co-founder\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been named the world's most valuable person by Forbes.\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes software for the iPhone, has been arrested in the US.\n", + "\n", + "\n", + "\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x's, who has been working on the project for over a year, has been working on the project\n", + "<|endoftext|>Co-founder of x Watch, a company that helps people with disabilities, says he's been working on a new app for\n", + "<|endoftext|>Co-founder of x, the company's new iPhone, is expected to be unveiled in the coming weeks.\n", + "\n", + "\n", + "\n", + "<|endoftext|>Co-founder of x, co-founder of the company, and co-founder of the company, and co-founder\n", + "<|endoftext|>Co-founder of x, co-founder of the company, and co-founder of the company, and co-founder\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been named the world's most valuable person by Forbes.\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been arrested in the US for allegedly selling a fake iPhone\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x's, who has been working on the project for over a year, has been working on the project\n", + "<|endoftext|>Co-founder of xInsurance, a company that provides insurance for people with disabilities, said he's been in touch with\n", + "<|endoftext|>Co-founder of x, Apple x, Apple x, Apple x, Apple x, Apple x, Apple x, Apple\n", + "<|endoftext|>Co-founder of x, who is now the CEO of Apple, has been named the new CEO of the company.\n", + "\n", + "<|endoftext|>Co-founder of x, co-founder of the company, and co-founder of the company's new product, the\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been named the world's most valuable person by Forbes.\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"bot\" to make a\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the iPhone, has been arrested in the US.\n", + "\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of xbaum.com, who has been a member of the XBAY community for over a decade,\n", + "<|endoftext|>Co-founder of x, Apple x, Apple x, Apple x, Apple x, Apple x, Apple x, Apple\n", + "<|endoftext|>Co-founder of x, who is now the CEO of Apple, has been named the new CEO of the company.\n", + "\n", + "<|endoftext|>Co-founder of x, co-founder of the company, and co-founder of the company, are both on the\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been named the world's most valuable person by Forbes.\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"bot\" to make a\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"bot\" to make a\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been arrested in the US for allegedly selling a fake iPhone\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x Inc. and co-founder of X.com, Mark Zuckerberg, has been accused of using his\n", + "<|endoftext|>Co-founder of x, Apple x, and Apple x.\n", + "Apple x, Apple x, and Apple x.\n", + "\n", + "<|endoftext|>Co-founder of x, a guest Jan 25th, 2016 1,929 Never a guest1,929Never\n", + "<|endoftext|>Co-founder of x, co-founder of the company, and co-founder of the company, are both on the\n", + "<|endoftext|>Co-founder of x, a company that makes the iPhone, has been named the new CEO of Apple.\n", + "\n", + "\n", + "\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"bot\" to make a\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"bot\" to make a\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been arrested in the US for allegedly stealing $1 million\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x Watch, a company that helps people with disabilities, says he's been working on a new app for\n", + "<|endoftext|>Co-founder of x, the iPhone, the iPhone, the iPhone, the iPhone, the iPhone, the iPhone, the\n", + "<|endoftext|>Co-founder of x, the new, the new, the new, the new, the new, the new, the\n", + "<|endoftext|>Co-founder of x, the company, the company, the company, the company, the company, the company, the\n", + "<|endoftext|>Co-founder of x, the company, the company, the company, the company, the company, the company, the\n", + "<|endoftext|>Co-founder of x, the company has announced that the company has released the company has released the company has released the company\n", + "<|endoftext|>Co-founder of x, the company, the company, the company, the company, the company, the company, the\n", + "<|endoftext|>Co-founder of x, the company, said the company is now offering a new product, but the company has now announced\n", + "<|endoftext|>Co-founder of x, the company that has been working on the iPhone, said the company has been working on the iPhone\n", + "<|endoftext|>Co-founder of x, the company that created the iPhone, said the company is now working on a new product, but\n", + "<|endoftext|>Co-founder of x, a company that makes the iPhone, said that the company is working on a new product that will\n", + "<|endoftext|>Co-founder of x, a company that makes the world's most popular mobile phone, has been arrested in the US.\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n" + ] + } + ], + "source": [ + "# Still need an aboslute position\n", + "last_pos_id = len(model.to_tokens([\"Co-founder of x\"])[0]) - 1\n", + "target_prompt = [\"Co-founder of x\"]\n", + "\n", + "# Check all the combinations, you'll see that the model is able to generate \"Steve Jobs\" in several continuations\n", + "for source_layer_id in range(12):\n", + " # Prepare source representation, here we can use relative position\n", + " source_rep = get_source_representation(\n", + " prompts=[\"Co-founder of company Apple\"],\n", + " layer_id=source_layer_id,\n", + " model=model,\n", + " pos_id=-1\n", + " )\n", + " for target_layer_id in range(12):\n", + " target_f = partial(\n", + " feed_source_representation,\n", + " source_rep=source_rep,\n", + " prompt=target_prompt,\n", + " f=identity_function,\n", + " model=model,\n", + " pos_id=last_pos_id,\n", + " layer_id=target_layer_id\n", + " )\n", + " gen = generate_with_patching(model, target_prompt, target_f, max_new_tokens=20)\n", + " print(gen)" ] - }, - "execution_count": 359, - "metadata": {}, - "output_type": "execute_result" } - ], - "source": [ - "# for a triplet (company Apple, co-founder of, Steve Jobs), we need to first make sure that the object is in the continuation\n", - "source_prompt = \"Co-founder of company Apple\"\n", - "model.generate(source_prompt, verbose=False, max_new_tokens=20, do_sample=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 366, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "<|endoftext|>Co-founder of x, co-founder of Google x, co-founder of Facebook x, co-founder of Twitter\n", - "<|endoftext|>Co-founder of x, co-founder of the company, and co-founder of the company, and co-founder\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to hide\n", - "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", - "<|endoftext|>Co-founder of x has been accused of being a \"fraud\" by the US government.\n", - "\n", - "\n", - "The former\n", - "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", - "<|endoftext|>Co-founder of x's, who has been working on the project for over a year, has been working on the project\n", - "<|endoftext|>Co-founder of x's, who has been working on the project for over a year, has been working on the project\n", - "<|endoftext|>Co-founder of xApple, who has been working on the iPhone since 2011, has been working on the iPhone since 2011\n", - "<|endoftext|>Co-founder of x, co-founder of Google x, co-founder of Facebook x, co-founder of Twitter\n", - "<|endoftext|>Co-founder of x, co-founder of Google, co-founder of Facebook, co-founder of Twitter, co\n", - "<|endoftext|>Co-founder of x, co-founder of x, co-founder of x, co-founder of x, co\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", - "<|endoftext|>Co-founder of x, a company that makes mobile apps for the iPhone, iPad and iPod touch, says he's been\n", - "<|endoftext|>Co-founder of x, a company that makes a lot of things, has been accused of sexual harassment by a former employee\n", - "<|endoftext|>Co-founder of x's, who has been working on the project for over a year, has been working on the project\n", - "<|endoftext|>Co-founder of x's, who has been working on the project for over a year, has been working on the project\n", - "<|endoftext|>Co-founder of xApple, who has been working on the iPhone since 2011, has been working on the iPhone since 2011\n", - "<|endoftext|>Co-founder of x, co-founder of Google x, co-founder of Facebook x, co-founder of Twitter\n", - "<|endoftext|>Co-founder of x, co-founder of Google x, co-founder of Facebook x, co-founder of Twitter\n", - "<|endoftext|>Co-founder of x, co-founder of Google, co-founder of Facebook, co-founder of Twitter, co\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", - "<|endoftext|>Co-founder of x, a company that makes software for the web, has been accused of using a \"secret\" code\n", - "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", - "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", - "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", - "<|endoftext|>Co-founder of x's, who has been working on the project for over a year, has been working on the project\n", - "<|endoftext|>Co-founder of x Apple, Steve Jobs, has been accused of being a \"fraud\" by a former employee who\n", - "<|endoftext|>Co-founder of x, co-founder of Google x, co-founder of Facebook x, co-founder of Twitter\n", - "<|endoftext|>Co-founder of x, co-founder of the company, and co-founder of the company, are the first people\n", - "<|endoftext|>Co-founder of x, co-founder of the company, and co-founder of the company, and co-founder\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", - "<|endoftext|>Co-founder of x, a company that makes software for the iPhone, has been arrested in the US.\n", - "\n", - "\n", - "\n", - "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", - "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", - "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", - "<|endoftext|>Co-founder of x's, who has been working on the project for over a year, has been working on the project\n", - "<|endoftext|>Co-founder of xInsurance, a company that provides insurance for people with disabilities, said he's been in touch with\n", - "<|endoftext|>Co-founder of x, CEO Tim Cook, and co-founder of Facebook, Mark Zuckerberg, have been named to the\n", - "<|endoftext|>Co-founder of x, co-founder of the company Apple, and co-founder of the company Apple, and co\n", - "<|endoftext|>Co-founder of x, who has been a vocal critic of the company's recent decision to cut its workforce, has been\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to hide\n", - "<|endoftext|>Co-founder of x, a company that makes software for the iPhone, has been arrested in the US.\n", - "\n", - "\n", - "\n", - "<|endoftext|>Co-founder of x, a company that makes software for the web, has been arrested in the US.\n", - "\n", - "\n", - "\n", - "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", - "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", - "<|endoftext|>Co-founder of x's, who has been working on the project for over a year, has been working on the project\n", - "<|endoftext|>Co-founder of xInsurance, a company that provides insurance for people with disabilities, said he's been in touch with\n", - "<|endoftext|>Co-founder of x, CEO Tim Cook, and co-founder of Facebook x, Mark Zuckerberg, are among the most\n", - "<|endoftext|>Co-founder of x, who has been a vocal critic of the company's iPhone 6 and iPhone 6 Plus, has been\n", - "<|endoftext|>Co-founder of x, who has been a vocal critic of the company's recent decision to cut its workforce, has been\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", - "<|endoftext|>Co-founder of x, a company that makes software for the iPhone, has been arrested in the US.\n", - "\n", - "\n", - "\n", - "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", - "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", - "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", - "<|endoftext|>Co-founder of x's, who has been working on the project for over a year, has been working on the project\n", - "<|endoftext|>Co-founder of x Watch, a company that helps people with disabilities, says he's been working on a new app for\n", - "<|endoftext|>Co-founder of x, CEO Tim Cook, and co-founder Steve Jobs.\n", - "\n", - "\n", - "The company's new CEO\n", - "<|endoftext|>Co-founder of x, who has been a vocal critic of the company's recent decision to cut its workforce by half,\n", - "<|endoftext|>Co-founder of x, co-founder of the company, and co-founder of the company, and co-founder\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been named the world's most valuable person by Forbes.\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", - "<|endoftext|>Co-founder of x, a company that makes software for the iPhone, has been arrested in the US.\n", - "\n", - "\n", - "\n", - "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", - "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", - "<|endoftext|>Co-founder of x's, who has been working on the project for over a year, has been working on the project\n", - "<|endoftext|>Co-founder of x Watch, a company that helps people with disabilities, says he's been working on a new app for\n", - "<|endoftext|>Co-founder of x, the company's new iPhone, is expected to be unveiled in the coming weeks.\n", - "\n", - "\n", - "\n", - "<|endoftext|>Co-founder of x, co-founder of the company, and co-founder of the company, and co-founder\n", - "<|endoftext|>Co-founder of x, co-founder of the company, and co-founder of the company, and co-founder\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been named the world's most valuable person by Forbes.\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been arrested in the US for allegedly selling a fake iPhone\n", - "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", - "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", - "<|endoftext|>Co-founder of x's, who has been working on the project for over a year, has been working on the project\n", - "<|endoftext|>Co-founder of xInsurance, a company that provides insurance for people with disabilities, said he's been in touch with\n", - "<|endoftext|>Co-founder of x, Apple x, Apple x, Apple x, Apple x, Apple x, Apple x, Apple\n", - "<|endoftext|>Co-founder of x, who is now the CEO of Apple, has been named the new CEO of the company.\n", - "\n", - "<|endoftext|>Co-founder of x, co-founder of the company, and co-founder of the company's new product, the\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been named the world's most valuable person by Forbes.\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"bot\" to make a\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", - "<|endoftext|>Co-founder of x, a startup that helps people build apps for the iPhone, has been arrested in the US.\n", - "\n", - "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", - "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", - "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", - "<|endoftext|>Co-founder of xbaum.com, who has been a member of the XBAY community for over a decade,\n", - "<|endoftext|>Co-founder of x, Apple x, Apple x, Apple x, Apple x, Apple x, Apple x, Apple\n", - "<|endoftext|>Co-founder of x, who is now the CEO of Apple, has been named the new CEO of the company.\n", - "\n", - "<|endoftext|>Co-founder of x, co-founder of the company, and co-founder of the company, are both on the\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been named the world's most valuable person by Forbes.\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"bot\" to make a\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"bot\" to make a\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been arrested in the US for allegedly selling a fake iPhone\n", - "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", - "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", - "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", - "<|endoftext|>Co-founder of x Inc. and co-founder of X.com, Mark Zuckerberg, has been accused of using his\n", - "<|endoftext|>Co-founder of x, Apple x, and Apple x.\n", - "Apple x, Apple x, and Apple x.\n", - "\n", - "<|endoftext|>Co-founder of x, a guest Jan 25th, 2016 1,929 Never a guest1,929Never\n", - "<|endoftext|>Co-founder of x, co-founder of the company, and co-founder of the company, are both on the\n", - "<|endoftext|>Co-founder of x, a company that makes the iPhone, has been named the new CEO of Apple.\n", - "\n", - "\n", - "\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"bot\" to make a\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"bot\" to make a\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", - "<|endoftext|>Co-founder of x, a company that makes smart phones, has been arrested in the US for allegedly stealing $1 million\n", - "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", - "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", - "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", - "<|endoftext|>Co-founder of x Watch, a company that helps people with disabilities, says he's been working on a new app for\n", - "<|endoftext|>Co-founder of x, the iPhone, the iPhone, the iPhone, the iPhone, the iPhone, the iPhone, the\n", - "<|endoftext|>Co-founder of x, the new, the new, the new, the new, the new, the new, the\n", - "<|endoftext|>Co-founder of x, the company, the company, the company, the company, the company, the company, the\n", - "<|endoftext|>Co-founder of x, the company, the company, the company, the company, the company, the company, the\n", - "<|endoftext|>Co-founder of x, the company has announced that the company has released the company has released the company has released the company\n", - "<|endoftext|>Co-founder of x, the company, the company, the company, the company, the company, the company, the\n", - "<|endoftext|>Co-founder of x, the company, said the company is now offering a new product, but the company has now announced\n", - "<|endoftext|>Co-founder of x, the company that has been working on the iPhone, said the company has been working on the iPhone\n", - "<|endoftext|>Co-founder of x, the company that created the iPhone, said the company is now working on a new product, but\n", - "<|endoftext|>Co-founder of x, a company that makes the iPhone, said that the company is working on a new product that will\n", - "<|endoftext|>Co-founder of x, a company that makes the world's most popular mobile phone, has been arrested in the US.\n", - "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n" - ] + ], + "metadata": { + "kernelspec": { + "display_name": "mechinterp", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.19" } - ], - "source": [ - "# Still need an aboslute position\n", - "last_pos_id = len(model.to_tokens([\"Co-founder of x\"])[0]) - 1\n", - "target_prompt = [\"Co-founder of x\"]\n", - "\n", - "# Check all the combinations, you'll see that the model is able to generate \"Steve Jobs\" in several continuations\n", - "for source_layer_id in range(12):\n", - " # Prepare source representation, here we can use relative position\n", - " source_rep = get_source_representation(\n", - " prompts=[\"Co-founder of company Apple\"],\n", - " layer_id=source_layer_id,\n", - " model=model,\n", - " pos_id=-1\n", - " )\n", - " for target_layer_id in range(12):\n", - " target_f = partial(\n", - " feed_source_representation,\n", - " source_rep=source_rep,\n", - " prompt=target_prompt,\n", - " f=identity_function,\n", - " model=model,\n", - " pos_id=last_pos_id,\n", - " layer_id=target_layer_id\n", - " )\n", - " gen = generate_with_patching(model, target_prompt, target_f, max_new_tokens=20)\n", - " print(gen)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "mechinterp", - "language": "python", - "name": "python3" }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.19" - } - }, - "nbformat": 4, - "nbformat_minor": 2 + "nbformat": 4, + "nbformat_minor": 2 } From 00c21411a8a63be24075f72e96f4294febf7a565 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Wed, 4 Mar 2026 08:32:35 -0600 Subject: [PATCH 3/3] Fixes to ensure functionality with v3.x --- demos/Patchscopes_Generation_Demo.ipynb | 118 +----------------------- 1 file changed, 4 insertions(+), 114 deletions(-) diff --git a/demos/Patchscopes_Generation_Demo.ipynb b/demos/Patchscopes_Generation_Demo.ipynb index b249d112e..8f06af4cc 100644 --- a/demos/Patchscopes_Generation_Demo.ipynb +++ b/demos/Patchscopes_Generation_Demo.ipynb @@ -30,48 +30,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", - "import os\n", - "\n", - "DEBUG_MODE = False\n", - "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", - "try:\n", - " import google.colab\n", - "\n", - " IN_COLAB = True\n", - " print(\"Running as a Colab notebook\")\n", - "except:\n", - " IN_COLAB = False\n", - " print(\"Running as a Jupyter notebook - intended for development only!\")\n", - " from IPython import get_ipython\n", - "\n", - " ipython = get_ipython()\n", - " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", - " ipython.magic(\"load_ext autoreload\")\n", - " ipython.magic(\"autoreload 2\")\n", - "\n", - "if IN_COLAB or IN_GITHUB:\n", - " %pip install transformer_lens\n", - " %pip install torchtyping\n", - " # Install my janky personal plotting utils\n", - " %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n", - " # Install another version of node that makes PySvelte work way faster\n", - " %pip install circuitsvis\n", - " # Needed for PySvelte to work, v3 came out and broke things...\n", - " %pip install typeguard==2.13.3\n", - "\n", - "import torch\n", - "from typing import List, Callable, Tuple, Union\n", - "from functools import partial\n", - "from jaxtyping import Float\n", - "from transformer_lens.model_bridge import TransformerBridge\n", - "from transformer_lens.ActivationCache import ActivationCache\n", - "import transformer_lens.utils as utils\n", - "from transformer_lens.hook_points import (\n", - " HookPoint,\n", - ") # Hooking utilities" - ] + "source": "# NBVAL_IGNORE_OUTPUT\n# Janky code to do different setup when run in a Colab notebook vs VSCode\nimport os\n\nDEBUG_MODE = False\nIN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\ntry:\n import google.colab\n\n IN_COLAB = True\n print(\"Running as a Colab notebook\")\nexcept:\n IN_COLAB = False\n\nif not IN_GITHUB and not IN_COLAB:\n print(\"Running as a Jupyter notebook - intended for development only!\")\n from IPython import get_ipython\n\n ipython = get_ipython()\n # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n ipython.run_line_magic(\"load_ext\", \"autoreload\")\n ipython.run_line_magic(\"autoreload\", \"2\")\n\nif IN_COLAB or IN_GITHUB:\n %pip install transformer_lens\n %pip install torchtyping\n # Install my janky personal plotting utils\n %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n # Install another version of node that makes PySvelte work way faster\n %pip install circuitsvis\n # Needed for PySvelte to work, v3 came out and broke things...\n %pip install typeguard==2.13.3\n\nimport torch\nfrom typing import List, Callable, Tuple, Union\nfrom functools import partial\nfrom jaxtyping import Float\nfrom transformer_lens.model_bridge import TransformerBridge\nfrom transformer_lens.ActivationCache import ActivationCache\nimport transformer_lens.utils as utils\nfrom transformer_lens.hook_points import (\n HookPoint,\n) # Hooking utilities" }, { "cell_type": "markdown", @@ -150,77 +109,8 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Loaded pretrained model gpt2-small into HookedTransformer\n" - ] - }, - { - "data": { - "text/plain": [ - "HookedTransformer(\n", - " (embed): Embed()\n", - " (hook_embed): HookPoint()\n", - " (pos_embed): PosEmbed()\n", - " (hook_pos_embed): HookPoint()\n", - " (blocks): ModuleList(\n", - " (0-11): 12 x TransformerBlock(\n", - " (ln1): LayerNormPre(\n", - " (hook_scale): HookPoint()\n", - " (hook_normalized): HookPoint()\n", - " )\n", - " (ln2): LayerNormPre(\n", - " (hook_scale): HookPoint()\n", - " (hook_normalized): HookPoint()\n", - " )\n", - " (attn): Attention(\n", - " (hook_k): HookPoint()\n", - " (hook_q): HookPoint()\n", - " (hook_v): HookPoint()\n", - " (hook_z): HookPoint()\n", - " (hook_attn_scores): HookPoint()\n", - " (hook_pattern): HookPoint()\n", - " (hook_result): HookPoint()\n", - " )\n", - " (mlp): MLP(\n", - " (hook_pre): HookPoint()\n", - " (hook_post): HookPoint()\n", - " )\n", - " (hook_attn_in): HookPoint()\n", - " (hook_q_input): HookPoint()\n", - " (hook_k_input): HookPoint()\n", - " (hook_v_input): HookPoint()\n", - " (hook_mlp_in): HookPoint()\n", - " (hook_attn_out): HookPoint()\n", - " (hook_mlp_out): HookPoint()\n", - " (hook_resid_pre): HookPoint()\n", - " (hook_resid_mid): HookPoint()\n", - " (hook_resid_post): HookPoint()\n", - " )\n", - " )\n", - " (ln_final): LayerNormPre(\n", - " (hook_scale): HookPoint()\n", - " (hook_normalized): HookPoint()\n", - " )\n", - " (unembed): Unembed()\n", - ")" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# NBVAL_IGNORE_OUTPUT\n", - "# I'm using an M2 macbook air, so I use CPU for better support\n", - "model = TransformerBridge.boot_transformers(\"gpt2\", device=\"cpu\")\n", - "model.enable_compatibility_mode()\n", - "model.eval()" - ] + "outputs": [], + "source": "# NBVAL_IGNORE_OUTPUT\n# I'm using an M2 macbook air, so I use CPU for better support\nmodel = TransformerBridge.boot_transformers(\"gpt2\", device=\"cpu\")\nmodel.enable_compatibility_mode()\nmodel.eval()" }, { "cell_type": "markdown", @@ -3774,4 +3664,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file