Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] Possible to retrieve layer-wise activations? #166

Closed
pat-alt opened this issue Jan 10, 2024 · 7 comments
Closed

[Question] Possible to retrieve layer-wise activations? #166

pat-alt opened this issue Jan 10, 2024 · 7 comments

Comments

@pat-alt
Copy link

pat-alt commented Jan 10, 2024

Thanks for the great package @chengchingwen 🙏🏽

I have a somewhat naive question that you might be able to help me with. For a project I'm currently working on I am trying run linear probes on layer activations. In particular, I'm trying to reproduce the following exercise from this paper:

image

I've naively tried to simply apply the Flux.activations() function with no luck. Here's an example:

using Flux
using Transformers
using Transformers.TextEncoders
using Transformers.HuggingFace

# Load model from HF 🤗:
tkr = hgf"mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis:tokenizer"
mod = hgf"mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis:ForSequenceClassification"
query = [
    "The economy is stagnant.",
    "Output has grown in H2.",
]
a = encode(tkr, query)
julia> Flux.activations(mod.model, a)
ERROR: 
──────────────────────────────────────────────────────────────── MethodError ───────────────────────────────────────────────────────────────
╭──── Error Stack ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│                                                                                                                                          │
│      ╭───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮           │
│  (1) │  top-level scope                                                                                                      │           │
│      │  REPL[80]:1                                                                                                           │           │
│      ╰───────────────────────────────────────────────────────────────────────────────────────────────────────── TOP LEVEL ───╯           │
│                                                                                                                                          │
╰──── Error Stack ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
╭─────────────────────────────────────────────────────────────── MethodError ──────────────────────────────────────────────────────────────╮
│                                                                                                                                          │
│  MethodError: no method matching activations(::Transformers.HuggingFace.HGFRobertaModel{Transformers.Layers.Chain{Tuple{Transformers.    │
│  Layers.CompositeEmbedding{Tuple{Transformers.Layers.WithArg{(:token,), Transformers.Layers.Embed}, Transformers.Layers.WithOptArg{(:    │
│  hidden_state,), (:position,), Transformers.Layers.ApplyEmbed{Base.Broadcast.BroadcastFunction{typeof(+)}, Transformers.Layers.FixedL    │
│  enPositionEmbed, NeuralAttentionlib.PrefixedFunction} , Transformers.Layers.WithOptArg{(:hidden_state,), (:segment,), Transformers.L    │
│  ayers.ApplyEmbed{Base.Broadcast.BroadcastFunction{typeof(+)}, Transformers.Layers.Embed, typeof(Transformers.HuggingFace.bert_ones_l    │
│  ike)} } , Transformers.Layers.DropoutLayer} , Transformer{NTuple 6, Transformers.Layers.PostNormTransformerBlock}, Nothing}    │
│  , Nothing}, ::@NamedTuple{token::OneHotArray 0x0000c459, 2, 3, Matrix{OneHot 0x0000c459} },       │
│  attention_mask::NeuralAttentionlib.RevLengthMask 1, Vector} )                                                        │
│                                                                                                                                          │
│  Closest candidates are:                                                                                                                 │
│    activations(!Matched::Flux.Chain, ::Any)                                                                                              │
│     @ Flux ~/.julia/packages/Flux/EHgZm/src/layers/basic.jl:102                                                                │
│                                                                                                                                          │
│                                                                                                                                          │
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯

Any advice would be much appreciated!

@chengchingwen
Copy link
Owner

There is an output_hidden_states configuration that can be set up with HGFConfig:

model_name = "mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis"
cfg = HuggingFace.HGFConfig(load_config(model_name); output_hidden_states = true)
mod = load_model(model_name, "ForSequenceClassification"; config = cfg)

then you can access all layer outputs with mod(a).outputs which is a NTuple{number_layers, @NamedTuple{hidden_state::Array{Float32, 3}}. Another similar configuration is output_attentions that would also include the attentions scores in the named tuples in .outputs.

BTW, if you don't need the sequence classification head, you can simply use load_model(model_name; config = cfg) which would extract the model part without the classification layers.

@pat-alt
Copy link
Author

pat-alt commented Jan 12, 2024

Amazing, thanks very much for the quick response 👍🏽

(I won't close this since you added the tag for documentation)

@pat-alt
Copy link
Author

pat-alt commented Jan 19, 2024

Small follow-up question: is it also somehow possible to collect outputs for each layer of the classifier head?

Edit: I realize I can just break down the forward pass into layer-by-layer calls as below, but perhaps there's a more streamline way to do this?

b = clf.layer.layers[1](b).hidden_state |>
        x -> clf.layer.layers[2](x)

@chengchingwen
Copy link
Owner

You can try extracting the actual layers in the classifier head and construct a Flux.Chain and call with Flux.activations. Otherwise, I think a manual loop/calls is probably the simplest.

@VarLad
Copy link

VarLad commented Feb 20, 2025

@chengchingwen I was trying out the following code:

using Transformers, Transformers.TextEncoders, Transformers.HuggingFace
bert_config = HuggingFace.HGFConfig(load_config("bert-base-uncased"); output_attentions = true)
# Load BERT model and tokenizer
bert_model = load_model("bert-base-uncased"; config=bert_config)
bert_tokenizer = load_tokenizer("bert-base-uncased")

text = [["The cat sat on the mat", "The cat lay on the rug"]]
sample = encode(bert_tokenizer, text)

bert_model(sample).attention_score

This returns an array 15×15×12×2 Array{Float32, 4}...

Is that correct? I'm not sure what the 2 means here, I was expecting a 12 here.

For example, for an equivalent Python code:

from transformers import BertTokenizer, BertModel

model_version = 'bert-base-uncased'
model = BertModel.from_pretrained(model_version, output_attentions=True)
tokenizer = BertTokenizer.from_pretrained(model_version)
sentence_a = "The cat sat on the mat"
sentence_b = "The cat lay on the rug"
inputs = tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt')
input_ids = inputs['input_ids']
token_type_ids = inputs['token_type_ids']
attention = model(input_ids, token_type_ids=token_type_ids)[-1]
print(len(attention), attention[0].shape)

which returns:

12 torch.Size([1, 12, 15, 15]), which makes more sense: 12 layers, 12 heads per layer 15 tokens x 15 tokens.

Would you know the issue here?

@chengchingwen
Copy link
Owner

@VarLad The output structure is slightly different. The bert_model(sample).attention_score is the attention score of the last transformer layer. What you want would be in bert_model(sample).outputs which should be a NTuple{12, @NamedTuple{hidden_state::Array{Float32, 3}, attention_score::Array{Float32, 4}}}.

OTOH I'm sure why you get 15×15×12×2 Array{Float32, 4}. With your code above, I get:

julia> size(bert_model(sample).outputs[1].attention_score)
(15, 15, 12, 1)

julia> size(bert_model(sample).attention_score)
(15, 15, 12, 1)

@VarLad
Copy link

VarLad commented Feb 21, 2025

@chengchingwen Apologies for confusion, 2 seemed to come from a typo on my side:

text = ["The cat sat on the mat", "The cat lay on the rug"]

and thanks a lot, that was the correct solution :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants