Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Much slower for inference, even when traced? #1477

Closed
pertschuk opened this issue Oct 10, 2019 · 7 comments
Closed

Much slower for inference, even when traced? #1477

pertschuk opened this issue Oct 10, 2019 · 7 comments

Comments

@pertschuk
Copy link

pertschuk commented Oct 10, 2019

❓ Questions & Help

When running inference using BERT-large on a T4 GPU using bert-as-a-service, I could get well over 100/s on sentence pair classification. (I am aware that this utilized TF's graph freezing and pruning)

When running inference with Roberta-large on a T4 GPU using native pytorch and fairseq, I was able to get 70-80/s for inference on sentence pairs.

Even with using the torchscript JIT tracing, I still am only able to get 17/s on a T4 using the transformers implementation of Bert-large, using a batch size of 8 (which fills most of the memory).

The training performance is similarly worse (about 40% - 100% longer even with apex vs no apex before).

One of the primary differences I can think of is that now I am padding all up to max-seq length, and it does increase performance a lot to decrease this. Is there a way to not pad in transformers? And just pass a list of pytorch tensors in that can be dynamically sized?

Should I try the tensorflow implementations?

Thank you!

@BramVanroy
Copy link
Collaborator

Can you fix this sentence? It seems some error slipped in there

One of the primary differences I can think of is that now I am padding all up to max-seq length, and it does increase performance a lot decrease this.

As far as I know, you don't have to pad up to the max sequence length manually, and you can just pad up to the max sequence length per batch. That might save you some time.

@pertschuk
Copy link
Author

pertschuk commented Oct 10, 2019

Yeah sorry I meant it increases performance a lot to decrease the max-seq-len.

Good point.. I should definitely padding up to max length per batch, although I am not sure this will make huge difference as most of my inputs are of similar length and close to the max.

I guess before I dive deeper I'm looking for a starting place into an investigation of why, say, the implementation of roberta here https://github.com/pytorch/fairseq/tree/master/examples/roberta would be 2x faster on the same GPU than the implementation in transformers.

Does transformers make a conscious performance sacrifice in the name of modularity and extensibility? Or are there specific optimizations in fairseq (for example) that I am observing that have not been ported.

Would updating the new pytorch modules from 1.12 discussed in #1451 make a difference (it seems like there can be performance improvements by fusing kernels so pytorch requires fewer to run the same model, although I do not fully understand this https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/)

@BramVanroy
Copy link
Collaborator

I am not sure about any of this, but I do remember that the PyTorch developers do an effort to implement as much parity between CPU and CUDA with specific optimisations for both. As an example, their C++ implementation for activation functions does specific things when MKL is available and otherwise. I'm not sure whether nn.Transformer and nn.MultiheadAttention got optimised intensively as well.

@thomwolf
Copy link
Member

@pertschuk these benchmarks are usually mostly dependant on stuff like data-processing, selected float precision, specific inference code (are you in a torch.no_grad context for instance) and basically all these things that are outside of the models themselves (which computational graphs are pretty much identical across frameworks).

If you have a (not too big) codebase for benchmarking and clear numbers, we can have a look.

@pertschuk

This comment has been minimized.

@pertschuk pertschuk reopened this Oct 10, 2019
@pertschuk
Copy link
Author

pertschuk commented Oct 10, 2019

I cleaned and consolidated my code with dynamic padding to current batch size and torch.no_grad() context. Output is below. It seems like the native fairseq/ torchub implementation is a little less than 2x as fast as transformers.

import transformers
from fairseq.data.data_utils import collate_tokens
import time
import torch.nn.functional as F
import torch.hub

MAX_LENGTH = 512
PAD = True


def benchmark_mnli(samples):
    torch_hub_model = time_fn(torch.hub.load, 'pytorch/fairseq','roberta.large.mnli')
    torch_hub_model.eval()
    torch_hub_model.cuda()
    try:
        transformers_model = time_fn(transformers.RobertaModel.from_pretrained,
                                     'roberta-large-mnli')
    except:
        transformers_model = time_fn(transformers.RobertaModel.from_pretrained,
                                     'roberta-large-mnli', force_download=True)
    transformers_tokenizer = time_fn(transformers.RobertaTokenizer.from_pretrained, 'roberta-large-mnli')
    pred_functions = {
        'transformers' : predict_transformers(transformers_model, transformers_tokenizer),
        'torch_hub' : predict_roberta(torch_hub_model)
    }
    for framework, pred_fn in pred_functions.items():
        print(f'Benchmarking {framework} with {samples} samples')
        time_fn(benchmark, pred_fn, samples)


def predict_transformers(model, tokenizer):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    def predict_fn(*args):
        inputs = time_fn(transformers_encode_batch, tokenizer, *args)
        inputs_dict = {
            'input_ids': torch.tensor(inputs[0],  dtype=torch.long).to(device),
            'attention_mask': torch.tensor(inputs[1],  dtype=torch.long).to(device),
            # 'token_type_ids': torch.tensor(inputs[2],  dtype=torch.long)
        }
        outputs = model(**inputs_dict)
        logits = outputs[0]
        preds = F.log_softmax(logits, dim=-1)
        return preds.tolist()
    return predict_fn


def predict_roberta(model):
    def pred_fn(*args):
        batch = time_fn(collate_tokens, [model.encode(*arg)[:MAX_LENGTH] for arg in zip(*args)], pad_idx=1)
        labels = model.predict('mnli', batch).tolist()
        return labels
    return pred_fn


def benchmark(pred_fn, n):
    args = ['All work and no play.'] * 8, ['Make jack a very dull boy.'] * 8
    for i in range(0, n):
        assert(type(pred_fn(*args)) == list)

### HELPERS

def time_fn(fn, *args, **kwargs):
    start = time.time()
    res = fn(*args, **kwargs)
    print(f'Took {time.time() - start} seconds to run {fn.__name__}')
    return res


def transformer_to_features(tokenizer, *args):
    inputs = tokenizer.encode_plus(
        *args,
        add_special_tokens=True,
        max_length=MAX_LENGTH,
        truncate_first_sequence=True
    )
    input_ids = inputs["input_ids"][:MAX_LENGTH]

    return input_ids

def pad_up(input_ids, max_length):
    padding_length = max_length - len(input_ids)
    input_ids = ([0] * padding_length) + input_ids
    attention_mask = ([0] * padding_length) + [1] * len(input_ids)
    return (input_ids, attention_mask)


def transformers_encode_batch(tokenizer, *args):
    assert(type(args[0]) == list)
    all_input_ids = []
    max_batch_len = 0

    for sample in zip(*args):
        input_ids = transformer_to_features(tokenizer, *sample)
        all_input_ids.append(input_ids)
        max_batch_len = max(max_batch_len, len(input_ids))

    all_input_ids, all_attention_masks = zip(*[
        pad_up(input_ids, max_batch_len) for input_ids in all_input_ids
    ])
    return all_input_ids, all_attention_masks


if __name__ == '__main__':
    with torch.no_grad():
        benchmark_mnli(10)

Here is the output:

Took 11.221294641494751 seconds to run load
Took 10.316125392913818 seconds to run from_pretrained
Took 0.3631258010864258 seconds to run from_pretrained
Benchmarking transformers with 10 samples
Took 0.00434112548828125 seconds to run transformers_encode_batch
Took 0.0039653778076171875 seconds to run transformers_encode_batch
Took 0.003747701644897461 seconds to run transformers_encode_batch
Took 0.0035974979400634766 seconds to run transformers_encode_batch
Took 0.0037157535552978516 seconds to run transformers_encode_batch
Took 0.003725767135620117 seconds to run transformers_encode_batch
Took 0.0038688182830810547 seconds to run transformers_encode_batch
Took 0.004169464111328125 seconds to run transformers_encode_batch
Took 0.003767728805541992 seconds to run transformers_encode_batch
Took 0.003550291061401367 seconds to run transformers_encode_batch
Took 0.7687280178070068 seconds to run benchmark
Benchmarking torch_hub with 10 samples
Took 0.0001957416534423828 seconds to run collate_tokens
Took 8.797645568847656e-05 seconds to run collate_tokens
Took 6.890296936035156e-05 seconds to run collate_tokens
Took 6.961822509765625e-05 seconds to run collate_tokens
Took 6.914138793945312e-05 seconds to run collate_tokens
Took 6.961822509765625e-05 seconds to run collate_tokens
Took 7.05718994140625e-05 seconds to run collate_tokens
Took 9.202957153320312e-05 seconds to run collate_tokens
Took 6.961822509765625e-05 seconds to run collate_tokens
Took 7.700920104980469e-05 seconds to run collate_tokens
Took 0.4018120765686035 seconds to run benchmark

Or with a longer sample input:

Took 10.34562063217163 seconds to run load
Took 10.523965835571289 seconds to run from_pretrained
Took 0.4653303623199463 seconds to run from_pretrained
Benchmarking transformers with 10 samples
Took 0.007193565368652344 seconds to run transformers_encode_batch
Took 0.005567789077758789 seconds to run transformers_encode_batch
Took 0.005621671676635742 seconds to run transformers_encode_batch
Took 0.006003141403198242 seconds to run transformers_encode_batch
Took 0.0061550140380859375 seconds to run transformers_encode_batch
Took 0.005508899688720703 seconds to run transformers_encode_batch
Took 0.005594730377197266 seconds to run transformers_encode_batch
Took 0.005545854568481445 seconds to run transformers_encode_batch
Took 0.005563259124755859 seconds to run transformers_encode_batch
Took 0.0059223175048828125 seconds to run transformers_encode_batch
Took 1.5394785404205322 seconds to run benchmark
Benchmarking torch_hub with 10 samples
Took 0.0001571178436279297 seconds to run collate_tokens
Took 9.131431579589844e-05 seconds to run collate_tokens
Took 9.322166442871094e-05 seconds to run collate_tokens
Took 8.7738037109375e-05 seconds to run collate_tokens
Took 8.726119995117188e-05 seconds to run collate_tokens
Took 8.726119995117188e-05 seconds to run collate_tokens
Took 8.869171142578125e-05 seconds to run collate_tokens
Took 8.96453857421875e-05 seconds to run collate_tokens
Took 8.58306884765625e-05 seconds to run collate_tokens
Took 8.869171142578125e-05 seconds to run collate_tokens
Took 0.9851493835449219 seconds to run benchmark

I benchmarked the traced transformer model and it's about the same.

@BramVanroy
Copy link
Collaborator

You closed this, but I'm curious to hear about your result and thoughts. So fairseq/HUB implementation is twice as fast as the transformers implementation? Do you have any intuition about why?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants