SAELens

Norms of decoder weights

#2
by sjgerstner - opened

Hi! I did some exploratory analysis on the transcoders (specifically the ones with lowest average L0 from each layer). I investigated norms of decoder weights: from the decoder weight matrix I'm considering the row that corresponds to a given latent, and computing the norm of that vector. In code: torch.linalg.vector_norm(sae.W_dec, dim=1).
I noticed that these norms are always very close to 1, the max distance being 2.3842e-07.
The most likely explanation seems to be that you normalised the weights after training (and that the slight variations are due to float precision issues), but I couldn't find this mentioned in the paper. So, did you normalise the weights, or do I have to look for another explanation?

Hi @sjgerstner , apologies for the delayed response.
Thanks for the careful analysis, your observation is correct and your interpretation is very close. The key detail, however, is that this is not a post training normalisation step.

For the Gemma Scope transcoders, the decoder weight vectors are constrained to have a unit L2 norm as part of the training procedure itself rather than being normalised afterward. As described in the Gemma Scope Technical Documentation and accompanying materials, the transcoders are trained as sparse auto encoders with a reconstruction objective (MSE) combined with a sparsity penalty implemented via JumpReLU activations. In this setup, leaving the decoder weights unconstrained would introduce a well known scaling degeneracy: the model could artificially shrink encoder weights which reduces latent activations and therefore lowers the sparsity penalty, while proportionally inflating decoder weights to preserve reconstruction quality.

To prevent this degeneracy, each decoder dictionary vector (each row of W_dec, corresponding to a latent feature) is constrained to lie on the unit L2 sphere during training. Concretely, after optimiser updates, the decoder weight vectors are projected back to unit norm. because this unit norm constraint is enforced throughout optimisation, the learned decoder weights naturally converge to norms extremely close to 1. The maximum deviation you observed (approximately 2.38 * 10^-7) is entirely consistent with expected floating point precision effects( float32 or bfloat16 rounding during training, checkpointing or loading), rather than evidence of a separate normalisation pass after training.

So you do not need to look for an alternative explanation, what you are seeing is the intended geometric constraint of the training procedure, with tiny deviations attributable to numerical precision.

Thank you!

Thank you srikanta I took a skim of your words and I think I documented everything I did in the tech report

See also the README of this HF repo on "Transcoder Input/Output Clarification"

Sign up or log in to comment