20
20
ModelConfig ,
21
21
ModelDetail ,
22
22
RequestRecommend ,
23
+ ShapeRecommendationReport ,
23
24
ShapeReport ,
24
25
)
25
26
from ads .model .model_metadata import ModelCustomMetadata , ModelProvenanceMetadata
@@ -233,9 +234,10 @@ def __init__(self):
233
234
local_shapes = local_data .get ("shapes" , {})
234
235
self .shapes = local_shapes
235
236
237
+
236
238
class MockDataScienceModel :
237
239
@staticmethod
238
- def create (config_file = "" ):
240
+ def create (config_file = "" ):
239
241
mock_model = MagicMock ()
240
242
mock_model .model_file_description = {"test_key" : "test_value" }
241
243
mock_model .display_name = re .sub (r"\.json$" , "" , config_file )
@@ -245,7 +247,7 @@ def create(config_file = ""):
245
247
"license" : "test_license" ,
246
248
"organization" : "test_organization" ,
247
249
"task" : "text-generation" ,
248
- "model_format" : "SAFETENSORS" ,
250
+ "model_format" : "SAFETENSORS" ,
249
251
"ready_to_fine_tune" : "true" ,
250
252
"aqua_custom_base_model" : "true" ,
251
253
}
@@ -261,36 +263,68 @@ def create(config_file = ""):
261
263
262
264
263
265
class TestAquaShapeRecommend :
264
-
265
- def test_which_gpu_valid (self , monkeypatch , ** kwargs ):
266
+ @pytest .mark .parametrize (
267
+ "config, expected_recs, expected_troubleshoot" ,
268
+ [
269
+ ( # decoder-only model
270
+ {
271
+ "num_hidden_layers" : 2 ,
272
+ "hidden_size" : 64 ,
273
+ "vocab_size" : 1000 ,
274
+ "num_attention_heads" : 4 ,
275
+ "head_dim" : 16 ,
276
+ "max_position_embeddings" : 2048 ,
277
+ },
278
+ [],
279
+ "" ,
280
+ ),
281
+ ( # encoder-decoder model
282
+ {
283
+ "num_hidden_layers" : 2 ,
284
+ "hidden_size" : 64 ,
285
+ "vocab_size" : 1000 ,
286
+ "num_attention_heads" : 4 ,
287
+ "head_dim" : 16 ,
288
+ "max_position_embeddings" : 2048 ,
289
+ "is_encoder_decoder" : True ,
290
+ },
291
+ [],
292
+ "Please provide a decoder-only text-generation model (ex. Llama, Falcon, etc). Encoder-decoder models (ex. T5, Gemma) and encoder-only (BERT) are not supported at this time." ,
293
+ ),
294
+ ],
295
+ )
296
+ def test_which_shapes_valid (
297
+ self , monkeypatch , config , expected_recs , expected_troubleshoot
298
+ ):
266
299
app = AquaShapeRecommend ()
267
300
mock_model = MockDataScienceModel .create ()
268
301
269
302
monkeypatch .setattr (
270
- "ads.aqua.app.DataScienceModel.from_id" ,
271
- lambda _ : mock_model
303
+ "ads.aqua.app.DataScienceModel.from_id" , lambda _ : mock_model
272
304
)
273
305
274
- config = {
275
- "num_hidden_layers" : 2 ,
276
- "hidden_size" : 64 ,
277
- "vocab_size" : 1000 ,
278
- "num_attention_heads" : 4 ,
279
- "head_dim" : 16 ,
280
- "max_position_embeddings" : 2048 ,
281
- }
282
-
306
+ expected_result = ShapeRecommendationReport (
307
+ recommendations = expected_recs , troubleshoot = expected_troubleshoot
308
+ )
283
309
app ._get_model_config = MagicMock (return_value = config )
284
310
app .valid_compute_shapes = MagicMock (return_value = [])
285
- app ._summarize_shapes_for_seq_lens = MagicMock (return_value = "mocked_report" )
311
+ app ._summarize_shapes_for_seq_lens = MagicMock (return_value = expected_result )
286
312
287
- request = RequestRecommend (model_id = "ocid1.datasciencemodel.oc1.TEST" )
313
+ request = RequestRecommend (
314
+ model_id = "ocid1.datasciencemodel.oc1.TEST" , generate_table = False
315
+ )
288
316
result = app .which_shapes (request )
317
+ assert result == expected_result
289
318
290
- app .valid_compute_shapes .assert_called_once ()
291
- llm_config = LLMConfig .from_raw_config (config )
292
- app ._summarize_shapes_for_seq_lens .assert_called_once_with (llm_config , [], "" )
293
- assert result == "mocked_report"
319
+ # If troubleshoot is populated (error case), _summarize_shapes_for_seq_lens should not have been called
320
+ if expected_troubleshoot :
321
+ app ._summarize_shapes_for_seq_lens .assert_not_called ()
322
+ else :
323
+ # For non-error case, summarize should have been called
324
+ llm_config = LLMConfig .from_raw_config (config )
325
+ app ._summarize_shapes_for_seq_lens .assert_called_once_with (
326
+ llm_config , [], ""
327
+ )
294
328
295
329
@pytest .mark .parametrize (
296
330
"config_file, result_file" ,
@@ -303,7 +337,9 @@ def test_which_gpu_valid(self, monkeypatch, **kwargs):
303
337
),
304
338
],
305
339
)
306
- def test_which_gpu_valid_from_file (self , monkeypatch , config_file , result_file , ** kwargs ):
340
+ def test_which_shapes_valid_from_file (
341
+ self , monkeypatch , config_file , result_file , ** kwargs
342
+ ):
307
343
raw = load_config (config_file )
308
344
app = AquaShapeRecommend ()
309
345
mock_model = MockDataScienceModel .create (config_file )
@@ -317,9 +353,14 @@ def test_which_gpu_valid_from_file(self, monkeypatch, config_file, result_file,
317
353
ComputeShapeSummary (name = name , shape_series = "GPU" , gpu_specs = spec )
318
354
for name , spec in shapes_index .shapes .items ()
319
355
]
320
- monkeypatch .setattr (app , "valid_compute_shapes" , lambda * args , ** kwargs : real_shapes )
356
+ monkeypatch .setattr (
357
+ app , "valid_compute_shapes" , lambda * args , ** kwargs : real_shapes
358
+ )
321
359
322
- result = app .which_gpu (model_ocid = "ocid1.datasciencemodel.oc1.TEST" )
360
+ request = RequestRecommend (
361
+ model_id = "ocid1.datasciencemodel.oc1.TEST" , generate_table = False
362
+ )
363
+ result = app .which_shapes (request = request )
323
364
324
365
expected_result = load_config (result_file )
325
366
assert result .model_dump () == expected_result
@@ -349,7 +390,7 @@ def test_shape_report_pareto_front(self):
349
390
model_size_gb = 1 , kv_cache_size_gb = 1 , total_model_gb = 2
350
391
),
351
392
deployment_params = DeploymentParams (
352
- quantization = "8bit" , max_model_len = 2048 , params = ""
393
+ quantization = "8bit" , max_model_len = 2048 , params = ""
353
394
),
354
395
recommendation = "ok" ,
355
396
)
@@ -363,7 +404,7 @@ def test_shape_report_pareto_front(self):
363
404
model_size_gb = 1 , kv_cache_size_gb = 1 , total_model_gb = 2
364
405
),
365
406
deployment_params = DeploymentParams (
366
- quantization = "8bit" , max_model_len = 2048 , params = ""
407
+ quantization = "8bit" , max_model_len = 2048 , params = ""
367
408
),
368
409
recommendation = "ok" ,
369
410
)
@@ -377,7 +418,7 @@ def test_shape_report_pareto_front(self):
377
418
model_size_gb = 1 , kv_cache_size_gb = 1 , total_model_gb = 2
378
419
),
379
420
deployment_params = DeploymentParams (
380
- quantization = "bfloat16" , max_model_len = 2048 , params = ""
421
+ quantization = "bfloat16" , max_model_len = 2048 , params = ""
381
422
),
382
423
recommendation = "ok" ,
383
424
)
0 commit comments