# Tier-2 leaf libs pinned exactly. Tier-1 (torch/CUDA/python) inherited from the# pinned Kaggle image — do NOT pip-install torch here or you'll break the GPU.#%pip install -q "sae-lens==6.44.2" "transformer-lens==2.18.0" "sae-dashboard==0.7.2" "numpy==1.26.4"::: {#cell-0 .cell _uuid=‘8f2839f25d086af736a60e9eeb907d3b93b6e0e5’ _cell_guid=‘b1076dfc-b9ad-4769-8c92-a6c4dae69d19’ trusted=‘true’ quarto-private-1=‘{“key”:“execution”,“value”:{“iopub.status.busy”:“2026-06-04T18:27:22.209240Z”,“iopub.execute_input”:“2026-06-04T18:27:22.209621Z”,“iopub.status.idle”:“2026-06-04T18:27:22.217645Z”,“shell.execute_reply.started”:“2026-06-04T18:27:22.209583Z”,“shell.execute_reply”:“2026-06-04T18:27:22.216738Z”}}’ execution_count=4}
# This Python 3 environment comes with many helpful analytics libraries installed# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python# For example, here's several helpful packages to loadimport numpy as np # linear algebraimport pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)# Input data files are available in the read-only "../input/" directory# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directoryimport osfor dirname, _, filenames in os.walk('/kaggle/input'): for filename in filenames: print(os.path.join(dirname, filename))# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" # You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session# Use the kagglehub client library to attach Kaggle resources like competitions, datasets, and models to your session# Learn more about kagglehub: https://github.com/Kaggle/kagglehub/blob/main/README.mdimport kagglehub# kagglehub.dataset_download('<owner>/<dataset-slug>')/kaggle/input/datasets/nikolazhuk/interp-emotions-prompts/prompts/joy.py
/kaggle/input/datasets/nikolazhuk/interp-emotions-prompts/prompts/neutral.py
:::
In [5]:
In [6]:
import osos.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"os.environ["HF_HOME"] = "/tmp/hf" # ephemeral scratch, not /contentfrom kaggle_secrets import UserSecretsClientos.environ["HF_TOKEN"] = UserSecretsClient().get_secret("HF_TOKEN")In [7]:
import torchtorch.set_grad_enabled(False)device = "cuda" if torch.cuda.is_available() else "cpu"# Fail fast if the leaf install disturbed tier-1:print("torch", torch.__version__, "| cuda", torch.version.cuda, "| device", device)assert device == "cuda", "GPU not visible — check Accelerator setting / install didn't swap torch"torch 2.10.0+cu128 | cuda 12.8 | device cuda
In [8]:
from sae_lens import SAE, HookedSAETransformer# This avoids staging model in CPU RAM during loadmodel = HookedSAETransformer.from_pretrained( "google/gemma-2-2b", dtype=torch.float32, device=device, fold_ln=False, center_writing_weights=False, center_unembed=False, move_to_device=True,)model.eval()print(f"n_layers={model.cfg.n_layers}, d_model={model.cfg.d_model}")2026-06-04 18:27:30.159643: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1780597650.338934 214 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1780597650.390269 214 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1780597650.816956 214 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1780597650.816999 214 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1780597650.817002 214 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1780597650.817004 214 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
`torch_dtype` is deprecated! Use `dtype` instead!
Loaded pretrained model google/gemma-2-2b into HookedTransformer
n_layers=26, d_model=2304
In [9]:
sae, cfg_dict, sparsity = SAE.from_pretrained( release="gemma-scope-2b-pt-res-canonical", sae_id="layer_20/width_16k/canonical", device=device)sae = sae.to(torch.bfloat16)allocated = torch.cuda.memory_allocated() / 1e9print(f"GPU allocated: {allocated:.1f}GB / 15GB")print(f"SAE d_sae: {sae.cfg.d_sae}, d_in: {sae.cfg.d_in}")GPU allocated: 15.2GB / 15GB
SAE d_sae: 16384, d_in: 2304
/tmp/ipykernel_214/2686642472.py:1: DeprecationWarning: Unpacking SAE objects is deprecated. SAE.from_pretrained() now returns only the SAE object. Use SAE.from_pretrained_with_cfg_and_sparsity() to get the config dict and sparsity as well.
sae, cfg_dict, sparsity = SAE.from_pretrained(
In [10]:
#!pip freeze > /kaggle/working/requirements-lock.txtSOURCE: https://colab.research.google.com/drive/1TD5GPwvR1BwrNT2LSVNSVXVBDqIOCXJA?authuser=1#scrollTo=F0W3TC-89MUQThe “sae” object is an instance of the SAE (Sparse Autoencoder class). There are many different SAE architectures which may have different weights or activation functions. In order to simplify working with SAEs, SAE Lens handles most of this complexity for you.Let’s look at the SAE config and understand each of the parameters:1. architecture: Specifies the type of SAE architecture being used, in this case, the standard architecture (encoder and decoder with hidden activations, as opposed to a gated SAE).2. d_in: Defines the input dimension of the SAE, which is [768] in this configuration.3. d_sae: Sets the dimension of the SAE’s hidden layer, which is [24576] here. This represents the number of possible feature activations.4. activation_fn_str: Specifies the activation function used in the SAE, which is ReLU in this case. TopK is another option that we will not cover here.5. apply_b_dec_to_input: Determines whether to apply the decoder bias to the input, set to True here.6. finetuning_scaling_factor: Indicates whether to use a scaling factor to weight initialization and the forward pass. This is not usually used and was introduced to support a solution for shrinkage.7. context_size: Defines the size of the context window, which is [128] tokens in this case. In turns out SAEs trained on small activations from small prompts often don’t perform well on longer prompts.8. model_name: Specifies the name of the model being used, which is ‘gemma-2-2b’ here. This is a valid model name in TransformerLens.9. hook_name: Indicates the specific hook in the model where the SAE is applied.10. hook_head_index: Defines which attention head to hook into; not relevant here since we are looking at a residual stream SAE.11. prepend_bos: Determines whether to prepend the beginning-of-sequence token, set to True.12. dataset_path: Specifies the path to the dataset used for training or evaluation. (Can be local or a huggingface dataset.)13. dataset_trust_remote_code: Indicates whether to trust remote code (from HuggingFace) when loading the dataset, set to True.14. normalize_activations: Specifies how to normalize activations, set to ‘none’ in this config.15. dtype: Defines the data type for tensor operations, set to 32-bit floating point.16. device: Specifies the computational device to use.17. sae_lens_training_version: Indicates the version of SAE Lens used for training, set to None here.18. activation_fn_kwargs: Allows for additional keyword arguments for the activation function. This would be used if e.g. the activation_fn_str was set to topk, so that k could be specified.
In [11]:
print(sae.cfg.__dict__){'d_in': 2304, 'd_sae': 16384, 'dtype': 'bfloat16', 'device': 'cuda', 'apply_b_dec_to_input': False, 'normalize_activations': 'none', 'reshape_activations': 'none', 'metadata': SAEMetadata({'sae_lens_version': '6.44.2', 'sae_lens_training_version': None, 'model_name': 'gemma-2-2b', 'hook_name': 'blocks.20.hook_resid_post', 'hook_head_index': None, 'prepend_bos': True, 'dataset_path': 'monology/pile-uncopyrighted', 'context_size': 1024, 'neuronpedia_id': 'gemma-2-2b/20-gemmascope-res-16k'})}
In [12]:
def get_feature_acts(prompt: str, layer: int = 20): tokens = model.to_tokens(prompt) _, cache = model.run_with_cache( tokens, names_filter=f"blocks.{layer}.hook_resid_post", return_type=None, ) resid = cache[f"blocks.{layer}.hook_resid_post"].squeeze( 0).to(device).float() feature_acts = sae.encode(resid) # fp32 in → fp32 out print("max act:", feature_acts.max().item(), "| nonzero:", (feature_acts > 0).sum().item()) return feature_acts, model.to_str_tokens(prompt)acts, str_tokens = get_feature_acts("The quick brown fox jumps over the lazy dog.")l0_per_token = (acts > 0).sum(dim=1) # [n_tokens] countprint("tokens: ", str_tokens)print("L0/token: ", l0_per_token.tolist())print("last-token L0:", l0_per_token[-1].item(), "| last-token max:", acts[-1].max().item())max act: 2040.0 | nonzero: 7878
tokens: ['<bos>', 'The', ' quick', ' brown', ' fox', ' jumps', ' over', ' the', ' lazy', ' dog', '.']
L0/token: [7012, 29, 62, 88, 105, 110, 147, 101, 96, 68, 60]
last-token L0: 60 | last-token max: 47.5
In [13]:
# load promptsimport sys, os, globhit = glob.glob("/kaggle/input/**/joy.py", recursive=True)assert hit, "joy.py not found under /kaggle/input — attach the prompts dataset (Add Input)"d = os.path.dirname(hit[0])sys.path.insert(0, d)from joy import JOY_PROMPTSfrom neutral import NEUTRAL_PROMPTSprint("loaded from", d)print("joy", len(JOY_PROMPTS), "neutral",len(NEUTRAL_PROMPTS))loaded from /kaggle/input/datasets/nikolazhuk/interp-emotions-prompts/prompts
joy 20 neutral 20
In [14]:
# concept meansdef concept_mean(prompts, layer=20): vecs = [] for p in prompts: acts, _ = get_feature_acts(p, layer) vecs.append(acts[-1]) return torch.stack(vecs).mean(0)joy_mean = concept_mean(JOY_PROMPTS)neutral_mean = concept_mean(NEUTRAL_PROMPTS)print("joy L0", int((joy_mean > 0).sum()))print("neutral L0", int((neutral_mean > 0).sum()))max act: 2040.0 | nonzero: 7842
max act: 2040.0 | nonzero: 8167
max act: 2040.0 | nonzero: 8086
max act: 2040.0 | nonzero: 7672
max act: 2040.0 | nonzero: 7986
max act: 2040.0 | nonzero: 7957
max act: 2040.0 | nonzero: 8067
max act: 2040.0 | nonzero: 7833
max act: 2040.0 | nonzero: 7962
max act: 2040.0 | nonzero: 7880
max act: 2040.0 | nonzero: 8106
max act: 2040.0 | nonzero: 7897
max act: 2040.0 | nonzero: 7897
max act: 2040.0 | nonzero: 8014
max act: 2040.0 | nonzero: 7975
max act: 2040.0 | nonzero: 7843
max act: 2040.0 | nonzero: 7775
max act: 2040.0 | nonzero: 7942
max act: 2040.0 | nonzero: 7946
max act: 2040.0 | nonzero: 7797
max act: 2040.0 | nonzero: 7676
max act: 2040.0 | nonzero: 7702
max act: 2040.0 | nonzero: 7743
max act: 2040.0 | nonzero: 7497
max act: 2040.0 | nonzero: 7816
max act: 2040.0 | nonzero: 7636
max act: 2040.0 | nonzero: 7741
max act: 2040.0 | nonzero: 7618
max act: 2040.0 | nonzero: 7622
max act: 2040.0 | nonzero: 7739
max act: 2040.0 | nonzero: 7601
max act: 2040.0 | nonzero: 7556
max act: 2040.0 | nonzero: 7627
max act: 2040.0 | nonzero: 7623
max act: 2040.0 | nonzero: 7653
max act: 2040.0 | nonzero: 7642
max act: 2040.0 | nonzero: 7764
max act: 2040.0 | nonzero: 7629
max act: 2040.0 | nonzero: 7650
max act: 2040.0 | nonzero: 7560
joy L0 421
neutral L0 526
In [15]:
import pandas as pdNP_URL = "https://neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/"def diff_table(concept_vec, baseline_vec, concept, baseline, k=20): diff = concept_vec - baseline_vec top = torch.topk(diff, k) rows = [] for fid in top.indices.tolist(): rows.append({ "feature_id": fid, "diff_score": diff[fid].item(), concept + "_mean": concept_vec[fid].item(), baseline + "_mean": baseline_vec[fid].item(), "neuronpedia": NP_URL + str(fid), }) return pd.DataFrame(rows)diff_joy_vs_neutral = diff_table(joy_mean, neutral_mean, "joy", "neutral")diff_joy_vs_neutral| feature_id | diff_score | joy_mean | neutral_mean | neuronpedia | |
|---|---|---|---|---|---|
| 0 | 1673 | 19.25000 | 21.25000 | 1.968750 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 1 | 15383 | 16.12500 | 30.75000 | 14.625000 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 2 | 369 | 15.81250 | 29.87500 | 14.062500 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 3 | 4418 | 14.87500 | 21.25000 | 6.343750 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 4 | 58 | 13.25000 | 14.06250 | 0.839844 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 5 | 4456 | 13.06250 | 13.43750 | 0.394531 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 6 | 12341 | 12.56250 | 15.75000 | 3.203125 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 7 | 12545 | 12.06250 | 25.87500 | 13.812500 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 8 | 3877 | 10.31250 | 10.31250 | 0.000000 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 9 | 1858 | 10.00000 | 74.50000 | 64.500000 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 10 | 10057 | 9.68750 | 12.18750 | 2.515625 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 11 | 15919 | 9.68750 | 10.06250 | 0.347656 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 12 | 2914 | 9.50000 | 29.37500 | 19.875000 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 13 | 13554 | 9.37500 | 10.75000 | 1.343750 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 14 | 5890 | 8.31250 | 18.87500 | 10.562500 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 15 | 3734 | 8.18750 | 8.68750 | 0.515625 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 16 | 8292 | 7.43750 | 11.25000 | 3.796875 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 17 | 15714 | 7.37500 | 9.93750 | 2.546875 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 18 | 1621 | 7.15625 | 7.15625 | 0.000000 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 19 | 10624 | 6.90625 | 6.90625 | 0.000000 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
In [16]:
from IPython.display import HTMLdef link(fid): url = NP_URL + str(fid) return '<a href="' + url + '" target="_blank">' + str(fid) + '</a>'def show_links(df): d = df.copy() d["neuronpedia"] = d["feature_id"].apply(link) return HTML(d.to_html(escape=False))show_links(diff_joy_vs_neutral)| feature_id | diff_score | joy_mean | neutral_mean | neuronpedia | |
|---|---|---|---|---|---|
| 0 | 1673 | 19.25000 | 21.25000 | 1.968750 | 1673 |
| 1 | 15383 | 16.12500 | 30.75000 | 14.625000 | 15383 |
| 2 | 369 | 15.81250 | 29.87500 | 14.062500 | 369 |
| 3 | 4418 | 14.87500 | 21.25000 | 6.343750 | 4418 |
| 4 | 58 | 13.25000 | 14.06250 | 0.839844 | 58 |
| 5 | 4456 | 13.06250 | 13.43750 | 0.394531 | 4456 |
| 6 | 12341 | 12.56250 | 15.75000 | 3.203125 | 12341 |
| 7 | 12545 | 12.06250 | 25.87500 | 13.812500 | 12545 |
| 8 | 3877 | 10.31250 | 10.31250 | 0.000000 | 3877 |
| 9 | 1858 | 10.00000 | 74.50000 | 64.500000 | 1858 |
| 10 | 10057 | 9.68750 | 12.18750 | 2.515625 | 10057 |
| 11 | 15919 | 9.68750 | 10.06250 | 0.347656 | 15919 |
| 12 | 2914 | 9.50000 | 29.37500 | 19.875000 | 2914 |
| 13 | 13554 | 9.37500 | 10.75000 | 1.343750 | 13554 |
| 14 | 5890 | 8.31250 | 18.87500 | 10.562500 | 5890 |
| 15 | 3734 | 8.18750 | 8.68750 | 0.515625 | 3734 |
| 16 | 8292 | 7.43750 | 11.25000 | 3.796875 | 8292 |
| 17 | 15714 | 7.37500 | 9.93750 | 2.546875 | 15714 |
| 18 | 1621 | 7.15625 | 7.15625 | 0.000000 | 1621 |
| 19 | 10624 | 6.90625 | 6.90625 | 0.000000 | 10624 |
In [17]:
import numpy as np fid = int(diff_joy_vs_neutral.iloc[0]["feature_id"]) import numpy as np fid = int(diff_joy_vs_neutral.iloc[0]["feature_id"]) joy_hits = [] for p in JOY_PROMPTS: acts, _ = get_feature_acts(p) joy_hits.append(acts[-1, fid].item())max act: 2040.0 | nonzero: 7842
max act: 2040.0 | nonzero: 8167
max act: 2040.0 | nonzero: 8086
max act: 2040.0 | nonzero: 7672
max act: 2040.0 | nonzero: 7986
max act: 2040.0 | nonzero: 7957
max act: 2040.0 | nonzero: 8067
max act: 2040.0 | nonzero: 7833
max act: 2040.0 | nonzero: 7962
max act: 2040.0 | nonzero: 7880
max act: 2040.0 | nonzero: 8106
max act: 2040.0 | nonzero: 7897
max act: 2040.0 | nonzero: 7897
max act: 2040.0 | nonzero: 8014
max act: 2040.0 | nonzero: 7975
max act: 2040.0 | nonzero: 7843
max act: 2040.0 | nonzero: 7775
max act: 2040.0 | nonzero: 7942
max act: 2040.0 | nonzero: 7946
max act: 2040.0 | nonzero: 7797
In [18]:
# check #1 feature prompt-by prompt:import numpy as npfid = int(diff_joy_vs_neutral.iloc[0]["feature_id"])joy_hits = [get_feature_acts(p)[0][-1, fid].item() for p in JOY_PROMPTS]neu_hits = [get_feature_acts(p)[0][-1, fid].item() for p in NEUTRAL_PROMPTS]print(f"feature {fid}")print(f" joy: mean={np.mean(joy_hits):6.2f} fires {sum(h>0 for h in joy_hits)}/{len(joy_hits)}")print(f" neutral: mean={np.mean(neu_hits):6.2f} fires {sum(h>0 for h in neu_hits)}/{len(neu_hits)}")print(f" https://neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{fid}")max act: 2040.0 | nonzero: 7842
max act: 2040.0 | nonzero: 8167
max act: 2040.0 | nonzero: 8086
max act: 2040.0 | nonzero: 7672
max act: 2040.0 | nonzero: 7986
max act: 2040.0 | nonzero: 7957
max act: 2040.0 | nonzero: 8067
max act: 2040.0 | nonzero: 7833
max act: 2040.0 | nonzero: 7962
max act: 2040.0 | nonzero: 7880
max act: 2040.0 | nonzero: 8106
max act: 2040.0 | nonzero: 7897
max act: 2040.0 | nonzero: 7897
max act: 2040.0 | nonzero: 8014
max act: 2040.0 | nonzero: 7975
max act: 2040.0 | nonzero: 7843
max act: 2040.0 | nonzero: 7775
max act: 2040.0 | nonzero: 7942
max act: 2040.0 | nonzero: 7946
max act: 2040.0 | nonzero: 7797
max act: 2040.0 | nonzero: 7676
max act: 2040.0 | nonzero: 7702
max act: 2040.0 | nonzero: 7743
max act: 2040.0 | nonzero: 7497
max act: 2040.0 | nonzero: 7816
max act: 2040.0 | nonzero: 7636
max act: 2040.0 | nonzero: 7741
max act: 2040.0 | nonzero: 7618
max act: 2040.0 | nonzero: 7622
max act: 2040.0 | nonzero: 7739
max act: 2040.0 | nonzero: 7601
max act: 2040.0 | nonzero: 7556
max act: 2040.0 | nonzero: 7627
max act: 2040.0 | nonzero: 7623
max act: 2040.0 | nonzero: 7653
max act: 2040.0 | nonzero: 7642
max act: 2040.0 | nonzero: 7764
max act: 2040.0 | nonzero: 7629
max act: 2040.0 | nonzero: 7650
max act: 2040.0 | nonzero: 7560
feature 1673
joy: mean= 21.21 fires 19/20
neutral: mean= 1.97 fires 3/20
https://neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/1673
Visualisation
In [20]:
#%pip install -q circuitsvisimport circuitsvis as cvfrom IPython.display import display# Pick your top joy featurefid = int(diff_joy_vs_neutral.iloc[0]["feature_id"])print(f"Visualizing feature {fid}")def viz_feature_on_prompt(prompt: str, fid: int, layer: int = 20): """Color each token by its activation on a single SAE feature.""" acts, str_tokens = get_feature_acts(prompt, layer=layer) # acts: (n_tokens, d_sae). We want column `fid` -> (n_tokens,) values = acts[:, fid].cpu().tolist() return cv.tokens.colored_tokens(tokens=str_tokens, values=values)# Show on 5 joy promptsfor p in JOY_PROMPTS[:5]: display(viz_feature_on_prompt(p, fid))# And 5 neutral prompts — should be visually mutedprint("\n--- Neutral controls ---")for p in NEUTRAL_PROMPTS[:5]: display(viz_feature_on_prompt(p, fid))━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.8/1.8 MB 24.9 MB/s eta 0:00:0000:0100:01 Note: you may need to restart the kernel to use updated packages. Visualizing feature 1673 max act: 2040.0 | nonzero: 7842
max act: 2040.0 | nonzero: 8167
max act: 2040.0 | nonzero: 8086
max act: 2040.0 | nonzero: 7672
max act: 2040.0 | nonzero: 7986
--- Neutral controls ---
max act: 2040.0 | nonzero: 7676
max act: 2040.0 | nonzero: 7702
max act: 2040.0 | nonzero: 7743
max act: 2040.0 | nonzero: 7497
max act: 2040.0 | nonzero: 7816
In [24]:
# comparing top3 featuresdef viz_topk_features_on_prompt(prompt: str, fids: list, layer: int = 20): """One prompt, multiple features, colored side-by-side.""" acts, str_tokens = get_feature_acts(prompt, layer=layer) values = torch.stack([acts[:, fid] for fid in fids], dim=1) # [n_tokens, k] labels = [f"feat {fid}" for fid in fids] return cv.tokens.colored_tokens_multi( tokens=str_tokens, values=values.cpu().float(), labels=labels, )top3 = diff_joy_vs_neutral.iloc[:3]["feature_id"].astype(int).tolist()display(viz_topk_features_on_prompt(JOY_PROMPTS[0], top3))max act: 2040.0 | nonzero: 7842
In [4]: