Skip to content
Merged
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"

[tool.poetry]
name = "together"
version = "1.5.8"
version = "1.5.9"
authors = ["Together AI <support@together.ai>"]
description = "Python client for Together's Cloud Platform!"
readme = "README.md"
Expand Down
25 changes: 18 additions & 7 deletions src/together/resources/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import re
from pathlib import Path
from typing import Dict, List, Literal
from typing import List, Dict, Literal

from rich import print as rprint

Expand Down Expand Up @@ -545,7 +545,7 @@ def download(
*,
output: Path | str | None = None,
checkpoint_step: int | None = None,
checkpoint_type: DownloadCheckpointType = DownloadCheckpointType.DEFAULT,
checkpoint_type: DownloadCheckpointType | str = DownloadCheckpointType.DEFAULT,
) -> FinetuneDownloadResult:
"""
Downloads compressed fine-tuned model or checkpoint to local disk.
Expand All @@ -558,7 +558,7 @@ def download(
Defaults to None.
checkpoint_step (int, optional): Specifies step number for checkpoint to download.
Defaults to -1 (download the final model)
checkpoint_type (CheckpointType, optional): Specifies which checkpoint to download.
checkpoint_type (CheckpointType | str, optional): Specifies which checkpoint to download.
Defaults to CheckpointType.DEFAULT.

Returns:
Expand All @@ -582,6 +582,16 @@ def download(

ft_job = self.retrieve(id)

# convert str to DownloadCheckpointType
if isinstance(checkpoint_type, str):
try:
checkpoint_type = DownloadCheckpointType(checkpoint_type.lower())
except ValueError:
enum_strs = ", ".join(e.value for e in DownloadCheckpointType)
raise ValueError(
f"Invalid checkpoint type: {checkpoint_type}. Choose one of {{{enum_strs}}}."
)

if isinstance(ft_job.training_type, FullTrainingType):
if checkpoint_type != DownloadCheckpointType.DEFAULT:
raise ValueError(
Expand All @@ -592,10 +602,11 @@ def download(
if checkpoint_type == DownloadCheckpointType.DEFAULT:
checkpoint_type = DownloadCheckpointType.MERGED

if checkpoint_type == DownloadCheckpointType.MERGED:
url += f"&checkpoint={DownloadCheckpointType.MERGED.value}"
elif checkpoint_type == DownloadCheckpointType.ADAPTER:
url += f"&checkpoint={DownloadCheckpointType.ADAPTER.value}"
if checkpoint_type in {
DownloadCheckpointType.MERGED,
DownloadCheckpointType.ADAPTER,
}:
url += f"&checkpoint={checkpoint_type.value}"
else:
raise ValueError(
f"Invalid checkpoint type for LoRATrainingType: {checkpoint_type}"
Expand Down