Skip to content

Commit 32b8bcc

Browse files
committed
fixed unit tests
1 parent 0d7ade4 commit 32b8bcc

File tree

2 files changed

+87
-99
lines changed

2 files changed

+87
-99
lines changed

tests/unitary/with_extras/aqua/test_deployment_handler.py

Lines changed: 19 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
from parameterized import parameterized
1414

1515
import ads.aqua
16-
from ads.aqua.modeldeployment.entities import AquaDeploymentDetail
1716
import ads.config
1817
from ads.aqua.extension.deployment_handler import (
1918
AquaDeploymentHandler,
2019
AquaDeploymentParamsHandler,
2120
AquaDeploymentStreamingInferenceHandler,
2221
AquaModelListHandler,
2322
)
23+
from ads.aqua.modeldeployment.entities import AquaDeploymentDetail
2424

2525

2626
class TestDataset:
@@ -91,6 +91,24 @@ def test_get_deployment_config_without_id(self, mock_error):
9191
mock_error.assert_called_once()
9292
assert result["status"] == 400
9393

94+
@patch("ads.aqua.modeldeployment.AquaDeploymentApp.recommend_shape")
95+
def test_get_recommend_shape(self, mock_recommend_shape):
96+
"""Test get method to return deployment config"""
97+
self.deployment_handler.request.path = "aqua/deployments/recommend_shapes"
98+
self.deployment_handler.get(id="mock-model-id")
99+
mock_recommend_shape.assert_called()
100+
101+
@unittest.skip("fix this test after exception handler is updated.")
102+
@patch("ads.aqua.extension.base_handler.AquaAPIhandler.write_error")
103+
def test_get_recommend_shape_without_id(self, mock_error):
104+
"""Test get method to return deployment config"""
105+
# todo: exception handler needs to be revisited
106+
self.deployment_handler.request.path = "aqua/deployments/recommend_shape"
107+
mock_error.return_value = MagicMock(status=400)
108+
result = self.deployment_handler.get(id="")
109+
mock_error.assert_called_once()
110+
assert result["status"] == 400
111+
94112
@patch(
95113
"ads.aqua.modeldeployment.AquaDeploymentApp.get_multimodel_deployment_config"
96114
)
@@ -284,74 +302,3 @@ def test_get_model_list(self, mock_get, mock_finish):
284302
mock_finish.side_effect = lambda x: x
285303
result = self.aqua_model_list_handler.get(model_id="test_model_id")
286304
mock_get.assert_called()
287-
288-
from unittest.mock import MagicMock, patch
289-
290-
import pytest
291-
from tornado.web import HTTPError
292-
293-
from ads.aqua.extension.base_handler import AquaAPIhandler
294-
from ads.aqua.extension.errors import Errors
295-
from ads.aqua.extension.recommend_handler import AquaRecommendHandler
296-
297-
298-
@pytest.fixture
299-
def handler():
300-
# Patch AquaAPIhandler.__init__ for unit test stubbing
301-
AquaAPIhandler.__init__ = lambda self, *args, **kwargs: None
302-
h = AquaRecommendHandler(MagicMock(), MagicMock())
303-
h.finish = MagicMock()
304-
h.request = MagicMock()
305-
# Set required Tornado internal fields
306-
h._headers = {}
307-
h._write_buffer = []
308-
return h
309-
310-
311-
def test_post_valid_input(monkeypatch, handler):
312-
input_data = {"model_ocid": "ocid1.datasciencemodel.oc1.XYZ"}
313-
expected = {"recommendations": ["VM.GPU.A10.1"], "troubleshoot": ""}
314-
315-
# Patch class on correct import path, so handler sees our fake implementation
316-
class FakeAquaRecommendApp:
317-
def which_gpu(self, **kwargs):
318-
return expected
319-
320-
monkeypatch.setattr(
321-
"ads.aqua.extension.recommend_handler.AquaRecommendApp", FakeAquaRecommendApp
322-
)
323-
324-
handler.get_json_body = MagicMock(return_value=input_data)
325-
handler.post()
326-
handler.finish.assert_called_once_with(expected)
327-
328-
329-
def test_post_no_input(handler):
330-
handler.get_json_body = MagicMock(return_value=None)
331-
handler._headers = {}
332-
handler._write_buffer = []
333-
handler.write_error = MagicMock()
334-
handler.post()
335-
handler.write_error.assert_called_once()
336-
exc_info = handler.write_error.call_args.kwargs.get("exc_info")
337-
assert exc_info is not None
338-
exc_type, exc_value, _ = exc_info
339-
assert exc_type is HTTPError
340-
assert exc_value.status_code == 400
341-
assert exc_value.log_message == Errors.NO_INPUT_DATA
342-
343-
344-
def test_post_invalid_input(handler):
345-
handler.get_json_body = MagicMock(side_effect=Exception("bad input"))
346-
handler._headers = {}
347-
handler._write_buffer = []
348-
handler.write_error = MagicMock()
349-
handler.post()
350-
handler.write_error.assert_called_once()
351-
exc_info = handler.write_error.call_args.kwargs.get("exc_info")
352-
assert exc_info is not None
353-
exc_type, exc_value, _ = exc_info
354-
assert exc_type is HTTPError
355-
assert exc_value.status_code == 400
356-
assert exc_value.log_message == Errors.INVALID_INPUT_DATA_FORMAT
357-

