-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
error from generate_from_string method for BART #17
Comments
Thoughts @shmsw25 ? |
After a bit more digging it appears that the arguments to the forward method in modelling_bart.py in transformers 4.4.2 are rather different to the arguments passed to the forward method in the unifiedqa bart.py. I'm thinking I may need to update bart.py to match the latest modelling_bart.py to make this work. If I manage to do so would you like a copy of the updated version? |
Sure, thank you! 🙏 |
I forked your code and updated bart.py and also run.py. I've run it a few times and it seems to work. Generally I've commented my changes with comments starting with #TJH.. You can access at: https://github.com/timhartill/unifiedqa-tjh |
Appreciate it! Will look into your changes. |
transformers 4.x brought breaking changes & But it shouldn't be an issue if you use the HF's modelclass & not the derived class here. Example of how generation would like: import torch
from transformers import BartTokenizer, BartForConditionalGeneration
base_model = "facebook/bart-large"
unifiedqa_path = "unifiedQA-uncased/best-model.pt" # path to the downloaded checkpoint
tokenizer = BartTokenizer.from_pretrained(base_model)
model = BartForConditionalGeneration.from_pretrained(base_model, state_dict=torch.load(unifiedqa_path))
model.eval()
def generate_text(text, model, tokenizer):
inputs = tokenizer([text], max_length=512, truncation=True, return_tensors='pt')
output_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=50, early_stopping=True)
return ' '.join([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in output_ids])
text = "Which is best conductor? \\n (A) iron (B) feather"
print('\n', text)
print("generated_text:", generate_text(text, model, tokenizer))
text = "What is the sum of 3 and 5? \\n (A) 8 (B) 3 (C) 5 (D) 10"
print('\n', text)
print("generated_text:", generate_text(text, model, tokenizer))
text = "What is 42 ? \\n 42 is the answer to life, the universe and everything"
print('\n', text)
print("generated_text:", generate_text(text, model, tokenizer)) Or one could also use HF's pipelines as follows: # Using Pipeline
from transformers import pipeline
text = "What is 42 ? \\n 42 is the answer to life, the universe and everything"
print('\n', text)
text2text_generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
print(text2text_generator(text)) |
Hi @timhartill and @tshrjn, @danyaljj I was thinking keeping the version as it is in the repo is better since HF library will keep being updated and it would not easy to update the code every time with the guarantee of reproducing the numbers in the paper. Or we could update the inference code only and put a note that finetuning is only tested with the version in README. What do you think? |
Hi, I'm attempting to run the BART model example as given in the readme:
import torch
from transformers import BartTokenizer
from bart import MyBart
base_model = "facebook/bart-large"
unifiedqa_path = "unifiedQA-uncased/best-model.pt" # path to the downloaded checkpoint
tokenizer = BartTokenizer.from_pretrained(base_model)
model = MyBart.from_pretrained(base_model, state_dict=torch.load(unifiedqa_path))
model.eval()
x = model.generate_from_string("Which is best conductor? \n (A) iron (B) feather", tokenizer=tokenizer)
The .from_pretrained line executes fine but the .generate_from_string(..) line errors out with the error:
TypeError: forward() got an unexpected keyword argument 'past_key_values'
I tried using the run_model(..) method from the main git page and it gives exactly the same error.
Any idea what might be causing this and how to fix it?
I am using python 3.85 with transformers 4.4.2 and pytorch 1.7.1
The text was updated successfully, but these errors were encountered: