From 0ea183f09020d623a4113238a653b5f9cd07a411 Mon Sep 17 00:00:00 2001 From: connermanuel Date: Thu, 17 Apr 2025 11:18:05 -0700 Subject: [PATCH 1/3] migrate to sft_on_inputs, and change defaults to match --- src/together/cli/api/finetune.py | 20 ++++++++++---------- src/together/resources/finetune.py | 26 ++++++++++++++++++-------- src/together/types/finetune.py | 5 ++--- 3 files changed, 30 insertions(+), 21 deletions(-) diff --git a/src/together/cli/api/finetune.py b/src/together/cli/api/finetune.py index 751fe6a..4d57b73 100644 --- a/src/together/cli/api/finetune.py +++ b/src/together/cli/api/finetune.py @@ -1,10 +1,10 @@ from __future__ import annotations import json +import re from datetime import datetime, timezone from textwrap import wrap from typing import Any, Literal -import re import click from click.core import ParameterSource # type: ignore[attr-defined] @@ -13,17 +13,17 @@ from together import Together from together.cli.api.utils import BOOL_WITH_AUTO, INT_WITH_MAX +from together.types.finetune import ( + DownloadCheckpointType, + FinetuneEventType, + FinetuneTrainingLimits, +) from together.utils import ( finetune_price_to_dollars, + format_timestamp, log_warn, log_warn_once, parse_timestamp, - format_timestamp, -) -from together.types.finetune import ( - DownloadCheckpointType, - FinetuneTrainingLimits, - FinetuneEventType, ) @@ -348,9 +348,9 @@ def list(ctx: click.Context) -> None: "Model Output Name": "\n".join(wrap(i.output_name or "", width=30)), "Status": i.status, "Created At": i.created_at, - "Price": f"""${finetune_price_to_dollars( - float(str(i.total_price)) - )}""", # convert to string for mypy typing + "Price": f"""${ + finetune_price_to_dollars(float(str(i.total_price))) + }""", # convert to string for mypy typing } ) table = tabulate(display_list, headers="keys", tablefmt="grid", showindex=True) diff --git a/src/together/resources/finetune.py b/src/together/resources/finetune.py index 275d683..132d46e 100644 --- a/src/together/resources/finetune.py +++ b/src/together/resources/finetune.py @@ -77,7 +77,7 @@ def create_finetune_request( wandb_base_url: str | None = None, wandb_project_name: str | None = None, wandb_name: str | None = None, - train_on_inputs: bool | Literal["auto"] = "auto", + train_on_inputs: bool | Literal["auto"] | None = None, training_method: str = "sft", dpo_beta: float | None = None, from_checkpoint: str | None = None, @@ -174,6 +174,15 @@ def create_finetune_request( f"training_method must be one of {', '.join(AVAILABLE_TRAINING_METHODS)}" ) + if train_on_inputs is not None and training_method != "sft": + raise ValueError("train_on_inputs is only supported for SFT training") + + if train_on_inputs is None and training_method == "sft": + log_warn_once( + "train_on_inputs is not set for SFT training, it will be set to 'auto' automatically" + ) + train_on_inputs = "auto" + lr_scheduler: FinetuneLRScheduler if lr_scheduler_type == "cosine": if scheduler_num_cycles <= 0.0: @@ -191,7 +200,9 @@ def create_finetune_request( lr_scheduler_args=LinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio), ) - training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT() + training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT( + train_on_inputs=train_on_inputs + ) if training_method == "dpo": training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta) @@ -214,7 +225,6 @@ def create_finetune_request( wandb_base_url=wandb_base_url, wandb_project_name=wandb_project_name, wandb_name=wandb_name, - train_on_inputs=train_on_inputs, training_method=training_method_cls, from_checkpoint=from_checkpoint, ) @@ -319,7 +329,7 @@ def create( wandb_name: str | None = None, verbose: bool = False, model_limits: FinetuneTrainingLimits | None = None, - train_on_inputs: bool | Literal["auto"] = "auto", + train_on_inputs: bool | Literal["auto"] | None = None, training_method: str = "sft", dpo_beta: float | None = None, from_checkpoint: str | None = None, @@ -364,12 +374,12 @@ def create( Defaults to False. model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning. Defaults to None. - train_on_inputs (bool or "auto"): Whether to mask the user messages in conversational data or prompts in instruction data. + train_on_inputs (bool or "auto", optional): Whether to mask the user messages in conversational data or prompts in instruction data. "auto" will automatically determine whether to mask the inputs based on the data format. For datasets with the "text" field (general format), inputs will not be masked. For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields (Instruction format), inputs will be masked. - Defaults to "auto". + Defaults to None, or "auto" if training_method is "sft" (set in create_finetune_request). training_method (str, optional): Training method. Defaults to "sft". Supported methods: "sft", "dpo". dpo_beta (float, optional): DPO beta parameter. Defaults to None. @@ -707,7 +717,7 @@ async def create( wandb_name: str | None = None, verbose: bool = False, model_limits: FinetuneTrainingLimits | None = None, - train_on_inputs: bool | Literal["auto"] = "auto", + train_on_inputs: bool | Literal["auto"] | None = None, training_method: str = "sft", dpo_beta: float | None = None, from_checkpoint: str | None = None, @@ -757,7 +767,7 @@ async def create( For datasets with the "text" field (general format), inputs will not be masked. For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields (Instruction format), inputs will be masked. - Defaults to "auto". + Defaults to None, or "auto" if training_method is "sft" (set in create_finetune_request). training_method (str, optional): Training method. Defaults to "sft". Supported methods: "sft", "dpo". dpo_beta (float, optional): DPO beta parameter. Defaults to None. diff --git a/src/together/types/finetune.py b/src/together/types/finetune.py index 5c2c2c2..07ee65e 100644 --- a/src/together/types/finetune.py +++ b/src/together/types/finetune.py @@ -3,7 +3,7 @@ from enum import Enum from typing import List, Literal, Any -from pydantic import StrictBool, Field, field_validator +from pydantic import Field, StrictBool, field_validator from together.types.abstract import BaseModel from together.types.common import ( @@ -149,6 +149,7 @@ class TrainingMethodSFT(TrainingMethod): """ method: Literal["sft"] = "sft" + train_on_inputs: StrictBool | Literal["auto"] = "auto" class TrainingMethodDPO(TrainingMethod): @@ -201,8 +202,6 @@ class FinetuneRequest(BaseModel): wandb_name: str | None = None # training type training_type: FullTrainingType | LoRATrainingType | None = None - # train on inputs - train_on_inputs: StrictBool | Literal["auto"] = "auto" # training method training_method: TrainingMethodSFT | TrainingMethodDPO = Field( default_factory=TrainingMethodSFT From 2cbd4c025e073d32726df24353d1347c1850c4d3 Mon Sep 17 00:00:00 2001 From: connermanuel Date: Thu, 17 Apr 2025 11:23:05 -0700 Subject: [PATCH 2/3] add validation to dpo_beta --- src/together/resources/finetune.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/together/resources/finetune.py b/src/together/resources/finetune.py index 132d46e..f6e124f 100644 --- a/src/together/resources/finetune.py +++ b/src/together/resources/finetune.py @@ -183,6 +183,9 @@ def create_finetune_request( ) train_on_inputs = "auto" + if dpo_beta is not None and training_method != "dpo": + raise ValueError("dpo_beta is only supported for DPO training") + lr_scheduler: FinetuneLRScheduler if lr_scheduler_type == "cosine": if scheduler_num_cycles <= 0.0: From 1ce13d65fb3f0af94ae13f79e8960dd39c9db11d Mon Sep 17 00:00:00 2001 From: connermanuel Date: Mon, 21 Apr 2025 16:06:09 -0700 Subject: [PATCH 3/3] tests --- src/together/resources/finetune.py | 8 ++++---- tests/unit/test_finetune_resources.py | 29 +++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/src/together/resources/finetune.py b/src/together/resources/finetune.py index f6e124f..29f1005 100644 --- a/src/together/resources/finetune.py +++ b/src/together/resources/finetune.py @@ -203,10 +203,10 @@ def create_finetune_request( lr_scheduler_args=LinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio), ) - training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT( - train_on_inputs=train_on_inputs - ) - if training_method == "dpo": + training_method_cls: TrainingMethodSFT | TrainingMethodDPO + if training_method == "sft": + training_method_cls = TrainingMethodSFT(train_on_inputs=train_on_inputs) + elif training_method == "dpo": training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta) finetune_request = FinetuneRequest( diff --git a/tests/unit/test_finetune_resources.py b/tests/unit/test_finetune_resources.py index 7ff6941..f354b2f 100644 --- a/tests/unit/test_finetune_resources.py +++ b/tests/unit/test_finetune_resources.py @@ -281,3 +281,32 @@ def test_bad_training_method(): training_file=_TRAINING_FILE, training_method="NON_SFT", ) + + +@pytest.mark.parametrize("train_on_inputs", [True, False, "auto", None]) +def test_train_on_inputs_for_sft(train_on_inputs): + request = create_finetune_request( + model_limits=_MODEL_LIMITS, + model=_MODEL_NAME, + training_file=_TRAINING_FILE, + training_method="sft", + train_on_inputs=train_on_inputs, + ) + assert request.training_method.method == "sft" + if isinstance(train_on_inputs, bool): + assert request.training_method.train_on_inputs is train_on_inputs + else: + assert request.training_method.train_on_inputs == "auto" + + +def test_train_on_inputs_not_supported_for_dpo(): + with pytest.raises( + ValueError, match="train_on_inputs is only supported for SFT training" + ): + _ = create_finetune_request( + model_limits=_MODEL_LIMITS, + model=_MODEL_NAME, + training_file=_TRAINING_FILE, + training_method="dpo", + train_on_inputs=True, + )