tests/unitary/with_extras/aqua/test_recommend.py

Lines changed: 68 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
ModelConfig,
2121
ModelDetail,
2222
RequestRecommend,
23+
ShapeRecommendationReport,
2324
ShapeReport,
2425
)
2526
from ads.model.model_metadata import ModelCustomMetadata, ModelProvenanceMetadata
@@ -233,9 +234,10 @@ def __init__(self):
233234
local_shapes = local_data.get("shapes", {})
234235
self.shapes = local_shapes
235236

237+
236238
class MockDataScienceModel:
237239
@staticmethod
238-
def create(config_file = ""):
240+
def create(config_file=""):
239241
mock_model = MagicMock()
240242
mock_model.model_file_description = {"test_key": "test_value"}
241243
mock_model.display_name = re.sub(r"\.json$", "", config_file)
@@ -245,7 +247,7 @@ def create(config_file = ""):
245247
"license": "test_license",
246248
"organization": "test_organization",
247249
"task": "text-generation",
248-
"model_format" : "SAFETENSORS",
250+
"model_format": "SAFETENSORS",
249251
"ready_to_fine_tune": "true",
250252
"aqua_custom_base_model": "true",
251253
}
@@ -261,36 +263,68 @@ def create(config_file = ""):
261263

262264

263265
class TestAquaShapeRecommend:
264-
265-
def test_which_gpu_valid(self, monkeypatch, **kwargs):
266+
@pytest.mark.parametrize(
267+
"config, expected_recs, expected_troubleshoot",
268+
[
269+
( # decoder-only model
270+
{
271+
"num_hidden_layers": 2,
272+
"hidden_size": 64,
273+
"vocab_size": 1000,
274+
"num_attention_heads": 4,
275+
"head_dim": 16,
276+
"max_position_embeddings": 2048,
277+
},
278+
[],
279+
"",
280+
),
281+
( # encoder-decoder model
282+
{
283+
"num_hidden_layers": 2,
284+
"hidden_size": 64,
285+
"vocab_size": 1000,
286+
"num_attention_heads": 4,
287+
"head_dim": 16,
288+
"max_position_embeddings": 2048,
289+
"is_encoder_decoder": True,
290+
},
291+
[],
292+
"Please provide a decoder-only text-generation model (ex. Llama, Falcon, etc). Encoder-decoder models (ex. T5, Gemma) and encoder-only (BERT) are not supported at this time.",
293+
),
294+
],
295+
)
296+
def test_which_shapes_valid(
297+
self, monkeypatch, config, expected_recs, expected_troubleshoot
298+
):
266299
app = AquaShapeRecommend()
267300
mock_model = MockDataScienceModel.create()
268301

269302
monkeypatch.setattr(
270-
"ads.aqua.app.DataScienceModel.from_id",
271-
lambda _: mock_model
303+
"ads.aqua.app.DataScienceModel.from_id", lambda _: mock_model
272304
)
273305

