Skip to content

[AQUA] Add API endpoints to manage chat templates via model custom metadata #1213

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ads/aqua/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
SUPPORTED_FILE_FORMATS = ["jsonl"]
MODEL_BY_REFERENCE_OSS_PATH_KEY = "artifact_location"

AQUA_CHAT_TEMPLATE_METADATA_KEY = "chat_template"

CONSOLE_LINK_RESOURCE_TYPE_MAPPING = {
"datasciencemodel": "models",
"datasciencemodeldeployment": "model-deployments",
Expand Down
58 changes: 50 additions & 8 deletions ads/aqua/extension/model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
from ads.aqua.common.enums import CustomInferenceContainerTypeFamily
from ads.aqua.common.errors import AquaRuntimeError
from ads.aqua.common.utils import get_hf_model_info, is_valid_ocid, list_hf_models
from ads.aqua.constants import AQUA_CHAT_TEMPLATE_METADATA_KEY
from ads.aqua.extension.base_handler import AquaAPIhandler
from ads.aqua.extension.errors import Errors
from ads.aqua.model import AquaModelApp
from ads.aqua.model.entities import AquaModelSummary, HFModelSummary
from ads.config import SERVICE
from ads.model import DataScienceModel
from ads.model.common.utils import MetadataArtifactPathType
from ads.model.service.oci_datascience_model import OCIDataScienceModel


class AquaModelHandler(AquaAPIhandler):
Expand Down Expand Up @@ -320,26 +323,65 @@ def post(self, *args, **kwargs): # noqa: ARG002
)


class AquaModelTokenizerConfigHandler(AquaAPIhandler):
class AquaModelChatTemplateHandler(AquaAPIhandler):
def get(self, model_id):
"""
Handles requests for retrieving the Hugging Face tokenizer configuration of a specified model.
Expected request format: GET /aqua/models/<model-ocid>/tokenizer
Handles requests for retrieving the chat template from custom metadata of a specified model.
Expected request format: GET /aqua/models/<model-ocid>/chat-template

"""

path_list = urlparse(self.request.path).path.strip("/").split("/")
# Path should be /aqua/models/ocid1.iad.ahdxxx/tokenizer
# path_list=['aqua','models','<model-ocid>','tokenizer']
# Path should be /aqua/models/ocid1.iad.ahdxxx/chat-template
# path_list=['aqua','models','<model-ocid>','chat-template']
if (
len(path_list) == 4
and is_valid_ocid(path_list[2])
and path_list[3] == "tokenizer"
and path_list[3] == "chat-template"
):
return self.finish(AquaModelApp().get_hf_tokenizer_config(model_id))
try:
oci_data_science_model = OCIDataScienceModel.from_id(model_id)
except Exception as e:
raise HTTPError(404, f"Model not found for id: {model_id}. Details: {str(e)}")
return self.finish(oci_data_science_model.get_custom_metadata_artifact("chat_template"))

raise HTTPError(400, f"The request {self.request.path} is invalid.")

@handle_exceptions
def post(self, model_id: str):
"""
Handles POST requests to add a custom chat_template metadata artifact to a model.

Expected request format:
POST /aqua/models/<model-ocid>/chat-template
Body: { "chat_template": "<your_template_string>" }

"""
try:
input_body = self.get_json_body()
except Exception as e:
raise HTTPError(400, f"Invalid JSON body: {str(e)}")

chat_template = input_body.get("chat_template")
if not chat_template:
raise HTTPError(400, "Missing required field: 'chat_template'")

try:
data_science_model = DataScienceModel.from_id(model_id)
except Exception as e:
raise HTTPError(404, f"Model not found for id: {model_id}. Details: {str(e)}")

try:
result = data_science_model.create_custom_metadata_artifact(
metadata_key_name=AQUA_CHAT_TEMPLATE_METADATA_KEY,
path_type=MetadataArtifactPathType.CONTENT,
artifact_path_or_content=chat_template.encode()
)
except Exception as e:
raise HTTPError(500, f"Failed to create metadata artifact: {str(e)}")

return self.finish(result)


class AquaModelDefinedMetadataArtifactHandler(AquaAPIhandler):
"""
Expand Down Expand Up @@ -381,7 +423,7 @@ def post(self, model_id: str, metadata_key: str):
("model/?([^/]*)", AquaModelHandler),
("model/?([^/]*)/license", AquaModelLicenseHandler),
("model/?([^/]*)/readme", AquaModelReadmeHandler),
("model/?([^/]*)/tokenizer", AquaModelTokenizerConfigHandler),
("model/?([^/]*)/chat-template", AquaModelChatTemplateHandler),
("model/hf/search/?([^/]*)", AquaHuggingFaceHandler),
(
"model/?([^/]*)/definedMetadata/?([^/]*)",
Expand Down
119 changes: 97 additions & 22 deletions tests/unitary/with_extras/aqua/test_model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
from unicodedata import category
from unittest import TestCase
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, patch, ANY

import pytest
from huggingface_hub.hf_api import HfApi, ModelInfo
Expand All @@ -14,13 +14,13 @@

from ads.aqua.common.errors import AquaRuntimeError
from ads.aqua.common.utils import get_hf_model_info
from ads.aqua.constants import AQUA_TROUBLESHOOTING_LINK, STATUS_CODE_MESSAGES
from ads.aqua.constants import AQUA_TROUBLESHOOTING_LINK, STATUS_CODE_MESSAGES, AQUA_CHAT_TEMPLATE_METADATA_KEY
from ads.aqua.extension.errors import ReplyDetails
from ads.aqua.extension.model_handler import (
AquaHuggingFaceHandler,
AquaModelHandler,
AquaModelLicenseHandler,
AquaModelTokenizerConfigHandler,
AquaModelChatTemplateHandler
)
from ads.aqua.model import AquaModelApp
from ads.aqua.model.entities import AquaModel, AquaModelSummary, HFModelSummary
Expand Down Expand Up @@ -254,39 +254,114 @@ def test_get(self, mock_load_license):
mock_load_license.assert_called_with("test_model_id")


class ModelTokenizerConfigHandlerTestCase(TestCase):
class AquaModelChatTemplateHandlerTestCase(TestCase):
@patch.object(IPythonHandler, "__init__")
def setUp(self, ipython_init_mock) -> None:
ipython_init_mock.return_value = None
self.model_tokenizer_config_handler = AquaModelTokenizerConfigHandler(
self.model_chat_template_handler = AquaModelChatTemplateHandler(
MagicMock(), MagicMock()
)
self.model_tokenizer_config_handler.finish = MagicMock()
self.model_tokenizer_config_handler.request = MagicMock()
self.model_chat_template_handler.finish = MagicMock()
self.model_chat_template_handler.request = MagicMock()
self.model_chat_template_handler._headers = {}

@patch.object(AquaModelApp, "get_hf_tokenizer_config")
@patch("ads.aqua.extension.model_handler.OCIDataScienceModel.from_id")
@patch("ads.aqua.extension.model_handler.urlparse")
def test_get(self, mock_urlparse, mock_get_hf_tokenizer_config):
request_path = MagicMock(path="aqua/model/ocid1.xx./tokenizer")
def test_get_valid_path(self, mock_urlparse, mock_from_id):
request_path = MagicMock(path="/aqua/models/ocid1.xx./chat-template")
mock_urlparse.return_value = request_path
self.model_tokenizer_config_handler.get(model_id="test_model_id")
self.model_tokenizer_config_handler.finish.assert_called_with(
mock_get_hf_tokenizer_config.return_value
)
mock_get_hf_tokenizer_config.assert_called_with("test_model_id")

@patch.object(AquaModelApp, "get_hf_tokenizer_config")
model_mock = MagicMock()
model_mock.get_custom_metadata_artifact.return_value = "chat_template_string"
mock_from_id.return_value = model_mock

self.model_chat_template_handler.get(model_id="test_model_id")
self.model_chat_template_handler.finish.assert_called_with("chat_template_string")
model_mock.get_custom_metadata_artifact.assert_called_with("chat_template")

@patch("ads.aqua.extension.model_handler.urlparse")
def test_get_invalid_path(self, mock_urlparse, mock_get_hf_tokenizer_config):
"""Test invalid request path should raise HTTPError(400)"""
request_path = MagicMock(path="/invalid/path")
def test_get_invalid_path(self, mock_urlparse):
request_path = MagicMock(path="/wrong/path")
mock_urlparse.return_value = request_path

with self.assertRaises(HTTPError) as context:
self.model_tokenizer_config_handler.get(model_id="test_model_id")
self.model_chat_template_handler.get("ocid1.test.chat")
self.assertEqual(context.exception.status_code, 400)
self.model_tokenizer_config_handler.finish.assert_not_called()
mock_get_hf_tokenizer_config.assert_not_called()

@patch("ads.aqua.extension.model_handler.OCIDataScienceModel.from_id", side_effect=Exception("Not found"))
@patch("ads.aqua.extension.model_handler.urlparse")
def test_get_model_not_found(self, mock_urlparse, mock_from_id):
request_path = MagicMock(path="/aqua/models/ocid1.invalid/chat-template")
mock_urlparse.return_value = request_path

with self.assertRaises(HTTPError) as context:
self.model_chat_template_handler.get("ocid1.invalid")
self.assertEqual(context.exception.status_code, 404)

@patch("ads.aqua.extension.model_handler.DataScienceModel.from_id")
def test_post_valid(self, mock_from_id):
model_mock = MagicMock()
model_mock.create_custom_metadata_artifact.return_value = {"result": "success"}
mock_from_id.return_value = model_mock

self.model_chat_template_handler.get_json_body = MagicMock(return_value={"chat_template": "Hello <|user|>"})
result = self.model_chat_template_handler.post("ocid1.valid")
self.model_chat_template_handler.finish.assert_called_with({"result": "success"})

model_mock.create_custom_metadata_artifact.assert_called_with(
metadata_key_name=AQUA_CHAT_TEMPLATE_METADATA_KEY,
path_type=ANY,
artifact_path_or_content=b"Hello <|user|>"
)

@patch.object(AquaModelChatTemplateHandler, "write_error")
def test_post_invalid_json(self, mock_write_error):
self.model_chat_template_handler.get_json_body = MagicMock(side_effect=Exception("Invalid JSON"))
self.model_chat_template_handler._headers = {}
self.model_chat_template_handler.post("ocid1.test.invalidjson")

mock_write_error.assert_called_once()

kwargs = mock_write_error.call_args.kwargs
exc_info = kwargs.get("exc_info")

assert exc_info is not None
exc_type, exc_instance, _ = exc_info

assert isinstance(exc_instance, HTTPError)
assert exc_instance.status_code == 400
assert "Invalid JSON body" in str(exc_instance)

@patch.object(AquaModelChatTemplateHandler, "write_error")
def test_post_missing_chat_template(self, mock_write_error):
self.model_chat_template_handler.get_json_body = MagicMock(return_value={})
self.model_chat_template_handler._headers = {}

self.model_chat_template_handler.post("ocid1.test.model")

mock_write_error.assert_called_once()
exc_info = mock_write_error.call_args.kwargs.get("exc_info")
assert exc_info is not None
_, exc_instance, _ = exc_info
assert isinstance(exc_instance, HTTPError)
assert exc_instance.status_code == 400
assert "Missing required field: 'chat_template'" in str(exc_instance)

@patch("ads.aqua.extension.model_handler.DataScienceModel.from_id", side_effect=Exception("Not found"))
@patch.object(AquaModelChatTemplateHandler, "write_error")
def test_post_model_not_found(self, mock_write_error, mock_from_id):
self.model_chat_template_handler.get_json_body = MagicMock(return_value={"chat_template": "test template"})
self.model_chat_template_handler._headers = {}

self.model_chat_template_handler.post("ocid1.invalid.model")

mock_write_error.assert_called_once()
exc_info = mock_write_error.call_args.kwargs.get("exc_info")
assert exc_info is not None
_, exc_instance, _ = exc_info
assert isinstance(exc_instance, HTTPError)
assert exc_instance.status_code == 404
assert "Model not found" in str(exc_instance)


class TestAquaHuggingFaceHandler:
Expand Down