Skip to content

Commit

Permalink
set configurable llm retry sleep time
Browse files Browse the repository at this point in the history
Reviewed By: SimonWan

Differential Revision: D66383761

fbshipit-source-id: f58e32d281c2fb03b081c92c739f53330b7d5c72
  • Loading branch information
Adam Bali authored and facebook-github-bot committed Nov 25, 2024
1 parent 08f42a2 commit 774f273
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions CybersecurityBenchmarks/benchmark/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@


NUM_LLM_RETRIES = 100
MAX_RETRY_TIMEOUT = 600

MAX_TOKENS = 1000

Expand Down Expand Up @@ -135,6 +136,7 @@ def _query_with_retries(
func: Callable[..., str],
*args: Union[str, bool, Optional[str]],
retries: int = NUM_LLM_RETRIES,
max_retry_timeout: int = MAX_RETRY_TIMEOUT,
backoff_factor: float = 0.5,
) -> str:
last_exception = None
Expand All @@ -145,8 +147,8 @@ def _query_with_retries(
last_exception = exception
sleep_time = (
backoff_factor * (2**retry)
if backoff_factor * (2**retry) <= 600
else 600
if backoff_factor * (2**retry) <= max_retry_timeout
else max_retry_timeout
)
time.sleep(sleep_time)
LOG.error(
Expand All @@ -161,6 +163,7 @@ def _chat_with_retries(
func: Callable[..., str],
*args: Union[List[str], bool, Optional[str]],
retries: int = NUM_LLM_RETRIES,
max_retry_timeout: int = MAX_RETRY_TIMEOUT,
backoff_factor: float = 0.5,
) -> str:
last_exception = None
Expand All @@ -171,8 +174,8 @@ def _chat_with_retries(
last_exception = exception
sleep_time = (
backoff_factor * (2**retry)
if backoff_factor * (2**retry) <= 600
else 600
if backoff_factor * (2**retry) <= max_retry_timeout
else max_retry_timeout
)
time.sleep(sleep_time)
LOG.error(
Expand Down Expand Up @@ -332,6 +335,7 @@ def query_multimodal_with_retries(
image_path: Optional[str] = None,
audio_path: Optional[str] = None,
num_retries: int = NUM_LLM_RETRIES,
max_retry_timeout: int = MAX_RETRY_TIMEOUT,
max_tokens: int = DEFAULT_MAX_TOKENS,
temperature: float = DEFAULT_TEMPERATURE,
top_p: float = DEFAULT_TOP_P,
Expand Down Expand Up @@ -393,8 +397,8 @@ def query_multimodal_with_retries(
last_exception = exception
sleep_time = (
backoff_factor * (2**retry)
if backoff_factor * (2**retry) <= 600
else 600
if backoff_factor * (2**retry) <= max_retry_timeout
else max_retry_timeout
)
time.sleep(sleep_time)
LOG.error(
Expand Down

0 comments on commit 774f273

Please sign in to comment.