274-
config = {
275-
"num_hidden_layers": 2,
276-
"hidden_size": 64,
277-
"vocab_size": 1000,
278-
"num_attention_heads": 4,
279-
"head_dim": 16,
280-
"max_position_embeddings": 2048,
281-
}
282-
306+
expected_result = ShapeRecommendationReport(
307+
recommendations=expected_recs, troubleshoot=expected_troubleshoot
308+
)
283309
app._get_model_config = MagicMock(return_value=config)
284310
app.valid_compute_shapes = MagicMock(return_value=[])
285-
app._summarize_shapes_for_seq_lens = MagicMock(return_value="mocked_report")
311+
app._summarize_shapes_for_seq_lens = MagicMock(return_value=expected_result)
286312

287-
request = RequestRecommend(model_id="ocid1.datasciencemodel.oc1.TEST")
313+
request = RequestRecommend(
314+
model_id="ocid1.datasciencemodel.oc1.TEST", generate_table=False
315+
)
288316
result = app.which_shapes(request)
317+
assert result == expected_result
289318

290-
app.valid_compute_shapes.assert_called_once()
291-
llm_config = LLMConfig.from_raw_config(config)
292-
app._summarize_shapes_for_seq_lens.assert_called_once_with(llm_config, [], "")
293-
assert result == "mocked_report"
319+
# If troubleshoot is populated (error case), _summarize_shapes_for_seq_lens should not have been called
320+
if expected_troubleshoot:
321+
app._summarize_shapes_for_seq_lens.assert_not_called()
322+
else:
323+
# For non-error case, summarize should have been called
324+
llm_config = LLMConfig.from_raw_config(config)
325+
app._summarize_shapes_for_seq_lens.assert_called_once_with(
326+
llm_config, [], ""
327+
)
294328

295329
@pytest.mark.parametrize(
296330
"config_file, result_file",
@@ -303,7 +337,9 @@ def test_which_gpu_valid(self, monkeypatch, **kwargs):
303337
),
304338
],
305339
)
306-
def test_which_gpu_valid_from_file(self, monkeypatch, config_file, result_file, **kwargs):
340+
def test_which_shapes_valid_from_file(
341+
self, monkeypatch, config_file, result_file, **kwargs
342+
):
307343
raw = load_config(config_file)
308344
app = AquaShapeRecommend()
309345
mock_model = MockDataScienceModel.create(config_file)
@@ -317,9 +353,14 @@ def test_which_gpu_valid_from_file(self, monkeypatch, config_file, result_file,
317353
ComputeShapeSummary(name=name, shape_series="GPU", gpu_specs=spec)
318354
for name, spec in shapes_index.shapes.items()
319355
]
320-
monkeypatch.setattr(app, "valid_compute_shapes", lambda *args, **kwargs: real_shapes)
356+
monkeypatch.setattr(
357+
app, "valid_compute_shapes", lambda *args, **kwargs: real_shapes
358+
)
321359

322-
result = app.which_gpu(model_ocid="ocid1.datasciencemodel.oc1.TEST")
360+
request = RequestRecommend(
361+
model_id="ocid1.datasciencemodel.oc1.TEST", generate_table=False
362+
)
363+
result = app.which_shapes(request=request)
323364

324365
expected_result = load_config(result_file)
325366
assert result.model_dump() == expected_result
@@ -349,7 +390,7 @@ def test_shape_report_pareto_front(self):
349390
model_size_gb=1, kv_cache_size_gb=1, total_model_gb=2
350391
),
351392
deployment_params=DeploymentParams(
352-
quantization="8bit", max_model_len=2048, params = ""
393+
quantization="8bit", max_model_len=2048, params=""
353394
),
354395
recommendation="ok",
355396
)
@@ -363,7 +404,7 @@ def test_shape_report_pareto_front(self):
363404
model_size_gb=1, kv_cache_size_gb=1, total_model_gb=2
364405
),
365406
deployment_params=DeploymentParams(
366-
quantization="8bit", max_model_len=2048, params = ""
407+
quantization="8bit", max_model_len=2048, params=""
367408
),
368409
recommendation="ok",
369410
)
@@ -377,7 +418,7 @@ def test_shape_report_pareto_front(self):
377418
model_size_gb=1, kv_cache_size_gb=1, total_model_gb=2
378419
),
379420
deployment_params=DeploymentParams(
380-
quantization="bfloat16", max_model_len=2048, params = ""
421+
quantization="bfloat16", max_model_len=2048, params=""
381422
),
382423
recommendation="ok",
383424
)

0 commit comments

Comments
 (0)