-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Much slower for inference, even when traced? #1477
Comments
Can you fix this sentence? It seems some error slipped in there
As far as I know, you don't have to pad up to the max sequence length manually, and you can just pad up to the max sequence length per batch. That might save you some time. |
Yeah sorry I meant it increases performance a lot to decrease the max-seq-len. Good point.. I should definitely padding up to max length per batch, although I am not sure this will make huge difference as most of my inputs are of similar length and close to the max. I guess before I dive deeper I'm looking for a starting place into an investigation of why, say, the implementation of roberta here https://github.com/pytorch/fairseq/tree/master/examples/roberta would be 2x faster on the same GPU than the implementation in transformers. Does transformers make a conscious performance sacrifice in the name of modularity and extensibility? Or are there specific optimizations in fairseq (for example) that I am observing that have not been ported. Would updating the new pytorch modules from 1.12 discussed in #1451 make a difference (it seems like there can be performance improvements by fusing kernels so pytorch requires fewer to run the same model, although I do not fully understand this https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/) |
I am not sure about any of this, but I do remember that the PyTorch developers do an effort to implement as much parity between CPU and CUDA with specific optimisations for both. As an example, their C++ implementation for activation functions does specific things when MKL is available and otherwise. I'm not sure whether |
@pertschuk these benchmarks are usually mostly dependant on stuff like data-processing, selected float precision, specific inference code (are you in a If you have a (not too big) codebase for benchmarking and clear numbers, we can have a look. |
This comment has been minimized.
This comment has been minimized.
I cleaned and consolidated my code with dynamic padding to current batch size and torch.no_grad() context. Output is below. It seems like the native fairseq/ torchub implementation is a little less than 2x as fast as transformers. import transformers
from fairseq.data.data_utils import collate_tokens
import time
import torch.nn.functional as F
import torch.hub
MAX_LENGTH = 512
PAD = True
def benchmark_mnli(samples):
torch_hub_model = time_fn(torch.hub.load, 'pytorch/fairseq','roberta.large.mnli')
torch_hub_model.eval()
torch_hub_model.cuda()
try:
transformers_model = time_fn(transformers.RobertaModel.from_pretrained,
'roberta-large-mnli')
except:
transformers_model = time_fn(transformers.RobertaModel.from_pretrained,
'roberta-large-mnli', force_download=True)
transformers_tokenizer = time_fn(transformers.RobertaTokenizer.from_pretrained, 'roberta-large-mnli')
pred_functions = {
'transformers' : predict_transformers(transformers_model, transformers_tokenizer),
'torch_hub' : predict_roberta(torch_hub_model)
}
for framework, pred_fn in pred_functions.items():
print(f'Benchmarking {framework} with {samples} samples')
time_fn(benchmark, pred_fn, samples)
def predict_transformers(model, tokenizer):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
def predict_fn(*args):
inputs = time_fn(transformers_encode_batch, tokenizer, *args)
inputs_dict = {
'input_ids': torch.tensor(inputs[0], dtype=torch.long).to(device),
'attention_mask': torch.tensor(inputs[1], dtype=torch.long).to(device),
# 'token_type_ids': torch.tensor(inputs[2], dtype=torch.long)
}
outputs = model(**inputs_dict)
logits = outputs[0]
preds = F.log_softmax(logits, dim=-1)
return preds.tolist()
return predict_fn
def predict_roberta(model):
def pred_fn(*args):
batch = time_fn(collate_tokens, [model.encode(*arg)[:MAX_LENGTH] for arg in zip(*args)], pad_idx=1)
labels = model.predict('mnli', batch).tolist()
return labels
return pred_fn
def benchmark(pred_fn, n):
args = ['All work and no play.'] * 8, ['Make jack a very dull boy.'] * 8
for i in range(0, n):
assert(type(pred_fn(*args)) == list)
### HELPERS
def time_fn(fn, *args, **kwargs):
start = time.time()
res = fn(*args, **kwargs)
print(f'Took {time.time() - start} seconds to run {fn.__name__}')
return res
def transformer_to_features(tokenizer, *args):
inputs = tokenizer.encode_plus(
*args,
add_special_tokens=True,
max_length=MAX_LENGTH,
truncate_first_sequence=True
)
input_ids = inputs["input_ids"][:MAX_LENGTH]
return input_ids
def pad_up(input_ids, max_length):
padding_length = max_length - len(input_ids)
input_ids = ([0] * padding_length) + input_ids
attention_mask = ([0] * padding_length) + [1] * len(input_ids)
return (input_ids, attention_mask)
def transformers_encode_batch(tokenizer, *args):
assert(type(args[0]) == list)
all_input_ids = []
max_batch_len = 0
for sample in zip(*args):
input_ids = transformer_to_features(tokenizer, *sample)
all_input_ids.append(input_ids)
max_batch_len = max(max_batch_len, len(input_ids))
all_input_ids, all_attention_masks = zip(*[
pad_up(input_ids, max_batch_len) for input_ids in all_input_ids
])
return all_input_ids, all_attention_masks
if __name__ == '__main__':
with torch.no_grad():
benchmark_mnli(10) Here is the output:
Or with a longer sample input:
I benchmarked the traced transformer model and it's about the same. |
You closed this, but I'm curious to hear about your result and thoughts. So fairseq/HUB implementation is twice as fast as the transformers implementation? Do you have any intuition about why? |
❓ Questions & Help
When running inference using BERT-large on a T4 GPU using bert-as-a-service, I could get well over 100/s on sentence pair classification. (I am aware that this utilized TF's graph freezing and pruning)
When running inference with Roberta-large on a T4 GPU using native pytorch and fairseq, I was able to get 70-80/s for inference on sentence pairs.
Even with using the torchscript JIT tracing, I still am only able to get 17/s on a T4 using the transformers implementation of Bert-large, using a batch size of 8 (which fills most of the memory).
The training performance is similarly worse (about 40% - 100% longer even with apex vs no apex before).
One of the primary differences I can think of is that now I am padding all up to max-seq length, and it does increase performance a lot to decrease this. Is there a way to not pad in transformers? And just pass a list of pytorch tensors in that can be dynamically sized?
Should I try the tensorflow implementations?
Thank you!
The text was updated successfully, but these errors were encountered: