torch 2.10.0+cu128 | cuda 12.8 | device cuda
Reproducing emotion-concept features on an open-weight model with Sparse Autoencoder (SAE).
Based on latest paper “Emotion Concepts and their Function in a Large Language Model”.
Loaded pretrained model google/gemma-2-2b into HookedTransformer
n_layers=26, d_model=2304
GPU allocated: 15.2GB / 15GB
SAE d_sae: 16384, d_in: 2304
SOURCE: https://colab.research.google.com/drive/1TD5GPwvR1BwrNT2LSVNSVXVBDqIOCXJA?authuser=1#scrollTo=F0W3TC-89MUQThe “sae” object is an instance of the SAE (Sparse Autoencoder class). There are many different SAE architectures which may have different weights or activation functions. In order to simplify working with SAEs, SAE Lens handles most of this complexity for you.Let’s look at the SAE config and understand each of the parameters:1. architecture: Specifies the type of SAE architecture being used, in this case, the standard architecture (encoder and decoder with hidden activations, as opposed to a gated SAE).2. d_in: Defines the input dimension of the SAE, which is [768] in this configuration.3. d_sae: Sets the dimension of the SAE’s hidden layer, which is [24576] here. This represents the number of possible feature activations.4. activation_fn_str: Specifies the activation function used in the SAE, which is ReLU in this case. TopK is another option that we will not cover here.5. apply_b_dec_to_input: Determines whether to apply the decoder bias to the input, set to True here.6. finetuning_scaling_factor: Indicates whether to use a scaling factor to weight initialization and the forward pass. This is not usually used and was introduced to support a solution for shrinkage.7. context_size: Defines the size of the context window, which is [128] tokens in this case. In turns out SAEs trained on small activations from small prompts often don’t perform well on longer prompts.8. model_name: Specifies the name of the model being used, which is ‘gemma-2-2b’ here. This is a valid model name in TransformerLens.9. hook_name: Indicates the specific hook in the model where the SAE is applied.10. hook_head_index: Defines which attention head to hook into; not relevant here since we are looking at a residual stream SAE.11. prepend_bos: Determines whether to prepend the beginning-of-sequence token, set to True.12. dataset_path: Specifies the path to the dataset used for training or evaluation. (Can be local or a huggingface dataset.)13. dataset_trust_remote_code: Indicates whether to trust remote code (from HuggingFace) when loading the dataset, set to True.14. normalize_activations: Specifies how to normalize activations, set to ‘none’ in this config.15. dtype: Defines the data type for tensor operations, set to 32-bit floating point.16. device: Specifies the computational device to use.17. sae_lens_training_version: Indicates the version of SAE Lens used for training, set to None here.18. activation_fn_kwargs: Allows for additional keyword arguments for the activation function. This would be used if e.g. the activation_fn_str was set to topk, so that k could be specified.
{'d_in': 2304, 'd_sae': 16384, 'dtype': 'bfloat16', 'device': 'cuda', 'apply_b_dec_to_input': False, 'normalize_activations': 'none', 'reshape_activations': 'none', 'metadata': SAEMetadata({'sae_lens_version': '6.44.2', 'sae_lens_training_version': None, 'model_name': 'gemma-2-2b', 'hook_name': 'blocks.20.hook_resid_post', 'hook_head_index': None, 'prepend_bos': True, 'dataset_path': 'monology/pile-uncopyrighted', 'context_size': 1024, 'neuronpedia_id': 'gemma-2-2b/20-gemmascope-res-16k'})}
check
max act: 2040.0 | nonzero: 7878
tokens: ['<bos>', 'The', ' quick', ' brown', ' fox', ' jumps', ' over', ' the', ' lazy', ' dog', '.']
L0/token: [7012, 29, 62, 88, 105, 110, 147, 101, 96, 68, 60]
last-token L0: 60 | last-token max: 47.5
loaded from /kaggle/input/datasets/nikolazhuk/interp-emotions-prompts/prompts
joy 20 neutral 20
max act: 2040.0 | nonzero: 7842
max act: 2040.0 | nonzero: 8167
max act: 2040.0 | nonzero: 8086
max act: 2040.0 | nonzero: 7672
max act: 2040.0 | nonzero: 7986
max act: 2040.0 | nonzero: 7957
max act: 2040.0 | nonzero: 8067
max act: 2040.0 | nonzero: 7833
max act: 2040.0 | nonzero: 7962
max act: 2040.0 | nonzero: 7880
max act: 2040.0 | nonzero: 8106
max act: 2040.0 | nonzero: 7897
max act: 2040.0 | nonzero: 7897
max act: 2040.0 | nonzero: 8014
max act: 2040.0 | nonzero: 7975
max act: 2040.0 | nonzero: 7843
max act: 2040.0 | nonzero: 7775
max act: 2040.0 | nonzero: 7942
max act: 2040.0 | nonzero: 7946
max act: 2040.0 | nonzero: 7797
max act: 2040.0 | nonzero: 7676
max act: 2040.0 | nonzero: 7702
max act: 2040.0 | nonzero: 7743
max act: 2040.0 | nonzero: 7497
max act: 2040.0 | nonzero: 7816
max act: 2040.0 | nonzero: 7636
max act: 2040.0 | nonzero: 7741
max act: 2040.0 | nonzero: 7618
max act: 2040.0 | nonzero: 7622
max act: 2040.0 | nonzero: 7739
max act: 2040.0 | nonzero: 7601
max act: 2040.0 | nonzero: 7556
max act: 2040.0 | nonzero: 7627
max act: 2040.0 | nonzero: 7623
max act: 2040.0 | nonzero: 7653
max act: 2040.0 | nonzero: 7642
max act: 2040.0 | nonzero: 7764
max act: 2040.0 | nonzero: 7629
max act: 2040.0 | nonzero: 7650
max act: 2040.0 | nonzero: 7560
joy L0 421
neutral L0 526
| feature_id | diff_score | joy_mean | neutral_mean | neuronpedia | |
|---|---|---|---|---|---|
| 0 | 1673 | 19.25000 | 21.25000 | 1.968750 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 1 | 15383 | 16.12500 | 30.75000 | 14.625000 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 2 | 369 | 15.81250 | 29.87500 | 14.062500 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 3 | 4418 | 14.87500 | 21.25000 | 6.343750 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 4 | 58 | 13.25000 | 14.06250 | 0.839844 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 5 | 4456 | 13.06250 | 13.43750 | 0.394531 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 6 | 12341 | 12.56250 | 15.75000 | 3.203125 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 7 | 12545 | 12.06250 | 25.87500 | 13.812500 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 8 | 3877 | 10.31250 | 10.31250 | 0.000000 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 9 | 1858 | 10.00000 | 74.50000 | 64.500000 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 10 | 10057 | 9.68750 | 12.18750 | 2.515625 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 11 | 15919 | 9.68750 | 10.06250 | 0.347656 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 12 | 2914 | 9.50000 | 29.37500 | 19.875000 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 13 | 13554 | 9.37500 | 10.75000 | 1.343750 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 14 | 5890 | 8.31250 | 18.87500 | 10.562500 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 15 | 3734 | 8.18750 | 8.68750 | 0.515625 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 16 | 8292 | 7.43750 | 11.25000 | 3.796875 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 17 | 15714 | 7.37500 | 9.93750 | 2.546875 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 18 | 1621 | 7.15625 | 7.15625 | 0.000000 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| 19 | 10624 | 6.90625 | 6.90625 | 0.000000 | https://neuronpedia.org/gemma-2-2b/20-gemmasco... |
| feature_id | diff_score | joy_mean | neutral_mean | neuronpedia | |
|---|---|---|---|---|---|
| 0 | 1673 | 19.25000 | 21.25000 | 1.968750 | 1673 |
| 1 | 15383 | 16.12500 | 30.75000 | 14.625000 | 15383 |
| 2 | 369 | 15.81250 | 29.87500 | 14.062500 | 369 |
| 3 | 4418 | 14.87500 | 21.25000 | 6.343750 | 4418 |
| 4 | 58 | 13.25000 | 14.06250 | 0.839844 | 58 |
| 5 | 4456 | 13.06250 | 13.43750 | 0.394531 | 4456 |
| 6 | 12341 | 12.56250 | 15.75000 | 3.203125 | 12341 |
| 7 | 12545 | 12.06250 | 25.87500 | 13.812500 | 12545 |
| 8 | 3877 | 10.31250 | 10.31250 | 0.000000 | 3877 |
| 9 | 1858 | 10.00000 | 74.50000 | 64.500000 | 1858 |
| 10 | 10057 | 9.68750 | 12.18750 | 2.515625 | 10057 |
| 11 | 15919 | 9.68750 | 10.06250 | 0.347656 | 15919 |
| 12 | 2914 | 9.50000 | 29.37500 | 19.875000 | 2914 |
| 13 | 13554 | 9.37500 | 10.75000 | 1.343750 | 13554 |
| 14 | 5890 | 8.31250 | 18.87500 | 10.562500 | 5890 |
| 15 | 3734 | 8.18750 | 8.68750 | 0.515625 | 3734 |
| 16 | 8292 | 7.43750 | 11.25000 | 3.796875 | 8292 |
| 17 | 15714 | 7.37500 | 9.93750 | 2.546875 | 15714 |
| 18 | 1621 | 7.15625 | 7.15625 | 0.000000 | 1621 |
| 19 | 10624 | 6.90625 | 6.90625 | 0.000000 | 10624 |
max act: 2040.0 | nonzero: 7842
max act: 2040.0 | nonzero: 8167
max act: 2040.0 | nonzero: 8086
max act: 2040.0 | nonzero: 7672
max act: 2040.0 | nonzero: 7986
max act: 2040.0 | nonzero: 7957
max act: 2040.0 | nonzero: 8067
max act: 2040.0 | nonzero: 7833
max act: 2040.0 | nonzero: 7962
max act: 2040.0 | nonzero: 7880
max act: 2040.0 | nonzero: 8106
max act: 2040.0 | nonzero: 7897
max act: 2040.0 | nonzero: 7897
max act: 2040.0 | nonzero: 8014
max act: 2040.0 | nonzero: 7975
max act: 2040.0 | nonzero: 7843
max act: 2040.0 | nonzero: 7775
max act: 2040.0 | nonzero: 7942
max act: 2040.0 | nonzero: 7946
max act: 2040.0 | nonzero: 7797
max act: 2040.0 | nonzero: 7842
max act: 2040.0 | nonzero: 8167
max act: 2040.0 | nonzero: 8086
max act: 2040.0 | nonzero: 7672
max act: 2040.0 | nonzero: 7986
max act: 2040.0 | nonzero: 7957
max act: 2040.0 | nonzero: 8067
max act: 2040.0 | nonzero: 7833
max act: 2040.0 | nonzero: 7962
max act: 2040.0 | nonzero: 7880
max act: 2040.0 | nonzero: 8106
max act: 2040.0 | nonzero: 7897
max act: 2040.0 | nonzero: 7897
max act: 2040.0 | nonzero: 8014
max act: 2040.0 | nonzero: 7975
max act: 2040.0 | nonzero: 7843
max act: 2040.0 | nonzero: 7775
max act: 2040.0 | nonzero: 7942
max act: 2040.0 | nonzero: 7946
max act: 2040.0 | nonzero: 7797
max act: 2040.0 | nonzero: 7676
max act: 2040.0 | nonzero: 7702
max act: 2040.0 | nonzero: 7743
max act: 2040.0 | nonzero: 7497
max act: 2040.0 | nonzero: 7816
max act: 2040.0 | nonzero: 7636
max act: 2040.0 | nonzero: 7741
max act: 2040.0 | nonzero: 7618
max act: 2040.0 | nonzero: 7622
max act: 2040.0 | nonzero: 7739
max act: 2040.0 | nonzero: 7601
max act: 2040.0 | nonzero: 7556
max act: 2040.0 | nonzero: 7627
max act: 2040.0 | nonzero: 7623
max act: 2040.0 | nonzero: 7653
max act: 2040.0 | nonzero: 7642
max act: 2040.0 | nonzero: 7764
max act: 2040.0 | nonzero: 7629
max act: 2040.0 | nonzero: 7650
max act: 2040.0 | nonzero: 7560
feature 1673
joy: mean= 21.21 fires 19/20
neutral: mean= 1.97 fires 3/20
https://neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/1673
Visualisation
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.8/1.8 MB 24.9 MB/s eta 0:00:0000:0100:01 Note: you may need to restart the kernel to use updated packages. Visualizing feature 1673 max act: 2040.0 | nonzero: 7842
max act: 2040.0 | nonzero: 8167
max act: 2040.0 | nonzero: 8086
max act: 2040.0 | nonzero: 7672
max act: 2040.0 | nonzero: 7986
--- Neutral controls ---
max act: 2040.0 | nonzero: 7676
max act: 2040.0 | nonzero: 7702
max act: 2040.0 | nonzero: 7743
max act: 2040.0 | nonzero: 7497
max act: 2040.0 | nonzero: 7816