33#
44# This source code is licensed under the terms described in the LICENSE file in
55# the root directory of this source tree.
6- from typing import List
6+ import base64
7+ import mimetypes
8+ import os
79
810import pytest
911
1012from llama_stack .apis .common .job_types import JobStatus
1113from llama_stack .apis .post_training import (
12- Checkpoint ,
1314 DataConfig ,
1415 LoraFinetuningConfig ,
1516 OptimizerConfig ,
16- PostTrainingJob ,
17- PostTrainingJobArtifactsResponse ,
18- PostTrainingJobStatusResponse ,
1917 TrainingConfig ,
2018)
2119
2624# -v -s --tb=short --disable-warnings
2725
2826
29- @pytest .mark .skip (reason = "FIXME FIXME @yanxi0830 this needs to be migrated to use the API" )
27+ def data_url_from_file (file_path : str ) -> str :
28+ if not os .path .exists (file_path ):
29+ raise FileNotFoundError (f"File not found: { file_path } " )
30+
31+ with open (file_path , "rb" ) as file :
32+ file_content = file .read ()
33+
34+ base64_content = base64 .b64encode (file_content ).decode ("utf-8" )
35+ mime_type , _ = mimetypes .guess_type (file_path )
36+
37+ data_url = f"data:{ mime_type } ;base64,{ base64_content } "
38+
39+ return data_url
40+
41+
3042class TestPostTraining :
3143 @pytest .mark .asyncio
32- async def test_supervised_fine_tune (self , post_training_stack ):
44+ def test_supervised_fine_tune (self , client_with_models ):
45+ dataset = client_with_models .datasets .register (
46+ purpose = "post-training/messages" ,
47+ source = {
48+ "type" : "uri" ,
49+ "uri" : data_url_from_file (
50+ os .path .join (os .path .dirname (__file__ ),
51+ "../datasets/test_dataset.csv" )
52+ ),
53+ },
54+ )
55+
3356 algorithm_config = LoraFinetuningConfig (
3457 type = "LoRA" ,
3558 lora_attn_modules = ["q_proj" , "v_proj" , "output_proj" ],
3659 apply_lora_to_mlp = True ,
3760 apply_lora_to_output = False ,
38- rank = 8 ,
39- alpha = 16 ,
61+ rank = 1 ,
62+ alpha = 1 ,
4063 )
4164
4265 data_config = DataConfig (
43- dataset_id = "alpaca" ,
66+ dataset_id = dataset .identifier ,
67+ data_format = "instruct" ,
4468 batch_size = 1 ,
4569 shuffle = False ,
4670 )
@@ -50,18 +74,19 @@ async def test_supervised_fine_tune(self, post_training_stack):
5074 lr = 3e-4 ,
5175 lr_min = 3e-5 ,
5276 weight_decay = 0.1 ,
53- num_warmup_steps = 100 ,
77+ num_warmup_steps = 1 ,
5478 )
5579
5680 training_config = TrainingConfig (
5781 n_epochs = 1 ,
5882 data_config = data_config ,
5983 optimizer_config = optimizer_config ,
6084 max_steps_per_epoch = 1 ,
85+ max_validation_steps = 1 ,
6186 gradient_accumulation_steps = 1 ,
87+ dtype = "fp32" ,
6288 )
63- post_training_impl = post_training_stack
64- response = await post_training_impl .supervised_fine_tune (
89+ job = client_with_models .post_training .supervised_fine_tune (
6590 job_uuid = "1234" ,
6691 model = "Llama3.2-3B-Instruct" ,
6792 algorithm_config = algorithm_config ,
@@ -70,32 +95,24 @@ async def test_supervised_fine_tune(self, post_training_stack):
7095 logger_config = {},
7196 checkpoint_dir = "null" ,
7297 )
73- assert isinstance (response , PostTrainingJob )
74- assert response .job_uuid == "1234"
98+ assert job .job_uuid == "1234"
7599
76100 @pytest .mark .asyncio
77- async def test_get_training_jobs (self , post_training_stack ):
78- post_training_impl = post_training_stack
79- jobs_list = await post_training_impl .get_training_jobs ()
80- assert isinstance (jobs_list , List )
101+ def test_get_training_jobs (self , client_with_models ):
102+ jobs_list = client_with_models .post_training .job .list ()
103+ assert len (jobs_list ) == 1
81104 assert jobs_list [0 ].job_uuid == "1234"
82105
83106 @pytest .mark .asyncio
84- async def test_get_training_job_status (self , post_training_stack ):
85- post_training_impl = post_training_stack
86- job_status = await post_training_impl .get_training_job_status ("1234" )
87- assert isinstance (job_status , PostTrainingJobStatusResponse )
107+ def test_get_training_job_status (self , client_with_models ):
108+ job_status = client_with_models .post_training .job .status (job_uuid = "1234" )
88109 assert job_status .job_uuid == "1234"
89- assert job_status .status == JobStatus .completed
90- assert isinstance (job_status .checkpoints [0 ], Checkpoint )
110+ assert job_status .status == JobStatus .completed .value
91111
92112 @pytest .mark .asyncio
93- async def test_get_training_job_artifacts (self , post_training_stack ):
94- post_training_impl = post_training_stack
95- job_artifacts = await post_training_impl .get_training_job_artifacts ("1234" )
96- assert isinstance (job_artifacts , PostTrainingJobArtifactsResponse )
113+ def test_get_training_job_artifacts (self , client_with_models ):
114+ job_artifacts = client_with_models .post_training .job .artifacts (job_uuid = "1234" )
97115 assert job_artifacts .job_uuid == "1234"
98- assert isinstance (job_artifacts .checkpoints [0 ], Checkpoint )
99- assert job_artifacts .checkpoints [0 ].identifier == "Llama3.2-3B-Instruct-sft-0"
100- assert job_artifacts .checkpoints [0 ].epoch == 0
101- assert "/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0" in job_artifacts .checkpoints [0 ].path
116+ assert job_artifacts .checkpoints [0 ]['identifier' ] == "Llama3.2-3B-Instruct-sft-0"
117+ assert job_artifacts .checkpoints [0 ]['epoch' ] == 0
118+ assert "/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0" in job_artifacts .checkpoints [0 ]['path' ]
0 commit comments