diff --git a/demos/Patchscopes_Generation_Demo.ipynb b/demos/Patchscopes_Generation_Demo.ipynb index 2a9109154..8f06af4cc 100644 --- a/demos/Patchscopes_Generation_Demo.ipynb +++ b/demos/Patchscopes_Generation_Demo.ipynb @@ -30,48 +30,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", - "import os\n", - "\n", - "DEBUG_MODE = False\n", - "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", - "try:\n", - " import google.colab\n", - "\n", - " IN_COLAB = True\n", - " print(\"Running as a Colab notebook\")\n", - "except:\n", - " IN_COLAB = False\n", - " print(\"Running as a Jupyter notebook - intended for development only!\")\n", - " from IPython import get_ipython\n", - "\n", - " ipython = get_ipython()\n", - " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", - " ipython.run_line_magic(\"load_ext\", \"autoreload\")\n", - " ipython.run_line_magic(\"autoreload\", \"2\")\n", - "\n", - "if IN_COLAB or IN_GITHUB:\n", - " %pip install transformer_lens\n", - " %pip install torchtyping\n", - " # Install my janky personal plotting utils\n", - " %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n", - " # Install another version of node that makes PySvelte work way faster\n", - " %pip install circuitsvis\n", - " # Needed for PySvelte to work, v3 came out and broke things...\n", - " %pip install typeguard==2.13.3\n", - "\n", - "import torch\n", - "from typing import List, Callable, Tuple, Union\n", - "from functools import partial\n", - "from jaxtyping import Float\n", - "from transformer_lens import HookedTransformer\n", - "from transformer_lens.ActivationCache import ActivationCache\n", - "import transformer_lens.utils as utils\n", - "from transformer_lens.hook_points import (\n", - " HookPoint,\n", - ") # Hooking utilities" - ] + "source": "# NBVAL_IGNORE_OUTPUT\n# Janky code to do different setup when run in a Colab notebook vs VSCode\nimport os\n\nDEBUG_MODE = False\nIN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\ntry:\n import google.colab\n\n IN_COLAB = True\n print(\"Running as a Colab notebook\")\nexcept:\n IN_COLAB = False\n\nif not IN_GITHUB and not IN_COLAB:\n print(\"Running as a Jupyter notebook - intended for development only!\")\n from IPython import get_ipython\n\n ipython = get_ipython()\n # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n ipython.run_line_magic(\"load_ext\", \"autoreload\")\n ipython.run_line_magic(\"autoreload\", \"2\")\n\nif IN_COLAB or IN_GITHUB:\n %pip install transformer_lens\n %pip install torchtyping\n # Install my janky personal plotting utils\n %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n # Install another version of node that makes PySvelte work way faster\n %pip install circuitsvis\n # Needed for PySvelte to work, v3 came out and broke things...\n %pip install typeguard==2.13.3\n\nimport torch\nfrom typing import List, Callable, Tuple, Union\nfrom functools import partial\nfrom jaxtyping import Float\nfrom transformer_lens.model_bridge import TransformerBridge\nfrom transformer_lens.ActivationCache import ActivationCache\nimport transformer_lens.utils as utils\nfrom transformer_lens.hook_points import (\n HookPoint,\n) # Hooking utilities" }, { "cell_type": "markdown", @@ -148,78 +107,10 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Loaded pretrained model gpt2-small into HookedTransformer\n" - ] - }, - { - "data": { - "text/plain": [ - "HookedTransformer(\n", - " (embed): Embed()\n", - " (hook_embed): HookPoint()\n", - " (pos_embed): PosEmbed()\n", - " (hook_pos_embed): HookPoint()\n", - " (blocks): ModuleList(\n", - " (0-11): 12 x TransformerBlock(\n", - " (ln1): LayerNormPre(\n", - " (hook_scale): HookPoint()\n", - " (hook_normalized): HookPoint()\n", - " )\n", - " (ln2): LayerNormPre(\n", - " (hook_scale): HookPoint()\n", - " (hook_normalized): HookPoint()\n", - " )\n", - " (attn): Attention(\n", - " (hook_k): HookPoint()\n", - " (hook_q): HookPoint()\n", - " (hook_v): HookPoint()\n", - " (hook_z): HookPoint()\n", - " (hook_attn_scores): HookPoint()\n", - " (hook_pattern): HookPoint()\n", - " (hook_result): HookPoint()\n", - " )\n", - " (mlp): MLP(\n", - " (hook_pre): HookPoint()\n", - " (hook_post): HookPoint()\n", - " )\n", - " (hook_attn_in): HookPoint()\n", - " (hook_q_input): HookPoint()\n", - " (hook_k_input): HookPoint()\n", - " (hook_v_input): HookPoint()\n", - " (hook_mlp_in): HookPoint()\n", - " (hook_attn_out): HookPoint()\n", - " (hook_mlp_out): HookPoint()\n", - " (hook_resid_pre): HookPoint()\n", - " (hook_resid_mid): HookPoint()\n", - " (hook_resid_post): HookPoint()\n", - " )\n", - " )\n", - " (ln_final): LayerNormPre(\n", - " (hook_scale): HookPoint()\n", - " (hook_normalized): HookPoint()\n", - " )\n", - " (unembed): Unembed()\n", - ")" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# NBVAL_IGNORE_OUTPUT\n", - "# I'm using an M2 macbook air, so I use CPU for better support\n", - "model = HookedTransformer.from_pretrained(\"gpt2-small\", device=\"cpu\")\n", - "model.eval()" - ] + "outputs": [], + "source": "# NBVAL_IGNORE_OUTPUT\n# I'm using an M2 macbook air, so I use CPU for better support\nmodel = TransformerBridge.boot_transformers(\"gpt2\", device=\"cpu\")\nmodel.enable_compatibility_mode()\nmodel.eval()" }, { "cell_type": "markdown", @@ -263,17 +154,17 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "def get_source_representation(prompts: List[str], layer_id: int, model: HookedTransformer, pos_id: Union[int, List[int]]=None) -> torch.Tensor:\n", + "def get_source_representation(prompts: List[str], layer_id: int, model: TransformerBridge, pos_id: Union[int, List[int]]=None) -> torch.Tensor:\n", " \"\"\"Get source hidden representation represented by (S, i, M, l)\n", " \n", " Args:\n", " - prompts (List[str]): a list of source prompts\n", " - layer_id (int): the layer id of the model\n", - " - model (HookedTransformer): the source model\n", + " - model (TransformerBridge): the source model\n", " - pos_id (Union[int, List[int]]): the position id(s) of the model, if None, return all positions\n", "\n", " Returns:\n", @@ -325,19 +216,19 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# recall the target representation (T,i*,f,M*,l*), and we also need the hidden representation from our source model (S, i, M, l)\n", - "def feed_source_representation(source_rep: torch.Tensor, prompt: List[str], f: Callable, model: HookedTransformer, layer_id: int, pos_id: Union[int, List[int]]=None) -> ActivationCache:\n", + "def feed_source_representation(source_rep: torch.Tensor, prompt: List[str], f: Callable, model: TransformerBridge, layer_id: int, pos_id: Union[int, List[int]]=None) -> ActivationCache:\n", " \"\"\"Feed the source hidden representation to the target model\n", " \n", " Args:\n", " - source_rep (torch.Tensor): the source hidden representation\n", " - prompt (List[str]): the target prompt\n", " - f (Callable): the mapping function\n", - " - model (HookedTransformer): the target model\n", + " - model (TransformerBridge): the target model\n", " - layer_id (int): the layer id of the target model\n", " - pos_id (Union[int, List[int]]): the position id(s) of the target model, if None, return all positions\n", " \"\"\"\n", @@ -417,11 +308,11 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "def generate_with_patching(model: HookedTransformer, prompts: List[str], target_f: Callable, max_new_tokens: int = 50):\n", + "def generate_with_patching(model: TransformerBridge, prompts: List[str], target_f: Callable, max_new_tokens: int = 50):\n", " temp_prompts = prompts\n", " input_tokens = model.to_tokens(temp_prompts)\n", " for _ in range(max_new_tokens):\n", @@ -3494,13 +3385,6 @@ " print(f\"Generation by patching layer {target_layer_id}:\\n{gen}\\n{'='*30}\\n\")" ] }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": "" - }, { "cell_type": "markdown", "metadata": {}, @@ -3780,4 +3664,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file