Skip to content

Commit 667a05f

Browse files
committed
fix(tests): enable post-training tests
Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
1 parent b64c6f9 commit 667a05f

File tree

1 file changed

+50
-36
lines changed

1 file changed

+50
-36
lines changed

tests/integration/post_training/test_post_training.py

Lines changed: 50 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
# the root directory of this source tree.
66

77
import logging
8+
import os
89
import sys
910
import time
1011
import uuid
12+
from datetime import datetime, timezone
1113

1214
import pytest
1315

16+
from llama_stack.apis.common.job_types import JobStatus
1417
from llama_stack.apis.post_training import (
1518
DataConfig,
1619
LoraFinetuningConfig,
@@ -44,6 +47,15 @@ def capture_output(capsys):
4447

4548

4649
class TestPostTraining:
50+
job_uuid = f"test-job{uuid.uuid4()}"
51+
model = "ibm-granite/granite-3.3-2b-instruct"
52+
53+
def _validate_checkpoints(self, checkpoints):
54+
assert len(checkpoints) == 1
55+
assert checkpoints[0]["identifier"] == f"{self.model}-sft-1"
56+
assert checkpoints[0]["epoch"] == 1
57+
assert "/.llama/checkpoints/merged_model" in checkpoints[0]["path"]
58+
4759
@pytest.mark.integration
4860
@pytest.mark.parametrize(
4961
"purpose, source",
@@ -91,60 +103,62 @@ def test_supervised_fine_tune(self, llama_stack_client, purpose, source):
91103
gradient_accumulation_steps=1,
92104
)
93105

94-
job_uuid = f"test-job{uuid.uuid4()}"
95-
logger.info(f"Starting training job with UUID: {job_uuid}")
106+
logger.info(f"Starting training job with UUID: {self.job_uuid}")
96107

97108
# train with HF trl SFTTrainer as the default
109+
os.makedirs("~/.llama/checkpoints/", exist_ok=True)
110+
111+
started = datetime.now(timezone.utc)
98112
_ = llama_stack_client.post_training.supervised_fine_tune(
99-
job_uuid=job_uuid,
100-
model="ibm-granite/granite-3.3-2b-instruct",
113+
job_uuid=self.job_uuid,
114+
model=self.model,
101115
algorithm_config=algorithm_config,
102116
training_config=training_config,
103117
hyperparam_search_config={},
104118
logger_config={},
105-
checkpoint_dir=None,
119+
checkpoint_dir="~/.llama/checkpoints/",
106120
)
107121

108122
while True:
109-
status = llama_stack_client.post_training.job.status(job_uuid=job_uuid)
123+
status = llama_stack_client.post_training.job.status(job_uuid=self.job_uuid)
110124
if not status:
111125
logger.error("Job not found")
112126
break
113127

114128
logger.info(f"Current status: {status}")
115129
if status.status == "completed":
130+
completed = datetime.now(timezone.utc)
131+
assert status.completed_at is not None
132+
assert status.completed_at >= started
133+
assert status.completed_at <= completed
116134
break
117135

118136
logger.info("Waiting for job to complete...")
119137
time.sleep(10) # Increased sleep time to reduce polling frequency
120138

121-
artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid)
122-
logger.info(f"Job artifacts: {artifacts}")
123-
124-
# TODO: Fix these tests to properly represent the Jobs API in training
125-
# @pytest.mark.asyncio
126-
# async def test_get_training_jobs(self, post_training_stack):
127-
# post_training_impl = post_training_stack
128-
# jobs_list = await post_training_impl.get_training_jobs()
129-
# assert isinstance(jobs_list, list)
130-
# assert jobs_list[0].job_uuid == "1234"
131-
132-
# @pytest.mark.asyncio
133-
# async def test_get_training_job_status(self, post_training_stack):
134-
# post_training_impl = post_training_stack
135-
# job_status = await post_training_impl.get_training_job_status("1234")
136-
# assert isinstance(job_status, PostTrainingJobStatusResponse)
137-
# assert job_status.job_uuid == "1234"
138-
# assert job_status.status == JobStatus.completed
139-
# assert isinstance(job_status.checkpoints[0], Checkpoint)
140-
141-
# @pytest.mark.asyncio
142-
# async def test_get_training_job_artifacts(self, post_training_stack):
143-
# post_training_impl = post_training_stack
144-
# job_artifacts = await post_training_impl.get_training_job_artifacts("1234")
145-
# assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse)
146-
# assert job_artifacts.job_uuid == "1234"
147-
# assert isinstance(job_artifacts.checkpoints[0], Checkpoint)
148-
# assert job_artifacts.checkpoints[0].identifier == "instructlab/granite-7b-lab"
149-
# assert job_artifacts.checkpoints[0].epoch == 0
150-
# assert "/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0" in job_artifacts.checkpoints[0].path
139+
@pytest.mark.asyncio
140+
def test_get_training_jobs(self, client_with_models):
141+
jobs_list = client_with_models.post_training.job.list()
142+
assert len(jobs_list) == 1
143+
assert jobs_list[0].job_uuid == self.job_uuid
144+
145+
@pytest.mark.asyncio
146+
def test_get_training_job_status(self, client_with_models):
147+
job_status = client_with_models.post_training.job.status(job_uuid=self.job_uuid)
148+
assert job_status.job_uuid == self.job_uuid
149+
assert job_status.status == JobStatus.completed.value
150+
assert isinstance(job_status.resources_allocated, dict)
151+
self._validate_checkpoints(job_status.checkpoints)
152+
153+
assert job_status.scheduled_at is not None
154+
assert job_status.started_at is not None
155+
assert job_status.completed_at is not None
156+
157+
assert job_status.scheduled_at <= job_status.started_at
158+
assert job_status.started_at <= job_status.completed_at
159+
160+
@pytest.mark.asyncio
161+
def test_get_training_job_artifacts(self, client_with_models):
162+
job_artifacts = client_with_models.post_training.job.artifacts(job_uuid=self.job_uuid)
163+
assert job_artifacts.job_uuid == self.job_uuid
164+
self._validate_checkpoints(job_artifacts.checkpoints)

0 commit comments

Comments
 (0)