Skip to content

Accept strings for checkpoint type on download #308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions src/together/resources/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import re
from pathlib import Path
from typing import List, Literal
from typing import List, Literal, Union

from rich import print as rprint

Expand Down Expand Up @@ -570,7 +570,9 @@ def download(
*,
output: Path | str | None = None,
checkpoint_step: int | None = None,
checkpoint_type: DownloadCheckpointType = DownloadCheckpointType.DEFAULT,
checkpoint_type: Union[
DownloadCheckpointType, str
] = DownloadCheckpointType.DEFAULT,
) -> FinetuneDownloadResult:
"""
Downloads compressed fine-tuned model or checkpoint to local disk.
Expand All @@ -583,7 +585,7 @@ def download(
Defaults to None.
checkpoint_step (int, optional): Specifies step number for checkpoint to download.
Defaults to -1 (download the final model)
checkpoint_type (CheckpointType, optional): Specifies which checkpoint to download.
checkpoint_type (Union[CheckpointType, str], optional): Specifies which checkpoint to download.
Defaults to CheckpointType.DEFAULT.

Returns:
Expand All @@ -607,6 +609,16 @@ def download(

ft_job = self.retrieve(id)

# convert str to DownloadCheckpointType
if isinstance(checkpoint_type, str):
try:
checkpoint_type = DownloadCheckpointType(checkpoint_type.lower())
except ValueError:
enum_strs = ", ".join([e.value for e in DownloadCheckpointType])
raise ValueError(
f"Invalid checkpoint type: {checkpoint_type}. Choose one of {{{enum_strs}}}."
)

if isinstance(ft_job.training_type, FullTrainingType):
if checkpoint_type != DownloadCheckpointType.DEFAULT:
raise ValueError(
Expand All @@ -617,10 +629,11 @@ def download(
if checkpoint_type == DownloadCheckpointType.DEFAULT:
checkpoint_type = DownloadCheckpointType.MERGED

if checkpoint_type == DownloadCheckpointType.MERGED:
url += f"&checkpoint={DownloadCheckpointType.MERGED.value}"
elif checkpoint_type == DownloadCheckpointType.ADAPTER:
url += f"&checkpoint={DownloadCheckpointType.ADAPTER.value}"
if checkpoint_type in {
DownloadCheckpointType.MERGED,
DownloadCheckpointType.ADAPTER,
}:
url += f"&checkpoint={checkpoint_type.value}"
else:
raise ValueError(
f"Invalid checkpoint type for LoRATrainingType: {checkpoint_type}"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this error will never be reached because of the check in lines 614:620, so we can delete this.

Expand Down