Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions graph_net/test/hack_llm_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import torch
from transformers import AutoModel, AutoTokenizer, AutoConfig
import graph_net.torch
from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.integrations.executorch import sdpa_mask_without_vmap


def get_model_name():
return "Qwen/Qwen2.5-3B"


def create_model():
config = AutoConfig.from_pretrained(get_model_name())
model = AutoModel.from_config(config)
# https://github.com/huggingface/transformers/blob/6b5bd117231f969713ed79fd4870903ab3c93edf/docs/source/en/attention_interface.md?plain=1#L194-L195
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
ALL_ATTENTION_FUNCTIONS.register(
"sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]
)
model.config._attn_implementation = "sdpa_without_vmap"
model.eval()
return model.to(device)


if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(get_model_name())

text = "Hello world"
inputs = tokenizer(text, return_tensors="pt")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inputs = {k: v.to(device) for k, v in inputs.items()}

class DummyAutocast:
def __init__(self, *args, **kwargs):
pass

def __enter__(self):
pass

def __exit__(self, *args):
pass

class PatchedModel(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model

def forward(self, *args, **kwargs):
# TODO
original_autocast1 = torch.cuda.amp.autocast
original_autocast2 = torch.amp.autocast
original_autocast3 = torch.autocast
try:
torch.cuda.amp.autocast = DummyAutocast
torch.amp.autocast = DummyAutocast
torch.autocast = DummyAutocast
return self.model.forward(*args, **kwargs)
finally:
torch.cuda.amp.autocast = original_autocast1
torch.amp.autocast = original_autocast2
torch.autocast = original_autocast3

model = create_model()
model = PatchedModel(model)
model = graph_net.torch.extract(name=get_model_name(), dynamic=False)(model)

print("Running inference...")
output = model(**inputs)
print("Inference finished. Output shape:", output.last_hidden_state.shape)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
70d4001ac71bdeb9de5fe851688032c3ba4cd2b9abf22950fc2061c98358408b
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"framework": "torch",
"num_devices_required": 1,
"num_nodes_required": 1,
"dynamic": false
}
Empty file.
5,275 changes: 5,275 additions & 0 deletions samples/transformers-auto-model/Qwen/Qwen2.5-0.5B/model.py

Large diffs are not rendered by default.

Loading