Skip to content

Commit

Permalink
Add instruction and conversational data support (#211)
Browse files Browse the repository at this point in the history
* Add format checks

* add tests

* add train on inputs flag

* style

* PR feedback

* style

* more tests

* enhance logic

* enhance logic

* pr feedback part 1

* style and fixed

* pr feedback

* style

* style

* fix typing

* change to strict boolean

* error out on train_on_inputs

* use "auto" directly

* add system message

* version bump

* Update src/together/cli/api/finetune.py

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>

---------

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
  • Loading branch information
artek0chumak and mryab authored Nov 14, 2024
1 parent 296f2a5 commit e157fcd
Show file tree
Hide file tree
Showing 8 changed files with 509 additions and 47 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"

[tool.poetry]
name = "together"
version = "1.3.3"
version = "1.3.4"
authors = [
"Together AI <support@together.ai>"
]
Expand Down
39 changes: 21 additions & 18 deletions src/together/cli/api/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@
from tabulate import tabulate

from together import Together
from together.cli.api.utils import INT_WITH_MAX
from together.utils import finetune_price_to_dollars, log_warn, parse_timestamp
from together.cli.api.utils import BOOL_WITH_AUTO, INT_WITH_MAX
from together.utils import (
finetune_price_to_dollars,
log_warn,
log_warn_once,
parse_timestamp,
)
from together.types.finetune import DownloadCheckpointType, FinetuneTrainingLimits


Expand Down Expand Up @@ -93,6 +98,13 @@ def fine_tuning(ctx: click.Context) -> None:
default=False,
help="Whether to skip the launch confirmation message",
)
@click.option(
"--train-on-inputs",
type=BOOL_WITH_AUTO,
default="auto",
help="Whether to mask the user messages in conversational data or prompts in instruction data. "
"`auto` will automatically determine whether to mask the inputs based on the data format.",
)
def create(
ctx: click.Context,
training_file: str,
Expand All @@ -112,6 +124,7 @@ def create(
suffix: str,
wandb_api_key: str,
confirm: bool,
train_on_inputs: bool | Literal["auto"],
) -> None:
"""Start fine-tuning"""
client: Together = ctx.obj
Expand All @@ -133,6 +146,7 @@ def create(
lora_trainable_modules=lora_trainable_modules,
suffix=suffix,
wandb_api_key=wandb_api_key,
train_on_inputs=train_on_inputs,
)

model_limits: FinetuneTrainingLimits = client.fine_tuning.get_model_limits(
Expand All @@ -150,6 +164,10 @@ def create(
"batch_size": model_limits.lora_training.max_batch_size,
"learning_rate": 1e-3,
}
log_warn_once(
f"The default LoRA rank for {model} has been changed to {default_values['lora_r']} as the max available.\n"
f"Also, the default learning rate for LoRA fine-tuning has been changed to {default_values['learning_rate']}."
)
for arg in default_values:
arg_source = ctx.get_parameter_source("arg") # type: ignore[attr-defined]
if arg_source == ParameterSource.DEFAULT:
Expand Down Expand Up @@ -186,22 +204,7 @@ def create(

if confirm or click.confirm(_CONFIRMATION_MESSAGE, default=True, show_default=True):
response = client.fine_tuning.create(
training_file=training_file,
model=model,
n_epochs=n_epochs,
validation_file=validation_file,
n_evals=n_evals,
n_checkpoints=n_checkpoints,
batch_size=batch_size,
learning_rate=learning_rate,
warmup_ratio=warmup_ratio,
lora=lora,
lora_r=lora_r,
lora_dropout=lora_dropout,
lora_alpha=lora_alpha,
lora_trainable_modules=lora_trainable_modules,
suffix=suffix,
wandb_api_key=wandb_api_key,
**training_args,
verbose=True,
)

Expand Down
21 changes: 21 additions & 0 deletions src/together/cli/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,25 @@ def convert(
)


class BooleanWithAutoParamType(click.ParamType):
name = "boolean_or_auto"

def convert(
self, value: str, param: click.Parameter | None, ctx: click.Context | None
) -> bool | Literal["auto"] | None:
if value == "auto":
return "auto"
try:
return bool(value)
except ValueError:
self.fail(
_("{value!r} is not a valid {type}.").format(
value=value, type=self.name
),
param,
ctx,
)


INT_WITH_MAX = AutoIntParamType()
BOOL_WITH_AUTO = BooleanWithAutoParamType()
19 changes: 19 additions & 0 deletions src/together/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import enum

# Session constants
TIMEOUT_SECS = 600
MAX_SESSION_LIFETIME_SECS = 180
Expand Down Expand Up @@ -29,3 +31,20 @@

# expected columns for Parquet files
PARQUET_EXPECTED_COLUMNS = ["input_ids", "attention_mask", "labels"]


class DatasetFormat(enum.Enum):
"""Dataset format enum."""

GENERAL = "general"
CONVERSATION = "conversation"
INSTRUCTION = "instruction"


JSONL_REQUIRED_COLUMNS_MAP = {
DatasetFormat.GENERAL: ["text"],
DatasetFormat.CONVERSATION: ["messages"],
DatasetFormat.INSTRUCTION: ["prompt", "completion"],
}
REQUIRED_COLUMNS_MESSAGE = ["role", "content"]
POSSIBLE_ROLES_CONVERSATION = ["system", "user", "assistant"]
20 changes: 19 additions & 1 deletion src/together/resources/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def createFinetuneRequest(
lora_trainable_modules: str | None = "all-linear",
suffix: str | None = None,
wandb_api_key: str | None = None,
train_on_inputs: bool | Literal["auto"] = "auto",
) -> FinetuneRequest:
if batch_size == "max":
log_warn_once(
Expand Down Expand Up @@ -95,6 +96,7 @@ def createFinetuneRequest(
training_type=training_type,
suffix=suffix,
wandb_key=wandb_api_key,
train_on_inputs=train_on_inputs,
)

return finetune_request
Expand Down Expand Up @@ -125,6 +127,7 @@ def create(
wandb_api_key: str | None = None,
verbose: bool = False,
model_limits: FinetuneTrainingLimits | None = None,
train_on_inputs: bool | Literal["auto"] = "auto",
) -> FinetuneResponse:
"""
Method to initiate a fine-tuning job
Expand All @@ -137,7 +140,7 @@ def create(
n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
n_checkpoints (int, optional): Number of checkpoints to save during fine-tuning.
Defaults to 1.
batch_size (int, optional): Batch size for fine-tuning. Defaults to max.
batch_size (int or "max"): Batch size for fine-tuning. Defaults to max.
learning_rate (float, optional): Learning rate multiplier to use for training
Defaults to 0.00001.
warmup_ratio (float, optional): Warmup ratio for learning rate scheduler.
Expand All @@ -154,6 +157,12 @@ def create(
Defaults to False.
model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
Defaults to None.
train_on_inputs (bool or "auto"): Whether to mask the user messages in conversational data or prompts in instruction data.
"auto" will automatically determine whether to mask the inputs based on the data format.
For datasets with the "text" field (general format), inputs will not be masked.
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
(Instruction format), inputs will be masked.
Defaults to "auto".
Returns:
FinetuneResponse: Object containing information about fine-tuning job.
Expand Down Expand Up @@ -184,6 +193,7 @@ def create(
lora_trainable_modules=lora_trainable_modules,
suffix=suffix,
wandb_api_key=wandb_api_key,
train_on_inputs=train_on_inputs,
)

if verbose:
Expand Down Expand Up @@ -436,6 +446,7 @@ async def create(
wandb_api_key: str | None = None,
verbose: bool = False,
model_limits: FinetuneTrainingLimits | None = None,
train_on_inputs: bool | Literal["auto"] = "auto",
) -> FinetuneResponse:
"""
Async method to initiate a fine-tuning job
Expand Down Expand Up @@ -465,6 +476,12 @@ async def create(
Defaults to False.
model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
Defaults to None.
train_on_inputs (bool or "auto"): Whether to mask the user messages in conversational data or prompts in instruction data.
"auto" will automatically determine whether to mask the inputs based on the data format.
For datasets with the "text" field (general format), inputs will not be masked.
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
(Instruction format), inputs will be masked.
Defaults to "auto".
Returns:
FinetuneResponse: Object containing information about fine-tuning job.
Expand Down Expand Up @@ -495,6 +512,7 @@ async def create(
lora_trainable_modules=lora_trainable_modules,
suffix=suffix,
wandb_api_key=wandb_api_key,
train_on_inputs=train_on_inputs,
)

if verbose:
Expand Down
4 changes: 3 additions & 1 deletion src/together/types/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from enum import Enum
from typing import List, Literal

from pydantic import Field, validator, field_validator
from pydantic import StrictBool, Field, validator, field_validator

from together.types.abstract import BaseModel
from together.types.common import (
Expand Down Expand Up @@ -163,6 +163,7 @@ class FinetuneRequest(BaseModel):
# weights & biases api key
wandb_key: str | None = None
training_type: FullTrainingType | LoRATrainingType | None = None
train_on_inputs: StrictBool | Literal["auto"] = "auto"


class FinetuneResponse(BaseModel):
Expand Down Expand Up @@ -230,6 +231,7 @@ class FinetuneResponse(BaseModel):
# training file metadata
training_file_num_lines: int | None = Field(None, alias="TrainingFileNumLines")
training_file_size: int | None = Field(None, alias="TrainingFileSize")
train_on_inputs: StrictBool | Literal["auto"] | None = "auto"

@field_validator("training_type")
@classmethod
Expand Down
Loading

0 comments on commit e157fcd

Please sign in to comment.