diff --git a/comps/cores/proto/api_protocol.py b/comps/cores/proto/api_protocol.py index 2b2481067..54a99cfdb 100644 --- a/comps/cores/proto/api_protocol.py +++ b/comps/cores/proto/api_protocol.py @@ -797,3 +797,45 @@ class FileObject(BaseModel): Supported values are assistants, assistants_output, batch, batch_output, fine-tune, fine-tune-results and vision. """ + + +class Metrics(BaseModel): + full_valid_loss: Optional[float] = None + + full_valid_mean_token_accuracy: Optional[float] = None + + step: Optional[float] = None + + train_loss: Optional[float] = None + + train_mean_token_accuracy: Optional[float] = None + + valid_loss: Optional[float] = None + + valid_mean_token_accuracy: Optional[float] = None + + +class FineTuningJobCheckpoint(BaseModel): + id: str + """The checkpoint identifier, which can be referenced in the API endpoints.""" + + created_at: int + """The Unix timestamp (in seconds) for when the checkpoint was created.""" + + fine_tuned_model_checkpoint: str + """The name of the fine-tuned checkpoint model that is created.""" + + fine_tuning_job_id: str + """The name of the fine-tuning job that this checkpoint was created from.""" + + fine_tuning_job_id: str + """The name of the fine-tuning job that this checkpoint was created from.""" + + metrics: Optional[Metrics] = None + """Metrics at the step number during the fine-tuning job.""" + + object: Literal["fine_tuning.job.checkpoint"] + """The object type, which is always "fine_tuning.job.checkpoint".""" + + step_number: Optional[int] = None + """The step number that the checkpoint was created at.""" diff --git a/comps/finetuning/finetuning_service.py b/comps/finetuning/finetuning_service.py index 958b08acb..64097c720 100644 --- a/comps/finetuning/finetuning_service.py +++ b/comps/finetuning/finetuning_service.py @@ -60,7 +60,7 @@ async def upload_training_files(request: UploadFileRequest = Depends(upload_file ) def list_checkpoints(request: FineTuningJobIDRequest): checkpoints = handle_list_finetuning_checkpoints(request) - return {"status": 200, "checkpoints": str(checkpoints)} + return checkpoints if __name__ == "__main__": diff --git a/comps/finetuning/handlers.py b/comps/finetuning/handlers.py index b97231485..ddae0726f 100644 --- a/comps/finetuning/handlers.py +++ b/comps/finetuning/handlers.py @@ -17,6 +17,7 @@ from comps.cores.proto.api_protocol import ( FileObject, FineTuningJob, + FineTuningJobCheckpoint, FineTuningJobIDRequest, FineTuningJobList, FineTuningJobsRequest, @@ -38,6 +39,9 @@ os.mkdir(OUTPUT_DIR) FineTuningJobID = str +CheckpointID = str +CheckpointPath = str + CHECK_JOB_STATUS_INTERVAL = 5 # Check every 5 secs global ray_client @@ -45,6 +49,7 @@ running_finetuning_jobs: Dict[FineTuningJobID, FineTuningJob] = {} finetuning_job_to_ray_job: Dict[FineTuningJobID, str] = {} +checkpoint_id_to_checkpoint_path: Dict[CheckpointID, CheckpointPath] = {} # Add a background task to periodicly update job status @@ -117,8 +122,6 @@ def handle_create_finetuning_jobs(request: FineTuningParams, background_tasks: B ray_job_id = ray_client.submit_job( # Entrypoint shell command to execute entrypoint=f"python finetune_runner.py --config_file {finetune_config_file}", - # Path to the local directory that contains the script.py file - runtime_env={"working_dir": "./", "excludes": [f"{OUTPUT_DIR}"]}, ) logger.info(f"Submitted Ray job: {ray_job_id} ...") @@ -183,10 +186,21 @@ def handle_list_finetuning_checkpoints(request: FineTuningJobIDRequest): job = running_finetuning_jobs.get(fine_tuning_job_id) if job is None: raise HTTPException(status_code=404, detail=f"Fine-tuning job '{fine_tuning_job_id}' not found!") - output_dir = os.path.join(JOBS_PATH, job.id) + output_dir = os.path.join(OUTPUT_DIR, job.id) checkpoints = [] if os.path.exists(output_dir): - checkpoints = os.listdir(output_dir) + # Iterate over the contents of the directory and add an entry for each + for _ in os.listdir(output_dir): # Loop over directory contents + checkpointsResponse = FineTuningJobCheckpoint( + id=f"ftckpt-{uuid.uuid4()}", # Generate a unique ID + created_at=int(time.time()), # Use the current timestamp + fine_tuned_model_checkpoint=output_dir, # Directory path itself + fine_tuning_job_id=fine_tuning_job_id, + object="fine_tuning.job.checkpoint", + ) + checkpoints.append(checkpointsResponse) + checkpoint_id_to_checkpoint_path[checkpointsResponse.id] = checkpointsResponse.fine_tuned_model_checkpoint + return checkpoints