Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add retry logic and proxy support to the NeMo LLM Service #1544

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 24 additions & 11 deletions morpheus/llm/services/nemo_llm_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,22 +114,34 @@ async def generate_batch_async(self, inputs: dict[str, list[str]]) -> list[str]:
Inputs containing prompt data.
"""
prompts = inputs[self._prompt_key]
futures = [
asyncio.wrap_future(
self._parent._conn.generate(self._model_name, p, return_type="async", **self._model_kwargs))
for p in prompts
]

results = await asyncio.gather(*futures)
async def process_one(p: str):
mdemoret-nv marked this conversation as resolved.
Show resolved Hide resolved

iterations = 0
errors = []

while iterations < 10:
fut = await asyncio.wrap_future(
self._parent._conn.generate(self._model_name, p, return_type="async", **self._model_kwargs))

result = nemollm.NemoLLM.post_process_generate_response(fut, return_text_completion_only=False)
if result.get('status', None) == 'fail':
iterations += 1
errors.append(result.get('msg', 'Unknown error'))
continue

return result['text']

raise RuntimeError(f"Failed to generate response for prompt '{p}' after 3 attempts. Errors: {errors}")

futures = [process_one(p) for p in prompts]

results = await asyncio.gather(*futures, return_exceptions=True)

responses = []

for result in results:
result = nemollm.NemoLLM.post_process_generate_response(result, return_text_completion_only=False)
if result.get('status', None) == 'fail':
raise RuntimeError(result.get('msg', 'Unknown error'))

responses.append(result['text'])
responses.append(result)

return responses
mdemoret-nv marked this conversation as resolved.
Show resolved Hide resolved

Expand Down Expand Up @@ -159,6 +171,7 @@ def __init__(self, *, api_key: str = None, org_id: str = None) -> None:
org_id = org_id if org_id is not None else os.environ.get("NGC_ORG_ID", None)

self._conn = nemollm.NemoLLM(
api_host=os.environ.get("NGC_API_BASE", None),
# The client must configure the authentication and authorization parameters
# in accordance with the API server security policy.
# Configure Bearer authorization
Expand Down
Loading