Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add retry logic and proxy support to the NeMo LLM Service #1544

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions morpheus.code-workspace
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
"program": "${workspaceFolder}/morpheus/cli/run.py",
"request": "launch",
"subProcess": true,
"type": "python"
"type": "debugpy"
},
{
"args": [
Expand Down Expand Up @@ -139,7 +139,7 @@
"program": "${workspaceFolder}/morpheus/cli/run.py",
"request": "launch",
"subProcess": true,
"type": "python"
"type": "debugpy"
},
{
"args": [
Expand Down Expand Up @@ -201,7 +201,7 @@
"program": "${workspaceFolder}/morpheus/cli/run.py",
"request": "launch",
"subProcess": true,
"type": "python"
"type": "debugpy"
},
{
"args": [
Expand Down Expand Up @@ -266,7 +266,7 @@
"program": "${workspaceFolder}/morpheus/cli/run.py",
"request": "launch",
"subProcess": true,
"type": "python"
"type": "debugpy"
},
{
"args": [
Expand All @@ -285,7 +285,7 @@
"name": "Python: Anomaly Detection Example",
"program": "${workspaceFolder}/examples/abp_pcap_detection/run.py",
"request": "launch",
"type": "python"
"type": "debugpy"
},
{
"args": [
Expand All @@ -303,7 +303,7 @@
"module": "sphinx.cmd.build",
"name": "Python: Sphinx",
"request": "launch",
"type": "python"
"type": "debugpy"
},
{
"MIMode": "gdb",
Expand Down Expand Up @@ -598,7 +598,7 @@
"name": "Python: GNN DGL inference",
"program": "${workspaceFolder}/examples/gnn_fraud_detection_pipeline/run.py",
"request": "launch",
"type": "python"
"type": "debugpy"
},
{
"args": [
Expand All @@ -614,7 +614,7 @@
"name": "Python: GNN model training",
"program": "${workspaceFolder}/models/training-tuning-scripts/fraud-detection-models/training.py",
"request": "launch",
"type": "python"
"type": "debugpy"
}
]
},
Expand Down
41 changes: 37 additions & 4 deletions morpheus/llm/services/llm_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import logging
import typing
from abc import ABC
from abc import abstractmethod

Expand All @@ -33,7 +34,7 @@ def get_input_names(self) -> list[str]:
pass

@abstractmethod
def generate(self, input_dict: dict[str, str]) -> str:
def generate(self, **input_dict) -> str:
"""
Issue a request to generate a response based on a given prompt.

Expand All @@ -45,7 +46,7 @@ def generate(self, input_dict: dict[str, str]) -> str:
pass

@abstractmethod
async def generate_async(self, input_dict: dict[str, str]) -> str:
async def generate_async(self, **input_dict) -> str:
"""
Issue an asynchronous request to generate a response based on a given prompt.

Expand All @@ -56,27 +57,59 @@ async def generate_async(self, input_dict: dict[str, str]) -> str:
"""
pass

@typing.overload
@abstractmethod
def generate_batch(self, inputs: dict[str, list[str]]) -> list[str]:
def generate_batch(self,
inputs: dict[str, list],
return_exceptions: typing.Literal[True] = True) -> list[str | BaseException]:
...

@typing.overload
@abstractmethod
def generate_batch(self, inputs: dict[str, list], return_exceptions: typing.Literal[False] = False) -> list[str]:
...

@abstractmethod
def generate_batch(self, inputs: dict[str, list], return_exceptions=False) -> list[str] | list[str | BaseException]:
"""
Issue a request to generate a list of responses based on a list of prompts.

Parameters
----------
inputs : dict
Inputs containing prompt data.
return_exceptions : bool
Whether to return exceptions in the output list or raise them immediately.
"""
pass

@typing.overload
@abstractmethod
async def generate_batch_async(self,
inputs: dict[str, list],
return_exceptions: typing.Literal[True] = True) -> list[str | BaseException]:
...

@typing.overload
@abstractmethod
async def generate_batch_async(self,
inputs: dict[str, list],
return_exceptions: typing.Literal[False] = False) -> list[str]:
...

@abstractmethod
async def generate_batch_async(self, inputs: dict[str, list[str]]) -> list[str]:
async def generate_batch_async(self,
inputs: dict[str, list],
return_exceptions=False) -> list[str] | list[str | BaseException]:
"""
Issue an asynchronous request to generate a list of responses based on a list of prompts.

Parameters
----------
inputs : dict
Inputs containing prompt data.
return_exceptions : bool
Whether to return exceptions in the output list or raise them immediately.
"""
pass

Expand Down
126 changes: 93 additions & 33 deletions morpheus/llm/services/nemo_llm_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import logging
import os
import typing
import warnings

from morpheus.llm.services.llm_service import LLMClient
from morpheus.llm.services.llm_service import LLMService
Expand Down Expand Up @@ -66,7 +67,7 @@ def __init__(self, parent: "NeMoLLMService", *, model_name: str, **model_kwargs)
def get_input_names(self) -> list[str]:
return [self._prompt_key]

def generate(self, input_dict: dict[str, str]) -> str:
def generate(self, **input_dict) -> str:
"""
Issue a request to generate a response based on a given prompt.

Expand All @@ -75,9 +76,9 @@ def generate(self, input_dict: dict[str, str]) -> str:
input_dict : dict
Input containing prompt data.
"""
return self.generate_batch({self._prompt_key: [input_dict[self._prompt_key]]})[0]
return self.generate_batch({self._prompt_key: [input_dict[self._prompt_key]]}, return_exceptions=False)[0]

async def generate_async(self, input_dict: dict[str, str]) -> str:
async def generate_async(self, **input_dict) -> str:
"""
Issue an asynchronous request to generate a response based on a given prompt.

Expand All @@ -86,79 +87,138 @@ async def generate_async(self, input_dict: dict[str, str]) -> str:
input_dict : dict
Input containing prompt data.
"""
return (await self.generate_batch_async({self._prompt_key: [input_dict[self._prompt_key]]}))[0]
return (await self.generate_batch_async({self._prompt_key: [input_dict[self._prompt_key]]},
return_exceptions=False))[0]

def generate_batch(self, inputs: dict[str, list[str]]) -> list[str]:
@typing.overload
def generate_batch(self,
inputs: dict[str, list],
return_exceptions: typing.Literal[True] = True) -> list[str | BaseException]:
...

@typing.overload
def generate_batch(self, inputs: dict[str, list], return_exceptions: typing.Literal[False] = False) -> list[str]:
...

def generate_batch(self, inputs: dict[str, list], return_exceptions=False) -> list[str] | list[str | BaseException]:
"""
Issue a request to generate a list of responses based on a list of prompts.

Parameters
----------
inputs : dict
Inputs containing prompt data.
return_exceptions : bool
Whether to return exceptions in the output list or raise them immediately.
"""

# Note: We dont want to use the generate_multiple implementation from nemollm because there is no retry logic.
# As soon as one of the requests fails, the entire batch fails. Instead, we need to implement the functionality
# listed in issue #1555 For now, we generate a warning if `return_exceptions` is True.
if (return_exceptions):
warnings.warn("return_exceptions==True is not currently supported by the NeMoLLMClient. "
"If an exception is raised for any item, the function will exit and raise that exception.")

return typing.cast(
list[str],
self._parent._conn.generate_multiple(model=self._model_name,
prompts=inputs[self._prompt_key],
return_type="text",
**self._model_kwargs))

async def generate_batch_async(self, inputs: dict[str, list[str]]) -> list[str]:
async def _process_one_async(self, prompt: str) -> str:
iterations = 0
errors = []

while iterations < self._parent._retry_count:
fut = await asyncio.wrap_future(
self._parent._conn.generate(model=self._model_name,
prompt=prompt,
return_type="async",
**self._model_kwargs)) # type: ignore

result: dict = nemollm.NemoLLM.post_process_generate_response(
fut, return_text_completion_only=False) # type: ignore

if result.get('status', None) == 'fail':
iterations += 1
errors.append(result.get('msg', 'Unknown error'))
continue

return result['text']

raise RuntimeError(
f"Failed to generate response for prompt '{prompt}' after {self._parent._retry_count} attempts. "
f"Errors: {errors}")

@typing.overload
async def generate_batch_async(self,
inputs: dict[str, list],
return_exceptions: typing.Literal[True] = True) -> list[str | BaseException]:
...

@typing.overload
async def generate_batch_async(self,
inputs: dict[str, list],
return_exceptions: typing.Literal[False] = False) -> list[str]:
...

async def generate_batch_async(self,
inputs: dict[str, list],
return_exceptions=False) -> list[str] | list[str | BaseException]:
"""
Issue an asynchronous request to generate a list of responses based on a list of prompts.

Parameters
----------
inputs : dict
Inputs containing prompt data.
return_exceptions : bool
Whether to return exceptions in the output list or raise them immediately.
"""
prompts = inputs[self._prompt_key]
futures = [
asyncio.wrap_future(
self._parent._conn.generate(self._model_name, p, return_type="async", **self._model_kwargs))
for p in prompts
]

results = await asyncio.gather(*futures)

responses = []

for result in results:
result = nemollm.NemoLLM.post_process_generate_response(result, return_text_completion_only=False)
if result.get('status', None) == 'fail':
raise RuntimeError(result.get('msg', 'Unknown error'))
futures = [self._process_one_async(p) for p in prompts]

responses.append(result['text'])
results = await asyncio.gather(*futures, return_exceptions=return_exceptions)

return responses
return results


class NeMoLLMService(LLMService):
"""
A service for interacting with NeMo LLM models, this class should be used to create a client for a specific model.

Parameters
----------
api_key : str, optional
The API key for the LLM service, by default None. If `None` the API key will be read from the `NGC_API_KEY`
environment variable. If neither are present an error will be raised.

org_id : str, optional
The organization ID for the LLM service, by default None. If `None` the organization ID will be read from the
`NGC_ORG_ID` environment variable. This value is only required if the account associated with the `api_key` is
a member of multiple NGC organizations.
"""

def __init__(self, *, api_key: str = None, org_id: str = None) -> None:
def __init__(self, *, api_key: str = None, org_id: str = None, retry_count=5) -> None:
"""
Creates a service for interacting with NeMo LLM models.

Parameters
----------
api_key : str, optional
The API key for the LLM service, by default None. If `None` the API key will be read from the `NGC_API_KEY`
environment variable. If neither are present an error will be raised., by default None
org_id : str, optional
The organization ID for the LLM service, by default None. If `None` the organization ID will be read from
the `NGC_ORG_ID` environment variable. This value is only required if the account associated with the
`api_key` is a member of multiple NGC organizations., by default None
retry_count : int, optional
The number of times to retry a request before raising an exception, by default 5

"""

if IMPORT_EXCEPTION is not None:
raise ImportError(IMPORT_ERROR_MESSAGE) from IMPORT_EXCEPTION

super().__init__()
api_key = api_key if api_key is not None else os.environ.get("NGC_API_KEY", None)
org_id = org_id if org_id is not None else os.environ.get("NGC_ORG_ID", None)

self._retry_count = retry_count

self._conn = nemollm.NemoLLM(
api_host=os.environ.get("NGC_API_BASE", None),
# The client must configure the authentication and authorization parameters
# in accordance with the API server security policy.
# Configure Bearer authorization
Expand Down
Loading
Loading