Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

INT64 clamping to INT32 creates overhead while using TensorRT #627

Closed
JingyaHuang opened this issue Dec 21, 2022 · 2 comments · Fixed by #655
Closed

INT64 clamping to INT32 creates overhead while using TensorRT #627

JingyaHuang opened this issue Dec 21, 2022 · 2 comments · Fixed by #655
Assignees
Labels
bug Something isn't working

Comments

@JingyaHuang
Copy link
Collaborator

JingyaHuang commented Dec 21, 2022

System Info

optimum: dev
CUDA: 11.3
cuDNN: 8.3.2
TensorRT: 8.4.1.5

Who can help?

@JingyaHuang

Reproduction

from optimum.onnxruntime import ORTModelForSequenceClassification
from transformers import AutoTokenizer

ort_model = ORTModelForSequenceClassification.from_pretrained(
    "philschmid/tiny-bert-sst2-distilled",
    from_transformers=True,
    provider="TensorrtExecutionProvider",
)

tokenizer = AutoTokenizer.from_pretrained("philschmid/tiny-bert-sst2-distilled")
inp = tokenizer("expectations were low, actual enjoyment was high", return_tensors="pt", padding=True)

result = ort_model(**inp)
assert ort_model.providers == ["TensorrtExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider"]

Log

2022-12-21 09:40:48.208932701 [W:onnxruntime:Default, tensorrt_execution_provider.h:60 log] [2022-12-21 09:40:48 WARNING] external/onnx-tensorrt/onnx2trt_utils.cpp:367: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32

Expected behavior

Clamping should be done before the inference.

Will work on a convert_int64_to_int32() function to solve it.

@JingyaHuang JingyaHuang added the bug Something isn't working label Dec 21, 2022
@JingyaHuang JingyaHuang self-assigned this Dec 21, 2022
@fxmarty
Copy link
Contributor

fxmarty commented Dec 21, 2022

@JingyaHuang the warning

2022-12-21 11:03:15.553857281 [W:onnxruntime:Default, tensorrt_execution_provider.h:60 log] [2022-12-21 10:03:15 WARNING] external/onnx-tensorrt/onnx2trt_utils.cpp:367: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.

is raised at the model initialization, so I'm not sure what we can do more? The inputs themselves being int64 from the tokenizer.

@fxmarty
Copy link
Contributor

fxmarty commented Dec 30, 2022

Have a working prototype for this. trtexec comes very handy to test out. The warnings come due to Slice operator having end indexes as the maximum representable int64 value. Casting to np.iinfo(np.int32).max + astype(np.int32) removes the warnings.

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants