From 4a04a81d0e14e3ab5ab79a1a597a057ef9f246c8 Mon Sep 17 00:00:00 2001 From: Fabian Joswig Date: Mon, 22 Apr 2024 16:12:59 +0200 Subject: [PATCH] [ModelRunner] Fix stop & bad word list pointer offset. --- tensorrt_llm/runtime/generation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/runtime/generation.py b/tensorrt_llm/runtime/generation.py index 5937c4112..53caf0fa5 100755 --- a/tensorrt_llm/runtime/generation.py +++ b/tensorrt_llm/runtime/generation.py @@ -3067,7 +3067,7 @@ def decode(self, stop_words_list_ptrs = torch.zeros((batch_size), dtype=torch.int64) for bi in range(batch_size): stop_words_list_ptrs[bi] = stop_words_list.data_ptr( - ) + bi * 2 * max_stop_words_len + ) + bi * 2 * max_stop_words_len * stop_words_list.element_size() stop_words_list_ptrs = stop_words_list_ptrs.to('cuda') stop_words_data = (stop_words_list_ptrs, stop_words_lens, max_stop_words_len) @@ -3083,7 +3083,7 @@ def decode(self, bad_words_list_ptrs = torch.zeros((batch_size), dtype=torch.int64) for bi in range(batch_size): bad_words_list_ptrs[bi] = bad_words_list.data_ptr( - ) + bi * 2 * max_bad_words_len + ) + bi * 2 * max_bad_words_len * bad_words_list.element_size() bad_words_list_ptrs = bad_words_list_ptrs.to('cuda') bad_words_data = (bad_words_list_ptrs, bad_words_lens, max_bad_words_len)