From 744ba79562343c1d7dd700c3156855b518496a5c Mon Sep 17 00:00:00 2001 From: Michael Demoret <42954918+mdemoret-nv@users.noreply.github.com> Date: Mon, 11 Mar 2024 10:45:05 -0400 Subject: [PATCH 1/3] Add retry logic and proxy support to the NeMo LLM Service (#1544) - Adds the ability to retry NeMo failures more than one time - Adds an argument to configure the retry count - Adds support for proxying requests to the NeMo service using the `NGC_API_BASE` environment variable - Changes the API for the base `LLMService` to improve the type hints and allow arguments other than strings to be used. ## By Submitting this PR I confirm: - I am familiar with the [Contributing Guidelines](https://github.com/nv-morpheus/Morpheus/blob/main/docs/source/developer_guide/contributing.md). - When the PR is ready for review, new or existing tests cover these changes. - When the PR is ready for review, the documentation is up to date with these changes. Authors: - Michael Demoret (https://github.com/mdemoret-nv) Approvers: - David Gardner (https://github.com/dagardner-nv) URL: https://github.com/nv-morpheus/Morpheus/pull/1544 --- morpheus.code-workspace | 16 +-- morpheus/llm/services/llm_service.py | 41 +++++- morpheus/llm/services/nemo_llm_service.py | 126 +++++++++++++----- morpheus/llm/services/openai_chat_service.py | 76 +++++++++-- tests/_utils/environment.py | 58 ++++++++ tests/conftest.py | 2 +- tests/llm/conftest.py | 28 ++-- tests/llm/services/test_llm_service_pipe.py | 20 ++- tests/llm/services/test_nemo_llm_client.py | 119 +++++++++-------- tests/llm/services/test_nemo_llm_service.py | 5 +- tests/llm/services/test_openai_chat_client.py | 4 +- tests/llm/test_completion_pipe.py | 29 ++-- 12 files changed, 371 insertions(+), 153 deletions(-) create mode 100644 tests/_utils/environment.py diff --git a/morpheus.code-workspace b/morpheus.code-workspace index dd9ed51763..4e0da64a64 100644 --- a/morpheus.code-workspace +++ b/morpheus.code-workspace @@ -83,7 +83,7 @@ "program": "${workspaceFolder}/morpheus/cli/run.py", "request": "launch", "subProcess": true, - "type": "python" + "type": "debugpy" }, { "args": [ @@ -139,7 +139,7 @@ "program": "${workspaceFolder}/morpheus/cli/run.py", "request": "launch", "subProcess": true, - "type": "python" + "type": "debugpy" }, { "args": [ @@ -201,7 +201,7 @@ "program": "${workspaceFolder}/morpheus/cli/run.py", "request": "launch", "subProcess": true, - "type": "python" + "type": "debugpy" }, { "args": [ @@ -266,7 +266,7 @@ "program": "${workspaceFolder}/morpheus/cli/run.py", "request": "launch", "subProcess": true, - "type": "python" + "type": "debugpy" }, { "args": [ @@ -285,7 +285,7 @@ "name": "Python: Anomaly Detection Example", "program": "${workspaceFolder}/examples/abp_pcap_detection/run.py", "request": "launch", - "type": "python" + "type": "debugpy" }, { "args": [ @@ -303,7 +303,7 @@ "module": "sphinx.cmd.build", "name": "Python: Sphinx", "request": "launch", - "type": "python" + "type": "debugpy" }, { "MIMode": "gdb", @@ -598,7 +598,7 @@ "name": "Python: GNN DGL inference", "program": "${workspaceFolder}/examples/gnn_fraud_detection_pipeline/run.py", "request": "launch", - "type": "python" + "type": "debugpy" }, { "args": [ @@ -614,7 +614,7 @@ "name": "Python: GNN model training", "program": "${workspaceFolder}/models/training-tuning-scripts/fraud-detection-models/training.py", "request": "launch", - "type": "python" + "type": "debugpy" } ] }, diff --git a/morpheus/llm/services/llm_service.py b/morpheus/llm/services/llm_service.py index 1d11481345..07c777c547 100644 --- a/morpheus/llm/services/llm_service.py +++ b/morpheus/llm/services/llm_service.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +import typing from abc import ABC from abc import abstractmethod @@ -33,7 +34,7 @@ def get_input_names(self) -> list[str]: pass @abstractmethod - def generate(self, input_dict: dict[str, str]) -> str: + def generate(self, **input_dict) -> str: """ Issue a request to generate a response based on a given prompt. @@ -45,7 +46,7 @@ def generate(self, input_dict: dict[str, str]) -> str: pass @abstractmethod - async def generate_async(self, input_dict: dict[str, str]) -> str: + async def generate_async(self, **input_dict) -> str: """ Issue an asynchronous request to generate a response based on a given prompt. @@ -56,8 +57,20 @@ async def generate_async(self, input_dict: dict[str, str]) -> str: """ pass + @typing.overload @abstractmethod - def generate_batch(self, inputs: dict[str, list[str]]) -> list[str]: + def generate_batch(self, + inputs: dict[str, list], + return_exceptions: typing.Literal[True] = True) -> list[str | BaseException]: + ... + + @typing.overload + @abstractmethod + def generate_batch(self, inputs: dict[str, list], return_exceptions: typing.Literal[False] = False) -> list[str]: + ... + + @abstractmethod + def generate_batch(self, inputs: dict[str, list], return_exceptions=False) -> list[str] | list[str | BaseException]: """ Issue a request to generate a list of responses based on a list of prompts. @@ -65,11 +78,29 @@ def generate_batch(self, inputs: dict[str, list[str]]) -> list[str]: ---------- inputs : dict Inputs containing prompt data. + return_exceptions : bool + Whether to return exceptions in the output list or raise them immediately. """ pass + @typing.overload + @abstractmethod + async def generate_batch_async(self, + inputs: dict[str, list], + return_exceptions: typing.Literal[True] = True) -> list[str | BaseException]: + ... + + @typing.overload + @abstractmethod + async def generate_batch_async(self, + inputs: dict[str, list], + return_exceptions: typing.Literal[False] = False) -> list[str]: + ... + @abstractmethod - async def generate_batch_async(self, inputs: dict[str, list[str]]) -> list[str]: + async def generate_batch_async(self, + inputs: dict[str, list], + return_exceptions=False) -> list[str] | list[str | BaseException]: """ Issue an asynchronous request to generate a list of responses based on a list of prompts. @@ -77,6 +108,8 @@ async def generate_batch_async(self, inputs: dict[str, list[str]]) -> list[str]: ---------- inputs : dict Inputs containing prompt data. + return_exceptions : bool + Whether to return exceptions in the output list or raise them immediately. """ pass diff --git a/morpheus/llm/services/nemo_llm_service.py b/morpheus/llm/services/nemo_llm_service.py index 0173354df9..7f3bb5c6a0 100644 --- a/morpheus/llm/services/nemo_llm_service.py +++ b/morpheus/llm/services/nemo_llm_service.py @@ -16,6 +16,7 @@ import logging import os import typing +import warnings from morpheus.llm.services.llm_service import LLMClient from morpheus.llm.services.llm_service import LLMService @@ -66,7 +67,7 @@ def __init__(self, parent: "NeMoLLMService", *, model_name: str, **model_kwargs) def get_input_names(self) -> list[str]: return [self._prompt_key] - def generate(self, input_dict: dict[str, str]) -> str: + def generate(self, **input_dict) -> str: """ Issue a request to generate a response based on a given prompt. @@ -75,9 +76,9 @@ def generate(self, input_dict: dict[str, str]) -> str: input_dict : dict Input containing prompt data. """ - return self.generate_batch({self._prompt_key: [input_dict[self._prompt_key]]})[0] + return self.generate_batch({self._prompt_key: [input_dict[self._prompt_key]]}, return_exceptions=False)[0] - async def generate_async(self, input_dict: dict[str, str]) -> str: + async def generate_async(self, **input_dict) -> str: """ Issue an asynchronous request to generate a response based on a given prompt. @@ -86,9 +87,20 @@ async def generate_async(self, input_dict: dict[str, str]) -> str: input_dict : dict Input containing prompt data. """ - return (await self.generate_batch_async({self._prompt_key: [input_dict[self._prompt_key]]}))[0] + return (await self.generate_batch_async({self._prompt_key: [input_dict[self._prompt_key]]}, + return_exceptions=False))[0] - def generate_batch(self, inputs: dict[str, list[str]]) -> list[str]: + @typing.overload + def generate_batch(self, + inputs: dict[str, list], + return_exceptions: typing.Literal[True] = True) -> list[str | BaseException]: + ... + + @typing.overload + def generate_batch(self, inputs: dict[str, list], return_exceptions: typing.Literal[False] = False) -> list[str]: + ... + + def generate_batch(self, inputs: dict[str, list], return_exceptions=False) -> list[str] | list[str | BaseException]: """ Issue a request to generate a list of responses based on a list of prompts. @@ -96,7 +108,17 @@ def generate_batch(self, inputs: dict[str, list[str]]) -> list[str]: ---------- inputs : dict Inputs containing prompt data. + return_exceptions : bool + Whether to return exceptions in the output list or raise them immediately. """ + + # Note: We dont want to use the generate_multiple implementation from nemollm because there is no retry logic. + # As soon as one of the requests fails, the entire batch fails. Instead, we need to implement the functionality + # listed in issue #1555 For now, we generate a warning if `return_exceptions` is True. + if (return_exceptions): + warnings.warn("return_exceptions==True is not currently supported by the NeMoLLMClient. " + "If an exception is raised for any item, the function will exit and raise that exception.") + return typing.cast( list[str], self._parent._conn.generate_multiple(model=self._model_name, @@ -104,7 +126,46 @@ def generate_batch(self, inputs: dict[str, list[str]]) -> list[str]: return_type="text", **self._model_kwargs)) - async def generate_batch_async(self, inputs: dict[str, list[str]]) -> list[str]: + async def _process_one_async(self, prompt: str) -> str: + iterations = 0 + errors = [] + + while iterations < self._parent._retry_count: + fut = await asyncio.wrap_future( + self._parent._conn.generate(model=self._model_name, + prompt=prompt, + return_type="async", + **self._model_kwargs)) # type: ignore + + result: dict = nemollm.NemoLLM.post_process_generate_response( + fut, return_text_completion_only=False) # type: ignore + + if result.get('status', None) == 'fail': + iterations += 1 + errors.append(result.get('msg', 'Unknown error')) + continue + + return result['text'] + + raise RuntimeError( + f"Failed to generate response for prompt '{prompt}' after {self._parent._retry_count} attempts. " + f"Errors: {errors}") + + @typing.overload + async def generate_batch_async(self, + inputs: dict[str, list], + return_exceptions: typing.Literal[True] = True) -> list[str | BaseException]: + ... + + @typing.overload + async def generate_batch_async(self, + inputs: dict[str, list], + return_exceptions: typing.Literal[False] = False) -> list[str]: + ... + + async def generate_batch_async(self, + inputs: dict[str, list], + return_exceptions=False) -> list[str] | list[str | BaseException]: """ Issue an asynchronous request to generate a list of responses based on a list of prompts. @@ -112,45 +173,41 @@ async def generate_batch_async(self, inputs: dict[str, list[str]]) -> list[str]: ---------- inputs : dict Inputs containing prompt data. + return_exceptions : bool + Whether to return exceptions in the output list or raise them immediately. """ prompts = inputs[self._prompt_key] - futures = [ - asyncio.wrap_future( - self._parent._conn.generate(self._model_name, p, return_type="async", **self._model_kwargs)) - for p in prompts - ] - - results = await asyncio.gather(*futures) - responses = [] - - for result in results: - result = nemollm.NemoLLM.post_process_generate_response(result, return_text_completion_only=False) - if result.get('status', None) == 'fail': - raise RuntimeError(result.get('msg', 'Unknown error')) + futures = [self._process_one_async(p) for p in prompts] - responses.append(result['text']) + results = await asyncio.gather(*futures, return_exceptions=return_exceptions) - return responses + return results class NeMoLLMService(LLMService): """ A service for interacting with NeMo LLM models, this class should be used to create a client for a specific model. - - Parameters - ---------- - api_key : str, optional - The API key for the LLM service, by default None. If `None` the API key will be read from the `NGC_API_KEY` - environment variable. If neither are present an error will be raised. - - org_id : str, optional - The organization ID for the LLM service, by default None. If `None` the organization ID will be read from the - `NGC_ORG_ID` environment variable. This value is only required if the account associated with the `api_key` is - a member of multiple NGC organizations. """ - def __init__(self, *, api_key: str = None, org_id: str = None) -> None: + def __init__(self, *, api_key: str = None, org_id: str = None, retry_count=5) -> None: + """ + Creates a service for interacting with NeMo LLM models. + + Parameters + ---------- + api_key : str, optional + The API key for the LLM service, by default None. If `None` the API key will be read from the `NGC_API_KEY` + environment variable. If neither are present an error will be raised., by default None + org_id : str, optional + The organization ID for the LLM service, by default None. If `None` the organization ID will be read from + the `NGC_ORG_ID` environment variable. This value is only required if the account associated with the + `api_key` is a member of multiple NGC organizations., by default None + retry_count : int, optional + The number of times to retry a request before raising an exception, by default 5 + + """ + if IMPORT_EXCEPTION is not None: raise ImportError(IMPORT_ERROR_MESSAGE) from IMPORT_EXCEPTION @@ -158,7 +215,10 @@ def __init__(self, *, api_key: str = None, org_id: str = None) -> None: api_key = api_key if api_key is not None else os.environ.get("NGC_API_KEY", None) org_id = org_id if org_id is not None else os.environ.get("NGC_ORG_ID", None) + self._retry_count = retry_count + self._conn = nemollm.NemoLLM( + api_host=os.environ.get("NGC_API_BASE", None), # The client must configure the authentication and authorization parameters # in accordance with the API server security policy. # Configure Bearer authorization diff --git a/morpheus/llm/services/openai_chat_service.py b/morpheus/llm/services/openai_chat_service.py index 446b9a0ee9..5b52db311a 100644 --- a/morpheus/llm/services/openai_chat_service.py +++ b/morpheus/llm/services/openai_chat_service.py @@ -164,15 +164,34 @@ def _extract_completion(self, completion: "openai.types.chat.chat_completion.Cha return content - def _generate(self, prompt: str, assistant: str = None) -> str: - messages = self._create_messages(prompt, assistant) + @typing.overload + def _generate(self, + prompt: str, + assistant: str = None, + return_exceptions: typing.Literal[True] = True) -> str | BaseException: + ... - output: openai.types.chat.chat_completion.ChatCompletion = self._client.chat.completions.create( - model=self._model_name, messages=messages, **self._model_kwargs) + @typing.overload + def _generate(self, prompt: str, assistant: str = None, return_exceptions: typing.Literal[False] = False) -> str: + ... - return self._extract_completion(output) + def _generate(self, prompt: str, assistant: str = None, return_exceptions: bool = False): + + try: + messages = self._create_messages(prompt, assistant) + + output: openai.types.chat.chat_completion.ChatCompletion = self._client.chat.completions.create( + model=self._model_name, messages=messages, **self._model_kwargs) + + return self._extract_completion(output) + except BaseException as e: - def generate(self, input_dict: dict[str, str]) -> str: + if return_exceptions: + return e + + raise + + def generate(self, **input_dict) -> str: """ Issue a request to generate a response based on a given prompt. @@ -181,7 +200,9 @@ def generate(self, input_dict: dict[str, str]) -> str: input_dict : dict Input containing prompt data. """ - return self._generate(input_dict[self._prompt_key], input_dict.get(self._assistant_key)) + return self._generate(input_dict[self._prompt_key], + input_dict.get(self._assistant_key), + return_exceptions=False) async def _generate_async(self, prompt: str, assistant: str = None) -> str: @@ -201,7 +222,7 @@ async def _generate_async(self, prompt: str, assistant: str = None) -> str: return self._extract_completion(output) - async def generate_async(self, input_dict: dict[str, str]) -> str: + async def generate_async(self, **input_dict) -> str: """ Issue an asynchronous request to generate a response based on a given prompt. @@ -212,7 +233,17 @@ async def generate_async(self, input_dict: dict[str, str]) -> str: """ return await self._generate_async(input_dict[self._prompt_key], input_dict.get(self._assistant_key)) - def generate_batch(self, inputs: dict[str, list[str]]) -> list[str]: + @typing.overload + def generate_batch(self, + inputs: dict[str, list], + return_exceptions: typing.Literal[True] = True) -> list[str | BaseException]: + ... + + @typing.overload + def generate_batch(self, inputs: dict[str, list], return_exceptions: typing.Literal[False] = False) -> list[str]: + ... + + def generate_batch(self, inputs: dict[str, list], return_exceptions=False) -> list[str] | list[str | BaseException]: """ Issue a request to generate a list of responses based on a list of prompts. @@ -220,6 +251,8 @@ def generate_batch(self, inputs: dict[str, list[str]]) -> list[str]: ---------- inputs : dict Inputs containing prompt data. + return_exceptions : bool + Whether to return exceptions in the output list or raise them immediately. """ prompts = inputs[self._prompt_key] assistants = None @@ -231,11 +264,28 @@ def generate_batch(self, inputs: dict[str, list[str]]) -> list[str]: results = [] for (i, prompt) in enumerate(prompts): assistant = assistants[i] if assistants is not None else None - results.append(self._generate(prompt, assistant)) + if (return_exceptions): + results.append(self._generate(prompt, assistant, return_exceptions=True)) + else: + results.append(self._generate(prompt, assistant, return_exceptions=False)) return results - async def generate_batch_async(self, inputs: dict[str, list[str]]) -> list[str]: + @typing.overload + async def generate_batch_async(self, + inputs: dict[str, list], + return_exceptions: typing.Literal[True] = True) -> list[str | BaseException]: + ... + + @typing.overload + async def generate_batch_async(self, + inputs: dict[str, list], + return_exceptions: typing.Literal[False] = False) -> list[str]: + ... + + async def generate_batch_async(self, + inputs: dict[str, list], + return_exceptions=False) -> list[str] | list[str | BaseException]: """ Issue an asynchronous request to generate a list of responses based on a list of prompts. @@ -243,6 +293,8 @@ async def generate_batch_async(self, inputs: dict[str, list[str]]) -> list[str]: ---------- inputs : dict Inputs containing prompt data. + return_exceptions : bool + Whether to return exceptions in the output list or raise them immediately. """ prompts = inputs[self._prompt_key] assistants = None @@ -256,7 +308,7 @@ async def generate_batch_async(self, inputs: dict[str, list[str]]) -> list[str]: assistant = assistants[i] if assistants is not None else None coros.append(self._generate_async(prompt, assistant)) - return await asyncio.gather(*coros) + return await asyncio.gather(*coros, return_exceptions=return_exceptions) class OpenAIChatService(LLMService): diff --git a/tests/_utils/environment.py b/tests/_utils/environment.py new file mode 100644 index 0000000000..b4ddea4866 --- /dev/null +++ b/tests/_utils/environment.py @@ -0,0 +1,58 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for modifying the environment for testing Morpheus""" + +import contextlib +import os + + +@contextlib.contextmanager +def set_env(**env_vars): + """ + Temporarily updates the ``os.environ`` dictionary in-place. + + The ``os.environ`` dictionary is updated in-place so that the modification + is sure to work in all situations. + + Setting a value to ``None`` will cause the key to be removed from the environment. + """ + # Taken from https://stackoverflow.com/a/34333710 + env = os.environ + + # Remove any which are set to None + remove = [k for k, v in env_vars.items() if v is None] + + # Save the remaining environment variables to set + update = {k: v for k, v in env_vars.items() if v is not None} + + # List of environment variables being updated or removed. + stomped = (set(update.keys()) | set(remove)) & set(env.keys()) + # Environment variables and values to restore on exit. + update_after = {k: env[k] for k in stomped} + # Environment variables and values to remove on exit. + remove_after = frozenset(k for k in update if k not in env) + + try: + env.update(update) + + for k in remove: + env.pop(k, None) + + yield + finally: + env.update(update_after) + + for k in remove_after: + env.pop(k) diff --git a/tests/conftest.py b/tests/conftest.py index c422c18ea9..4d67976dc3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1065,7 +1065,7 @@ def mock_chat_completion_fixture(): @pytest.mark.usefixtures("nemollm") @pytest.fixture(name="mock_nemollm") def mock_nemollm_fixture(): - with mock.patch("nemollm.NemoLLM") as mock_nemollm: + with mock.patch("nemollm.NemoLLM", autospec=True) as mock_nemollm: mock_nemollm.return_value = mock_nemollm mock_nemollm.generate_multiple.return_value = ["test_output"] mock_nemollm.post_process_generate_response.return_value = {"text": "test_output"} diff --git a/tests/llm/conftest.py b/tests/llm/conftest.py index f92a16d148..3519166635 100644 --- a/tests/llm/conftest.py +++ b/tests/llm/conftest.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio -import typing from unittest import mock import pytest @@ -104,16 +102,26 @@ def serpapi_api_key_fixture(): @pytest.fixture(name="mock_nemollm") def mock_nemollm_fixture(mock_nemollm: mock.MagicMock): - # The generate function is a blocking call that returns a future when return_type="async" - async def mock_task(fut: asyncio.Future, value: typing.Any = mock.DEFAULT): - fut.set_result(value) + from concurrent.futures import Future + + def generate_mock(*_, **kwargs): + + fut = Future() + + fut.set_result(kwargs["prompt"]) - def create_future(*args, **kwargs) -> asyncio.Future: # pylint: disable=unused-argument - event_loop = asyncio.get_event_loop() - fut = event_loop.create_future() - event_loop.create_task(mock_task(fut, mock.DEFAULT)) return fut - mock_nemollm.generate.side_effect = create_future + mock_nemollm.generate.side_effect = generate_mock + + def generate_multiple_mock(*_, **kwargs): + + assert kwargs["return_type"] == "text", "Only text return type is supported for mocking." + + prompts: list[str] = kwargs["prompts"] + + return list(prompts) + + mock_nemollm.generate_multiple.side_effect = generate_multiple_mock yield mock_nemollm diff --git a/tests/llm/services/test_llm_service_pipe.py b/tests/llm/services/test_llm_service_pipe.py index fa6c1ac0c7..e6e2f8bbf3 100644 --- a/tests/llm/services/test_llm_service_pipe.py +++ b/tests/llm/services/test_llm_service_pipe.py @@ -35,7 +35,7 @@ from morpheus.stages.preprocess.deserialize_stage import DeserializeStage -def _build_engine(llm_service_cls: LLMService): +def _build_engine(llm_service_cls: type[LLMService]): llm_service = llm_service_cls() llm_clinet = llm_service.get_client(model_name="test_model") @@ -47,7 +47,9 @@ def _build_engine(llm_service_cls: LLMService): return engine -def _run_pipeline(config: Config, llm_service_cls: LLMService, country_prompts: list[str], +def _run_pipeline(config: Config, + llm_service_cls: type[LLMService], + country_prompts: list[str], capital_responses: list[str]): """ Loosely patterned after `examples/llm/completion` @@ -72,16 +74,10 @@ def _run_pipeline(config: Config, llm_service_cls: LLMService, country_prompts: assert_results(sink.get_results()) -@mock.patch("asyncio.wrap_future") -@mock.patch("asyncio.gather", new_callable=mock.AsyncMock) -def test_completion_pipe_nemo( - mock_asyncio_gather: mock.AsyncMock, - mock_asyncio_wrap_future: mock.MagicMock, # pylint: disable=unused-argument - config: Config, - mock_nemollm: mock.MagicMock, - country_prompts: list[str], - capital_responses: list[str]): - mock_asyncio_gather.return_value = [mock.MagicMock() for _ in range(len(country_prompts))] +def test_completion_pipe_nemo(config: Config, + mock_nemollm: mock.MagicMock, + country_prompts: list[str], + capital_responses: list[str]): mock_nemollm.post_process_generate_response.side_effect = [{"text": response} for response in capital_responses] _run_pipeline(config, NeMoLLMService, country_prompts, capital_responses) diff --git a/tests/llm/services/test_nemo_llm_client.py b/tests/llm/services/test_nemo_llm_client.py index 5a7993006c..d6cc295ea4 100644 --- a/tests/llm/services/test_nemo_llm_client.py +++ b/tests/llm/services/test_nemo_llm_client.py @@ -13,95 +13,104 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio from unittest import mock import pytest from morpheus.llm.services.llm_service import LLMClient -from morpheus.llm.services.nemo_llm_service import NeMoLLMClient +from morpheus.llm.services.nemo_llm_service import NeMoLLMService -def test_constructor(mock_nemollm: mock.MagicMock, mock_nemo_service: mock.MagicMock): - client = NeMoLLMClient(mock_nemo_service, model_name="test_model", additional_arg="test_arg") +def test_constructor(): + client = NeMoLLMService(api_key="dummy").get_client(model_name="test_model", additional_arg="test_arg") + assert isinstance(client, LLMClient) - mock_nemollm.assert_not_called() -def test_get_input_names(mock_nemollm: mock.MagicMock, mock_nemo_service: mock.MagicMock): - client = NeMoLLMClient(mock_nemo_service, model_name="test_model", additional_arg="test_arg") +def test_get_input_names(): + client = NeMoLLMService(api_key="dummy").get_client(model_name="test_model", additional_arg="test_arg") + assert client.get_input_names() == ["prompt"] - mock_nemollm.assert_not_called() -def test_generate(mock_nemollm: mock.MagicMock, mock_nemo_service: mock.MagicMock): - client = NeMoLLMClient(mock_nemo_service, model_name="test_model", additional_arg="test_arg") - assert client.generate({'prompt': "test_prompt"}) == "test_output" +def test_generate(mock_nemollm: mock.MagicMock): + + client = NeMoLLMService(api_key="dummy").get_client(model_name="test_model", customization_id="test_custom_id") + + assert client.generate(prompt="test_prompt") == "test_prompt" + mock_nemollm.generate_multiple.assert_called_once_with(model="test_model", prompts=["test_prompt"], return_type="text", - additional_arg="test_arg") + customization_id="test_custom_id") + +def test_generate_batch(mock_nemollm: mock.MagicMock): -def test_generate_batch(mock_nemollm: mock.MagicMock, mock_nemo_service: mock.MagicMock): - mock_nemollm.generate_multiple.return_value = ["output1", "output2"] + client = NeMoLLMService(api_key="dummy").get_client(model_name="test_model", customization_id="test_custom_id") + + assert client.generate_batch({'prompt': ["prompt1", "prompt2"]}) == ["prompt1", "prompt2"] - client = NeMoLLMClient(mock_nemo_service, model_name="test_model", additional_arg="test_arg") - assert client.generate_batch({'prompt': ["prompt1", "prompt2"]}) == ["output1", "output2"] mock_nemollm.generate_multiple.assert_called_once_with(model="test_model", prompts=["prompt1", "prompt2"], return_type="text", - additional_arg="test_arg") + customization_id="test_custom_id") + + +async def test_generate_async(mock_nemollm: mock.MagicMock): + client = NeMoLLMService(api_key="dummy").get_client(model_name="test_model", customization_id="test_custom_id") -@mock.patch("asyncio.wrap_future") -@mock.patch("asyncio.gather", new_callable=mock.AsyncMock) -def test_generate_async( - mock_asyncio_gather: mock.AsyncMock, - mock_asyncio_wrap_future: mock.MagicMock, # pylint: disable=unused-argument - mock_nemollm: mock.MagicMock, - mock_nemo_service: mock.MagicMock): - mock_asyncio_gather.return_value = [mock.MagicMock()] + results = await client.generate_async(prompt="test_prompt") - client = NeMoLLMClient(mock_nemo_service, model_name="test_model", additional_arg="test_arg") - results = asyncio.run(client.generate_async({'prompt': "test_prompt"})) assert results == "test_output" + mock_nemollm.generate.assert_called_once_with("test_model", "test_prompt", return_type="async", - additional_arg="test_arg") - - -@mock.patch("asyncio.wrap_future") -@mock.patch("asyncio.gather", new_callable=mock.AsyncMock) -def test_generate_batch_async( - mock_asyncio_gather: mock.AsyncMock, - mock_asyncio_wrap_future: mock.MagicMock, # pylint: disable=unused-argument - mock_nemollm: mock.MagicMock, - mock_nemo_service: mock.MagicMock): - mock_asyncio_gather.return_value = [mock.MagicMock(), mock.MagicMock()] - mock_nemollm.post_process_generate_response.side_effect = [{"text": "output1"}, {"text": "output2"}] - - client = NeMoLLMClient(mock_nemo_service, model_name="test_model", additional_arg="test_arg") - results = asyncio.run(client.generate_batch_async({'prompt': ["prompt1", "prompt2"]})) - assert results == ["output1", "output2"] + customization_id="test_custom_id") + + +async def test_generate_batch_async(mock_nemollm: mock.MagicMock): + # mock_nemollm.post_process_generate_response.side_effect = [{"text": "output1"}, {"text": "output2"}] + + client = NeMoLLMService(api_key="dummy").get_client(model_name="test_model", customization_id="test_custom_id") + + results = await client.generate_batch_async({'prompt': ["prompt1", "prompt2"]}) + + assert results == ["test_output", "test_output"] + mock_nemollm.generate.assert_has_calls([ - mock.call("test_model", "prompt1", return_type="async", additional_arg="test_arg"), - mock.call("test_model", "prompt2", return_type="async", additional_arg="test_arg") + mock.call("test_model", "prompt1", return_type="async", customization_id="test_custom_id"), + mock.call("test_model", "prompt2", return_type="async", customization_id="test_custom_id") ]) -@mock.patch("asyncio.wrap_future") -@mock.patch("asyncio.gather", new_callable=mock.AsyncMock) -def test_generate_batch_async_error( - mock_asyncio_gather: mock.AsyncMock, - mock_asyncio_wrap_future: mock.MagicMock, # pylint: disable=unused-argument - mock_nemollm: mock.MagicMock, - mock_nemo_service: mock.MagicMock): - mock_asyncio_gather.return_value = [mock.MagicMock(), mock.MagicMock()] +async def test_generate_batch_async_error(mock_nemollm: mock.MagicMock): mock_nemollm.post_process_generate_response.return_value = {"status": "fail", "msg": "unittest"} - client = NeMoLLMClient(mock_nemo_service, model_name="test_model", additional_arg="test_arg") + client = NeMoLLMService(api_key="dummy").get_client(model_name="test_model", customization_id="test_custom_id") with pytest.raises(RuntimeError, match="unittest"): - asyncio.run(client.generate_batch_async({'prompt': ["prompt1", "prompt2"]})) + await client.generate_batch_async({'prompt': ["prompt1", "prompt2"]}) + + +async def test_generate_batch_async_error_retry(mock_nemollm: mock.MagicMock): + + count = 0 + + def mock_post_process_generate_response(*args, **_): + nonlocal count + if count < 2: + count += 1 + return {"status": "fail", "msg": "unittest"} + return {"status": "success", "text": args[0]} + + mock_nemollm.post_process_generate_response.side_effect = mock_post_process_generate_response + + client = NeMoLLMService(api_key="dummy", retry_count=2).get_client(model_name="test_model", + customization_id="test_custom_id") + + results = await client.generate_batch_async({'prompt': ["prompt1", "prompt2"]}) + + assert results == ["prompt1", "prompt2"] diff --git a/tests/llm/services/test_nemo_llm_service.py b/tests/llm/services/test_nemo_llm_service.py index d91a6f7351..739d44e31b 100644 --- a/tests/llm/services/test_nemo_llm_service.py +++ b/tests/llm/services/test_nemo_llm_service.py @@ -38,7 +38,10 @@ def test_constructor(mock_nemollm: mock.MagicMock, api_key: str, org_id: str): expected_org_id = org_id or env_org_id NeMoLLMService(api_key=api_key, org_id=org_id) - mock_nemollm.assert_called_once_with(api_key=expected_api_key, org_id=expected_org_id) + _, kwargs = mock_nemollm.call_args_list[-1] + + assert kwargs["api_key"] == expected_api_key + assert kwargs["org_id"] == expected_org_id def test_get_client(): diff --git a/tests/llm/services/test_openai_chat_client.py b/tests/llm/services/test_openai_chat_client.py index b4f5529edd..577c83c7bb 100644 --- a/tests/llm/services/test_openai_chat_client.py +++ b/tests/llm/services/test_openai_chat_client.py @@ -61,14 +61,14 @@ def test_generate(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMock], set_assistant=set_assistant, temperature=temperature) if use_async: - results = asyncio.run(client.generate_async(input_dict)) + results = asyncio.run(client.generate_async(**input_dict)) mock_async_client.chat.completions.create.assert_called_once_with(model="test_model", messages=expected_messages, temperature=temperature) mock_client.chat.completions.create.assert_not_called() else: - results = client.generate(input_dict) + results = client.generate(**input_dict) mock_client.chat.completions.create.assert_called_once_with(model="test_model", messages=expected_messages, temperature=temperature) diff --git a/tests/llm/test_completion_pipe.py b/tests/llm/test_completion_pipe.py index 39c16d7e3b..106eb39586 100644 --- a/tests/llm/test_completion_pipe.py +++ b/tests/llm/test_completion_pipe.py @@ -21,6 +21,7 @@ import cudf from _utils import assert_results +from _utils.environment import set_env from _utils.llm import mk_mock_openai_response from morpheus.config import Config from morpheus.llm import LLMEngine @@ -41,7 +42,7 @@ logger = logging.getLogger(__name__) -def _build_engine(llm_service_cls: LLMService, model_name: str = "test_model"): +def _build_engine(llm_service_cls: type[LLMService], model_name: str = "test_model"): llm_service = llm_service_cls() llm_client = llm_service.get_client(model_name=model_name) @@ -57,7 +58,7 @@ def _build_engine(llm_service_cls: LLMService, model_name: str = "test_model"): def _run_pipeline(config: Config, - llm_service_cls: LLMService, + llm_service_cls: type[LLMService], countries: list[str], capital_responses: list[str], model_name: str = "test_model") -> dict: @@ -90,19 +91,17 @@ def _run_pipeline(config: Config, @pytest.mark.usefixtures("nemollm") -@mock.patch("asyncio.wrap_future") -@mock.patch("asyncio.gather", new_callable=mock.AsyncMock) -def test_completion_pipe_nemo( - mock_asyncio_gather: mock.AsyncMock, - mock_asyncio_wrap_future: mock.MagicMock, # pylint: disable=unused-argument - config: Config, - mock_nemollm: mock.MagicMock, - countries: list[str], - capital_responses: list[str]): - mock_asyncio_gather.return_value = [mock.MagicMock() for _ in range(len(countries))] - mock_nemollm.post_process_generate_response.side_effect = [{"text": response} for response in capital_responses] - results = _run_pipeline(config, NeMoLLMService, countries=countries, capital_responses=capital_responses) - assert_results(results) +def test_completion_pipe_nemo(config: Config, + mock_nemollm: mock.MagicMock, + countries: list[str], + capital_responses: list[str]): + + # Set a dummy key to bypass the API key check + with set_env(NGC_API_KEY="test"): + + mock_nemollm.post_process_generate_response.side_effect = [{"text": response} for response in capital_responses] + results = _run_pipeline(config, NeMoLLMService, countries=countries, capital_responses=capital_responses) + assert_results(results) @pytest.mark.usefixtures("openai") From a84479484e0e5afd99748698ffe5a5944b5aeabe Mon Sep 17 00:00:00 2001 From: David Gardner <96306125+dagardner-nv@users.noreply.github.com> Date: Mon, 11 Mar 2024 15:09:32 -0700 Subject: [PATCH 2/3] Catch langchain agent errors (#1539) * Catch any uncaught/un-handled langchain agent errors, return an error string so that we don't end up with a missing result. * Fix version conflicts for nemollm and dgl * Expand existing tests for the `LangChainAgentNode` ## By Submitting this PR I confirm: - I am familiar with the [Contributing Guidelines](https://github.com/nv-morpheus/Morpheus/blob/main/docs/source/developer_guide/contributing.md). - When the PR is ready for review, new or existing tests cover these changes. - When the PR is ready for review, the documentation is up to date with these changes. Authors: - David Gardner (https://github.com/dagardner-nv) Approvers: - Christopher Harris (https://github.com/cwharris) - Michael Demoret (https://github.com/mdemoret-nv) URL: https://github.com/nv-morpheus/Morpheus/pull/1539 --- .../all_cuda-121_arch-x86_64.yaml | 4 +- .../dev_cuda-121_arch-x86_64.yaml | 2 + .../examples_cuda-121_arch-x86_64.yaml | 4 +- .../runtime_cuda-121_arch-x86_64.yaml | 2 + dependencies.yaml | 4 +- morpheus/llm/nodes/langchain_agent_node.py | 11 ++- tests/_utils/llm.py | 32 ++++++- tests/llm/nodes/test_langchain_agent_node.py | 88 +++++++++++++++++++ 8 files changed, 139 insertions(+), 8 deletions(-) diff --git a/conda/environments/all_cuda-121_arch-x86_64.yaml b/conda/environments/all_cuda-121_arch-x86_64.yaml index 4d18afbac2..d2fb377b2a 100644 --- a/conda/environments/all_cuda-121_arch-x86_64.yaml +++ b/conda/environments/all_cuda-121_arch-x86_64.yaml @@ -93,7 +93,9 @@ dependencies: - pytorch=*=*cuda* - rapidjson=1.1.0 - rdma-core>=48 +- requests - requests-cache=1.1 +- requests-toolbelt - s3fs=2023.12.2 - scikit-build=0.17.6 - scikit-learn=1.3.2 @@ -117,7 +119,7 @@ dependencies: - --find-links https://data.dgl.ai/wheels/cu121/repo.html - PyMuPDF==1.23.21 - databricks-connect - - dgl + - dgl==2.0.0 - dglgo - google-search-results==2.4 - langchain==0.1.9 diff --git a/conda/environments/dev_cuda-121_arch-x86_64.yaml b/conda/environments/dev_cuda-121_arch-x86_64.yaml index 124fc59e17..5ca2285176 100644 --- a/conda/environments/dev_cuda-121_arch-x86_64.yaml +++ b/conda/environments/dev_cuda-121_arch-x86_64.yaml @@ -73,7 +73,9 @@ dependencies: - pytorch-cuda - pytorch=*=*cuda* - rapidjson=1.1.0 +- requests - requests-cache=1.1 +- requests-toolbelt - scikit-build=0.17.6 - scikit-learn=1.3.2 - sphinx diff --git a/conda/environments/examples_cuda-121_arch-x86_64.yaml b/conda/environments/examples_cuda-121_arch-x86_64.yaml index ba479b45f9..b8e0655e69 100644 --- a/conda/environments/examples_cuda-121_arch-x86_64.yaml +++ b/conda/environments/examples_cuda-121_arch-x86_64.yaml @@ -46,7 +46,9 @@ dependencies: - python=3.10 - pytorch-cuda - pytorch=*=*cuda* +- requests - requests-cache=1.1 +- requests-toolbelt - s3fs=2023.12.2 - scikit-learn=1.3.2 - sentence-transformers @@ -61,7 +63,7 @@ dependencies: - --find-links https://data.dgl.ai/wheels/cu121/repo.html - PyMuPDF==1.23.21 - databricks-connect - - dgl + - dgl==2.0.0 - dglgo - google-search-results==2.4 - langchain==0.1.9 diff --git a/conda/environments/runtime_cuda-121_arch-x86_64.yaml b/conda/environments/runtime_cuda-121_arch-x86_64.yaml index 7593e4a951..cd65265434 100644 --- a/conda/environments/runtime_cuda-121_arch-x86_64.yaml +++ b/conda/environments/runtime_cuda-121_arch-x86_64.yaml @@ -28,7 +28,9 @@ dependencies: - python=3.10 - pytorch-cuda - pytorch=*=*cuda* +- requests - requests-cache=1.1 +- requests-toolbelt - scikit-learn=1.3.2 - sqlalchemy - tqdm=4 diff --git a/dependencies.yaml b/dependencies.yaml index 2d243c4e47..fbfdc4b31c 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -261,7 +261,9 @@ dependencies: - python-graphviz - pytorch-cuda - pytorch=*=*cuda* + - requests - requests-cache=1.1 + - requests-toolbelt # Transitive dep needed by nemollm, specified here to ensure we get a compatible version - sqlalchemy - tqdm=4 - typing_utils=0.1 @@ -311,7 +313,7 @@ dependencies: - pip: - --find-links https://data.dgl.ai/wheels/cu121/repo.html - --find-links https://data.dgl.ai/wheels-test/repo.html - - dgl + - dgl==2.0.0 - dglgo example-llm-agents: diff --git a/morpheus/llm/nodes/langchain_agent_node.py b/morpheus/llm/nodes/langchain_agent_node.py index 86ee526af6..ab2671dae8 100644 --- a/morpheus/llm/nodes/langchain_agent_node.py +++ b/morpheus/llm/nodes/langchain_agent_node.py @@ -66,9 +66,14 @@ async def _run_single(self, **kwargs: dict[str, typing.Any]) -> dict[str, typing return results # We are not dealing with a list, so run single - return await self._agent_executor.arun(**kwargs) - - async def execute(self, context: LLMContext) -> LLMContext: + try: + return await self._agent_executor.arun(**kwargs) + except Exception as e: + error_msg = f"Error running agent: {e}" + logger.exception(error_msg) + return error_msg + + async def execute(self, context: LLMContext) -> LLMContext: # pylint: disable=invalid-overridden-method input_dict = context.get_inputs() diff --git a/tests/_utils/llm.py b/tests/_utils/llm.py index 9c48583b7e..9d9802478d 100644 --- a/tests/_utils/llm.py +++ b/tests/_utils/llm.py @@ -87,7 +87,35 @@ def mk_mock_openai_response(messages: list[str]) -> mock.MagicMock: Creates a mocked openai.types.chat.chat_completion.ChatCompletion response with the given messages. """ response = mock.MagicMock() - mock_choices = [_mk_mock_choice(message) for message in messages] - response.choices = mock_choices + + response.choices = [_mk_mock_choice(message) for message in messages] + response.dict.return_value = { + "choices": [{ + 'message': { + 'role': 'assistant', 'content': message + } + } for message in messages] + } return response + + +def mk_mock_langchain_tool(responses: list[str]) -> mock.MagicMock: + """ + Creates a mocked LangChainTestTool with the given responses. + """ + + # Langchain will call inspect.signature on the tool methods, typically mock objects don't have a signature, + # explicitly providing one here + async def _arun_spec(*_, **__): + pass + + def run_spec(*_, **__): + pass + + tool = mock.MagicMock() + tool.arun = mock.create_autospec(spec=_arun_spec) + tool.arun.side_effect = responses + tool.run = mock.create_autospec(run_spec) + tool.run.side_effect = responses + return tool diff --git a/tests/llm/nodes/test_langchain_agent_node.py b/tests/llm/nodes/test_langchain_agent_node.py index e86ae0b3f5..c3846f8465 100644 --- a/tests/llm/nodes/test_langchain_agent_node.py +++ b/tests/llm/nodes/test_langchain_agent_node.py @@ -16,8 +16,14 @@ from unittest import mock import pytest +from langchain.agents import AgentType +from langchain.agents import Tool +from langchain.agents import initialize_agent +from langchain.chat_models import ChatOpenAI # pylint: disable=no-name-in-module from _utils.llm import execute_node +from _utils.llm import mk_mock_langchain_tool +from _utils.llm import mk_mock_openai_response from morpheus.llm import LLMNodeBase from morpheus.llm.nodes.langchain_agent_node import LangChainAgentNode @@ -50,8 +56,90 @@ def test_execute( expected_output: list, expected_calls: list[mock.call], ): + # Tests the execute method of the LangChainAgentNode with a mocked agent_executor mock_agent_executor.arun.return_value = arun_return node = LangChainAgentNode(agent_executor=mock_agent_executor) assert execute_node(node, **values) == expected_output mock_agent_executor.arun.assert_has_calls(expected_calls) + + +def test_execute_tools(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMock]): + # Tests the execute method of the LangChainAgentNode with a a mocked tools and chat completion + (_, mock_async_client) = mock_chat_completion + chat_responses = [ + 'I should check Tool1\nAction: Tool1\nAction Input: "name a reptile"', + 'I should check Tool2\nAction: Tool2\nAction Input: "name of a day of the week"', + 'I should check Tool1\nAction: Tool1\nAction Input: "name a reptile"', + 'I should check Tool2\nAction: Tool2\nAction Input: "name of a day of the week"', + 'Observation: Answer: Yes!\nI now know the final answer.\nFinal Answer: Yes!' + ] + mock_responses = [mk_mock_openai_response([response]) for response in chat_responses] + mock_async_client.chat.completions.create.side_effect = mock_responses + + llm_chat = ChatOpenAI(model="fake-model", openai_api_key="fake-key") + + mock_tool1 = mk_mock_langchain_tool(["lizard", "frog"]) + mock_tool2 = mk_mock_langchain_tool(["Tuesday", "Thursday"]) + + tools = [ + Tool(name="Tool1", + func=mock_tool1.run, + coroutine=mock_tool1.arun, + description="useful for when you need to know the name of a reptile"), + Tool(name="Tool2", + func=mock_tool2.run, + coroutine=mock_tool2.arun, + description="useful for when you need to know the day of the week") + ] + + agent = initialize_agent(tools, + llm_chat, + agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, + verbose=True, + handle_parsing_errors=True, + early_stopping_method="generate", + return_intermediate_steps=False) + + node = LangChainAgentNode(agent_executor=agent) + + assert execute_node(node, input="input1") == "Yes!" + + +def test_execute_error(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMock]): + # Tests the execute method of the LangChainAgentNode with a a mocked tools and chat completion + (_, mock_async_client) = mock_chat_completion + chat_responses = [ + 'I should check Tool1\nAction: Tool1\nAction Input: "name a reptile"', + 'I should check Tool2\nAction: Tool2\nAction Input: "name of a day of the week"', + 'Observation: Answer: Yes!\nI now know the final answer.\nFinal Answer: Yes!' + ] + mock_responses = [mk_mock_openai_response([response]) for response in chat_responses] + mock_async_client.chat.completions.create.side_effect = mock_responses + + llm_chat = ChatOpenAI(model="fake-model", openai_api_key="fake-key") + + mock_tool1 = mk_mock_langchain_tool(["lizard"]) + mock_tool2 = mk_mock_langchain_tool(RuntimeError("unittest")) + + tools = [ + Tool(name="Tool1", + func=mock_tool1.run, + coroutine=mock_tool1.arun, + description="useful for when you need to know the name of a reptile"), + Tool(name="Tool2", + func=mock_tool2.run, + coroutine=mock_tool2.arun, + description="useful for when you need to test tool errors") + ] + + agent = initialize_agent(tools, + llm_chat, + agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, + verbose=True, + handle_parsing_errors=True, + early_stopping_method="generate", + return_intermediate_steps=False) + + node = LangChainAgentNode(agent_executor=agent) + assert execute_node(node, input="input1") == "Error running agent: unittest" From faa3d1e18e052b5bacc6c0c20fbc302336f4d88d Mon Sep 17 00:00:00 2001 From: eagostini Date: Mon, 11 Mar 2024 22:34:27 +0000 Subject: [PATCH 3/3] Last comments applied Signed-off-by: eagostini --- morpheus/_lib/doca/src/doca_context.cpp | 2 +- morpheus/_lib/doca/src/doca_source.cpp | 3 ++- morpheus/stages/doca/doca_source_stage.py | 7 +++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/morpheus/_lib/doca/src/doca_context.cpp b/morpheus/_lib/doca/src/doca_context.cpp index 6cb6eb75bd..c184621cd3 100644 --- a/morpheus/_lib/doca/src/doca_context.cpp +++ b/morpheus/_lib/doca/src/doca_context.cpp @@ -60,7 +60,7 @@ static doca_error_t open_doca_device_with_pci(const char *pcie_value, struct doc res = doca_devinfo_create_list(&dev_list, &nb_devs); if (res != DOCA_SUCCESS) { - MORPHEUS_FAIL("Failed to load doca devices list"); + LOG(ERROR) << "Failed to load doca devices list"; return res; } diff --git a/morpheus/_lib/doca/src/doca_source.cpp b/morpheus/_lib/doca/src/doca_source.cpp index 45d4edd15e..a4f6066d4a 100644 --- a/morpheus/_lib/doca/src/doca_source.cpp +++ b/morpheus/_lib/doca/src/doca_source.cpp @@ -38,6 +38,7 @@ #include #include #include +#include "morpheus/utilities/error.hpp" #include #include @@ -125,7 +126,7 @@ DocaSourceStage::subscriber_fn_t DocaSourceStage::build() int thread_idx = mrc::runnable::Context::get_runtime_context().rank(); if (thread_idx >= MAX_QUEUE) { - MORPHEUS_LOCAL(MORPHEUS_CONCAT_STR("Thread ID " << thread_idx << " bigger than MAX_QUEUE " << MAX_QUEUE)); + MORPHEUS_FAIL("Thread ID bigger than MAX_QUEUE"); return; } diff --git a/morpheus/stages/doca/doca_source_stage.py b/morpheus/stages/doca/doca_source_stage.py index 3844ff897b..8bde222048 100644 --- a/morpheus/stages/doca/doca_source_stage.py +++ b/morpheus/stages/doca/doca_source_stage.py @@ -68,9 +68,9 @@ def __init__( self._max_concurrent = c.num_threads self._nic_pci_address = nic_pci_address self._gpu_pci_address = gpu_pci_address - self._traffic_type = traffic_type - if self._traffic_type != 'udp' and self._traffic_type != 'tcp': - raise NotImplementedError("The Morpheus DOCA source stage allows a only udp or tcp types of traffic flow " + self._traffic_type) + self._traffic_type = traffic_type.lower() + if self._traffic_type not in ('udp', 'tcp'): + raise NotImplementedError("The Morpheus DOCA source stage allows a only udp or tcp types of traffic flow " + traffic_type) @property def name(self) -> str: @@ -90,7 +90,6 @@ def supports_cpp_node(self): def _build_source(self, builder: mrc.Builder) -> mrc.SegmentObject: if self._build_cpp_node(): - # return self._doca_source_class(builder, self.unique_name, self._nic_pci_address, self._gpu_pci_address, self._traffic_type) node = self._doca_source_class(builder, self.unique_name, self._nic_pci_address, self._gpu_pci_address, self._traffic_type) node.launch_options.pe_count = self._max_concurrent return node