From 5da335eb3d11a0e2a3d4513f0bce073e1997562b Mon Sep 17 00:00:00 2001 From: DocShotgun <126566557+DocShotgun@users.noreply.github.com> Date: Tue, 17 Dec 2024 09:34:43 -0800 Subject: [PATCH 1/3] Model: Robust request length checking in generator * Ensure that length of positive/negative prompt + max_tokens does not exceed max_seq_len * Ensure that total required pages for CFG request does not exceed allocated cache_size --- backends/exllamav2/model.py | 46 ++++++++++++++++++++++++++++++++----- 1 file changed, 40 insertions(+), 6 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 50cef42..f46f4f9 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1301,17 +1301,51 @@ async def generate_gen( # The first index will always be the positive prompt context_len = input_ids[0].size(dim=-1) - if context_len > self.config.max_seq_len: - raise ValueError( - f"Context length {context_len} is greater than max_seq_len " - f"{self.config.max_seq_len}" - ) + + # The second index will be the negative prompt if CFG is enabled + if negative_prompt is not None: + negative_context_len = input_ids[1].size(dim=-1) + else: + negative_context_len = 0 # Automatically set max_tokens to fill up the context # This should be an OK default, but may be changed in the future max_tokens = unwrap( - kwargs.get("max_tokens"), self.config.max_seq_len - context_len + kwargs.get("max_tokens"), + self.config.max_seq_len - max(context_len, negative_context_len), ) + if max_tokens < 1: + logger.warning("max_tokens must be a positive integer, " "setting to 1.") + max_tokens = 1 + + # Check total length of request + if context_len + max_tokens > self.config.max_seq_len: + raise ValueError( + f"Request length {context_len} + {max_tokens} is greater than " + f"max_seq_len {self.config.max_seq_len}" + ) + + # Check total length of negative prompt request if CFG is enabled + if negative_prompt is not None: + if context_len + max_tokens > self.config.max_seq_len: + raise ValueError( + f"Request length for negative prompt " + f"{negative_context_len} + {max_tokens} is greater than " + f"max_seq_len {self.config.max_seq_len}" + ) + # Check total required pages for CFG request + if ( + sum( + 256 * math.ceil((context + max_tokens) / 256) + for context in (context_len, negative_context_len) + ) + > self.cache_size + ): + raise ValueError( + f"Total required page size for request " + f"{context_len} + {negative_context_len} + {max_tokens} * 2 " + f"is greater than cache_size {self.cache_size}" + ) # Set min_tokens to generate while keeping EOS banned min_tokens = unwrap(kwargs.get("min_tokens"), 0) From 4d11323c17805b61833eafcc869d5221a89ff5fb Mon Sep 17 00:00:00 2001 From: DocShotgun <126566557+DocShotgun@users.noreply.github.com> Date: Tue, 17 Dec 2024 09:37:33 -0800 Subject: [PATCH 2/3] Tree: Format --- backends/exllamav2/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index f46f4f9..3478820 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1315,7 +1315,7 @@ async def generate_gen( self.config.max_seq_len - max(context_len, negative_context_len), ) if max_tokens < 1: - logger.warning("max_tokens must be a positive integer, " "setting to 1.") + logger.warning("max_tokens must be a positive integer, setting to 1.") max_tokens = 1 # Check total length of request From b994aae99591f286cbc185493ed079ddd16da973 Mon Sep 17 00:00:00 2001 From: kingbri <8082010+bdashore3@users.noreply.github.com> Date: Thu, 26 Dec 2024 23:13:08 -0500 Subject: [PATCH 3/3] Model: Cleanup generation length and page checks Reduce the amount of if statements and combine parts of code. Signed-off-by: kingbri <8082010+bdashore3@users.noreply.github.com> --- backends/exllamav2/model.py | 54 ++++++++++++++++++------------------- 1 file changed, 26 insertions(+), 28 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index d9a52ac..2e20e91 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1309,10 +1309,7 @@ async def generate_gen( context_len = input_ids[0].size(dim=-1) # The second index will be the negative prompt if CFG is enabled - if negative_prompt is not None: - negative_context_len = input_ids[1].size(dim=-1) - else: - negative_context_len = 0 + negative_context_len = input_ids[1].size(dim=-1) if negative_prompt else 0 # Automatically set max_tokens to fill up the context # This should be an OK default, but may be changed in the future @@ -1324,34 +1321,35 @@ async def generate_gen( logger.warning("max_tokens must be a positive integer, setting to 1.") max_tokens = 1 - # Check total length of request - if context_len + max_tokens > self.config.max_seq_len: + # Determine if the negative context or the context length is bigger + context_to_check = max(negative_context_len, context_len) + + # Check highest possible total length of request + if context_to_check + max_tokens > self.config.max_seq_len: + preamble = ( + "Negative prompt request" + if negative_context_len > context_len + else "Request" + ) + raise ValueError( - f"Request length {context_len} + {max_tokens} is greater than " + f"{preamble} length {context_to_check} + {max_tokens} is greater than " f"max_seq_len {self.config.max_seq_len}" ) - # Check total length of negative prompt request if CFG is enabled - if negative_prompt is not None: - if context_len + max_tokens > self.config.max_seq_len: - raise ValueError( - f"Request length for negative prompt " - f"{negative_context_len} + {max_tokens} is greater than " - f"max_seq_len {self.config.max_seq_len}" - ) - # Check total required pages for CFG request - if ( - sum( - 256 * math.ceil((context + max_tokens) / 256) - for context in (context_len, negative_context_len) - ) - > self.cache_size - ): - raise ValueError( - f"Total required page size for request " - f"{context_len} + {negative_context_len} + {max_tokens} * 2 " - f"is greater than cache_size {self.cache_size}" - ) + # Check total required pages for CFG request to avoid overallocation + if negative_prompt and ( + sum( + 256 * math.ceil((context + max_tokens) / 256) + for context in (context_len, negative_context_len) + ) + > self.cache_size + ): + raise ValueError( + f"Total required page size for request " + f"{context_len} + {negative_context_len} + {max_tokens} * 2 " + f"is greater than cache_size {self.cache_size}" + ) # Set min_tokens to generate while keeping EOS banned min_tokens = unwrap(kwargs.get("min_tokens"), 0)