|  | 
| 5 | 5 | # the root directory of this source tree. | 
| 6 | 6 | 
 | 
| 7 | 7 | import logging | 
|  | 8 | +import os | 
| 8 | 9 | import sys | 
| 9 | 10 | import time | 
| 10 | 11 | import uuid | 
|  | 12 | +from datetime import datetime, timezone | 
| 11 | 13 | 
 | 
| 12 | 14 | import pytest | 
| 13 | 15 | 
 | 
|  | 16 | +from llama_stack.apis.common.job_types import JobStatus | 
| 14 | 17 | from llama_stack.apis.post_training import ( | 
| 15 | 18 |     DataConfig, | 
| 16 | 19 |     LoraFinetuningConfig, | 
| @@ -44,6 +47,15 @@ def capture_output(capsys): | 
| 44 | 47 | 
 | 
| 45 | 48 | 
 | 
| 46 | 49 | class TestPostTraining: | 
|  | 50 | +    job_uuid = f"test-job{uuid.uuid4()}" | 
|  | 51 | +    model = "ibm-granite/granite-3.3-2b-instruct" | 
|  | 52 | + | 
|  | 53 | +    def _validate_checkpoints(self, checkpoints): | 
|  | 54 | +        assert len(checkpoints) == 1 | 
|  | 55 | +        assert checkpoints[0]["identifier"] == f"{self.model}-sft-1" | 
|  | 56 | +        assert checkpoints[0]["epoch"] == 1 | 
|  | 57 | +        assert "/.llama/checkpoints/merged_model" in checkpoints[0]["path"] | 
|  | 58 | + | 
| 47 | 59 |     @pytest.mark.integration | 
| 48 | 60 |     @pytest.mark.parametrize( | 
| 49 | 61 |         "purpose, source", | 
| @@ -91,60 +103,62 @@ def test_supervised_fine_tune(self, llama_stack_client, purpose, source): | 
| 91 | 103 |             gradient_accumulation_steps=1, | 
| 92 | 104 |         ) | 
| 93 | 105 | 
 | 
| 94 |  | -        job_uuid = f"test-job{uuid.uuid4()}" | 
| 95 |  | -        logger.info(f"Starting training job with UUID: {job_uuid}") | 
|  | 106 | +        logger.info(f"Starting training job with UUID: {self.job_uuid}") | 
| 96 | 107 | 
 | 
| 97 | 108 |         # train with HF trl SFTTrainer as the default | 
|  | 109 | +        os.makedirs("~/.llama/checkpoints/", exist_ok=True) | 
|  | 110 | + | 
|  | 111 | +        started = datetime.now(timezone.utc) | 
| 98 | 112 |         _ = llama_stack_client.post_training.supervised_fine_tune( | 
| 99 |  | -            job_uuid=job_uuid, | 
| 100 |  | -            model="ibm-granite/granite-3.3-2b-instruct", | 
|  | 113 | +            job_uuid=self.job_uuid, | 
|  | 114 | +            model=self.model, | 
| 101 | 115 |             algorithm_config=algorithm_config, | 
| 102 | 116 |             training_config=training_config, | 
| 103 | 117 |             hyperparam_search_config={}, | 
| 104 | 118 |             logger_config={}, | 
| 105 |  | -            checkpoint_dir=None, | 
|  | 119 | +            checkpoint_dir="~/.llama/checkpoints/", | 
| 106 | 120 |         ) | 
| 107 | 121 | 
 | 
| 108 | 122 |         while True: | 
| 109 |  | -            status = llama_stack_client.post_training.job.status(job_uuid=job_uuid) | 
|  | 123 | +            status = llama_stack_client.post_training.job.status(job_uuid=self.job_uuid) | 
| 110 | 124 |             if not status: | 
| 111 | 125 |                 logger.error("Job not found") | 
| 112 | 126 |                 break | 
| 113 | 127 | 
 | 
| 114 | 128 |             logger.info(f"Current status: {status}") | 
| 115 | 129 |             if status.status == "completed": | 
|  | 130 | +                completed = datetime.now(timezone.utc) | 
|  | 131 | +                assert status.completed_at is not None | 
|  | 132 | +                assert status.completed_at >= started | 
|  | 133 | +                assert status.completed_at <= completed | 
| 116 | 134 |                 break | 
| 117 | 135 | 
 | 
| 118 | 136 |             logger.info("Waiting for job to complete...") | 
| 119 | 137 |             time.sleep(10)  # Increased sleep time to reduce polling frequency | 
| 120 | 138 | 
 | 
| 121 |  | -        artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid) | 
| 122 |  | -        logger.info(f"Job artifacts: {artifacts}") | 
| 123 |  | - | 
| 124 |  | -    # TODO: Fix these tests to properly represent the Jobs API in training | 
| 125 |  | -    # @pytest.mark.asyncio | 
| 126 |  | -    # async def test_get_training_jobs(self, post_training_stack): | 
| 127 |  | -    #     post_training_impl = post_training_stack | 
| 128 |  | -    #     jobs_list = await post_training_impl.get_training_jobs() | 
| 129 |  | -    #     assert isinstance(jobs_list, list) | 
| 130 |  | -    #     assert jobs_list[0].job_uuid == "1234" | 
| 131 |  | - | 
| 132 |  | -    # @pytest.mark.asyncio | 
| 133 |  | -    # async def test_get_training_job_status(self, post_training_stack): | 
| 134 |  | -    #     post_training_impl = post_training_stack | 
| 135 |  | -    #     job_status = await post_training_impl.get_training_job_status("1234") | 
| 136 |  | -    #     assert isinstance(job_status, PostTrainingJobStatusResponse) | 
| 137 |  | -    #     assert job_status.job_uuid == "1234" | 
| 138 |  | -    #     assert job_status.status == JobStatus.completed | 
| 139 |  | -    #     assert isinstance(job_status.checkpoints[0], Checkpoint) | 
| 140 |  | - | 
| 141 |  | -    # @pytest.mark.asyncio | 
| 142 |  | -    # async def test_get_training_job_artifacts(self, post_training_stack): | 
| 143 |  | -    #     post_training_impl = post_training_stack | 
| 144 |  | -    #     job_artifacts = await post_training_impl.get_training_job_artifacts("1234") | 
| 145 |  | -    #     assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse) | 
| 146 |  | -    #     assert job_artifacts.job_uuid == "1234" | 
| 147 |  | -    #     assert isinstance(job_artifacts.checkpoints[0], Checkpoint) | 
| 148 |  | -    #     assert job_artifacts.checkpoints[0].identifier == "instructlab/granite-7b-lab" | 
| 149 |  | -    #     assert job_artifacts.checkpoints[0].epoch == 0 | 
| 150 |  | -    # assert "/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0" in job_artifacts.checkpoints[0].path | 
|  | 139 | +    @pytest.mark.asyncio | 
|  | 140 | +    def test_get_training_jobs(self, client_with_models): | 
|  | 141 | +        jobs_list = client_with_models.post_training.job.list() | 
|  | 142 | +        assert len(jobs_list) == 1 | 
|  | 143 | +        assert jobs_list[0].job_uuid == self.job_uuid | 
|  | 144 | + | 
|  | 145 | +    @pytest.mark.asyncio | 
|  | 146 | +    def test_get_training_job_status(self, client_with_models): | 
|  | 147 | +        job_status = client_with_models.post_training.job.status(job_uuid=self.job_uuid) | 
|  | 148 | +        assert job_status.job_uuid == self.job_uuid | 
|  | 149 | +        assert job_status.status == JobStatus.completed.value | 
|  | 150 | +        assert isinstance(job_status.resources_allocated, dict) | 
|  | 151 | +        self._validate_checkpoints(job_status.checkpoints) | 
|  | 152 | + | 
|  | 153 | +        assert job_status.scheduled_at is not None | 
|  | 154 | +        assert job_status.started_at is not None | 
|  | 155 | +        assert job_status.completed_at is not None | 
|  | 156 | + | 
|  | 157 | +        assert job_status.scheduled_at <= job_status.started_at | 
|  | 158 | +        assert job_status.started_at <= job_status.completed_at | 
|  | 159 | + | 
|  | 160 | +    @pytest.mark.asyncio | 
|  | 161 | +    def test_get_training_job_artifacts(self, client_with_models): | 
|  | 162 | +        job_artifacts = client_with_models.post_training.job.artifacts(job_uuid=self.job_uuid) | 
|  | 163 | +        assert job_artifacts.job_uuid == self.job_uuid | 
|  | 164 | +        self._validate_checkpoints(job_artifacts.checkpoints) | 
0 commit comments