diff --git a/llava/mm_utils.py b/llava/mm_utils.py index 5c6627d..dc8ba56 100644 --- a/llava/mm_utils.py +++ b/llava/mm_utils.py @@ -269,18 +269,7 @@ def call_for_batch( keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids ] for keyword_id in self.keyword_ids: - if keyword_id.ndim == 3: - if (output_ids[0, -keyword_id.shape[0] :, None] == keyword_id).all(): - return True - elif keyword_id.ndim == 2: - if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all(): - return True - else: - raise ValueError( - "Keyword tensor should have 2 or 3 dimensions, got {}".format( - keyword_id.ndim - ) - ) + if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all(): return True outputs = self.tokenizer.batch_decode( output_ids[:, -offset:], skip_special_tokens=True