Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ jobs:
fail-fast: false # we want to run all tests regardless of failure

steps:
# Huggingface trainer provider will download a model to train and save
# checkpoints, so need more space than other flows for training.
- name: Free disk space
if: ${{ matrix.test-type == 'post_training' }}
uses: jlumbroso/free-disk-space@54081f138730dfa15788a46383842cd2f914a1be # v1.3.1

- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2

Expand Down Expand Up @@ -85,6 +91,7 @@ jobs:
echo "Ollama health check failed"
exit 1
fi

- name: Check Storage and Memory Available Before Tests
if: ${{ always() }}
run: |
Expand Down
87 changes: 51 additions & 36 deletions tests/integration/post_training/test_post_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
# the root directory of this source tree.

import logging
import os
import sys
import time
import uuid
from datetime import datetime, timezone

import pytest

from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.post_training import (
DataConfig,
LoraFinetuningConfig,
Expand Down Expand Up @@ -44,6 +47,15 @@ def capture_output(capsys):


class TestPostTraining:
job_uuid = f"test-job{uuid.uuid4()}"
model = "ibm-granite/granite-3.3-2b-instruct"

def _validate_checkpoints(self, checkpoints):
assert len(checkpoints) == 1
assert checkpoints[0]["identifier"] == f"{self.model}-sft-1"
assert checkpoints[0]["epoch"] == 1
assert "/.llama/checkpoints/merged_model" in checkpoints[0]["path"]

@pytest.mark.integration
@pytest.mark.parametrize(
"purpose, source",
Expand Down Expand Up @@ -92,60 +104,63 @@ def test_supervised_fine_tune(self, llama_stack_client, purpose, source):
gradient_accumulation_steps=1,
)

job_uuid = f"test-job{uuid.uuid4()}"
logger.info(f"Starting training job with UUID: {job_uuid}")
logger.info(f"Starting training job with UUID: {self.job_uuid}")

# train with HF trl SFTTrainer as the default
checkpoint_dir = os.path.expanduser("/mnt/")
# os.makedirs(checkpoint_dir, exist_ok=True)

started = datetime.now(timezone.utc)
_ = llama_stack_client.post_training.supervised_fine_tune(
job_uuid=job_uuid,
model="ibm-granite/granite-3.3-2b-instruct",
job_uuid=self.job_uuid,
model=self.model,
algorithm_config=algorithm_config,
training_config=training_config,
hyperparam_search_config={},
logger_config={},
checkpoint_dir=None,
checkpoint_dir=checkpoint_dir,
)

while True:
status = llama_stack_client.post_training.job.status(job_uuid=job_uuid)
status = llama_stack_client.post_training.job.status(job_uuid=self.job_uuid)
if not status:
logger.error("Job not found")
break

logger.info(f"Current status: {status}")
if status.status == "completed":
completed = datetime.now(timezone.utc)
assert status.completed_at is not None
assert status.completed_at >= started
assert status.completed_at <= completed
break

logger.info("Waiting for job to complete...")
time.sleep(10) # Increased sleep time to reduce polling frequency

artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid)
logger.info(f"Job artifacts: {artifacts}")

# TODO: Fix these tests to properly represent the Jobs API in training
# @pytest.mark.asyncio
# async def test_get_training_jobs(self, post_training_stack):
# post_training_impl = post_training_stack
# jobs_list = await post_training_impl.get_training_jobs()
# assert isinstance(jobs_list, list)
# assert jobs_list[0].job_uuid == "1234"

# @pytest.mark.asyncio
# async def test_get_training_job_status(self, post_training_stack):
# post_training_impl = post_training_stack
# job_status = await post_training_impl.get_training_job_status("1234")
# assert isinstance(job_status, PostTrainingJobStatusResponse)
# assert job_status.job_uuid == "1234"
# assert job_status.status == JobStatus.completed
# assert isinstance(job_status.checkpoints[0], Checkpoint)

# @pytest.mark.asyncio
# async def test_get_training_job_artifacts(self, post_training_stack):
# post_training_impl = post_training_stack
# job_artifacts = await post_training_impl.get_training_job_artifacts("1234")
# assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse)
# assert job_artifacts.job_uuid == "1234"
# assert isinstance(job_artifacts.checkpoints[0], Checkpoint)
# assert job_artifacts.checkpoints[0].identifier == "instructlab/granite-7b-lab"
# assert job_artifacts.checkpoints[0].epoch == 0
# assert "/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0" in job_artifacts.checkpoints[0].path
@pytest.mark.asyncio
def test_get_training_jobs(self, client_with_models):
jobs_list = client_with_models.post_training.job.list()
assert len(jobs_list) == 1
assert jobs_list[0].job_uuid == self.job_uuid

@pytest.mark.asyncio
def test_get_training_job_status(self, client_with_models):
job_status = client_with_models.post_training.job.status(job_uuid=self.job_uuid)
assert job_status.job_uuid == self.job_uuid
assert job_status.status == JobStatus.completed.value
assert isinstance(job_status.resources_allocated, dict)
self._validate_checkpoints(job_status.checkpoints)

assert job_status.scheduled_at is not None
assert job_status.started_at is not None
assert job_status.completed_at is not None

assert job_status.scheduled_at <= job_status.started_at
assert job_status.started_at <= job_status.completed_at

@pytest.mark.asyncio
def test_get_training_job_artifacts(self, client_with_models):
job_artifacts = client_with_models.post_training.job.artifacts(job_uuid=self.job_uuid)
assert job_artifacts.job_uuid == self.job_uuid
self._validate_checkpoints(job_artifacts.checkpoints)
Loading