Skip to content

Commit 0751a96

Browse files
authored
feat: make training config fields optional (#1861)
# What does this PR do? Today, supervised_fine_tune itself and the `TrainingConfig` class have a bunch of required fields that a provider implementation might not need. for example, if a provider wants to handle hyperparameters in its configuration as well as any type of dataset retrieval, optimizer or LoRA config, a user will still need to pass in a virtually empty `DataConfig`, `OptimizerConfig` and `AlgorithmConfig` in some cases. Many of these fields are intended to work specifically with llama models and knobs intended for customizing inline. Adding remote post_training providers will require loosening these arguments, or forcing users to pass in empty objects to satisfy the pydantic models. Signed-off-by: Charlie Doern <cdoern@redhat.com>
1 parent 70a7e4d commit 0751a96

File tree

4 files changed

+29
-21
lines changed

4 files changed

+29
-21
lines changed

docs/_static/llama-stack-spec.html

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9778,13 +9778,16 @@
97789778
"type": "integer"
97799779
},
97809780
"max_steps_per_epoch": {
9781-
"type": "integer"
9781+
"type": "integer",
9782+
"default": 1
97829783
},
97839784
"gradient_accumulation_steps": {
9784-
"type": "integer"
9785+
"type": "integer",
9786+
"default": 1
97859787
},
97869788
"max_validation_steps": {
9787-
"type": "integer"
9789+
"type": "integer",
9790+
"default": 1
97889791
},
97899792
"data_config": {
97909793
"$ref": "#/components/schemas/DataConfig"
@@ -9804,10 +9807,7 @@
98049807
"required": [
98059808
"n_epochs",
98069809
"max_steps_per_epoch",
9807-
"gradient_accumulation_steps",
9808-
"max_validation_steps",
9809-
"data_config",
9810-
"optimizer_config"
9810+
"gradient_accumulation_steps"
98119811
],
98129812
"title": "TrainingConfig"
98139813
},
@@ -10983,8 +10983,7 @@
1098310983
"job_uuid",
1098410984
"training_config",
1098510985
"hyperparam_search_config",
10986-
"logger_config",
10987-
"model"
10986+
"logger_config"
1098810987
],
1098910988
"title": "SupervisedFineTuneRequest"
1099010989
},

docs/_static/llama-stack-spec.yaml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6744,10 +6744,13 @@ components:
67446744
type: integer
67456745
max_steps_per_epoch:
67466746
type: integer
6747+
default: 1
67476748
gradient_accumulation_steps:
67486749
type: integer
6750+
default: 1
67496751
max_validation_steps:
67506752
type: integer
6753+
default: 1
67516754
data_config:
67526755
$ref: '#/components/schemas/DataConfig'
67536756
optimizer_config:
@@ -6762,9 +6765,6 @@ components:
67626765
- n_epochs
67636766
- max_steps_per_epoch
67646767
- gradient_accumulation_steps
6765-
- max_validation_steps
6766-
- data_config
6767-
- optimizer_config
67686768
title: TrainingConfig
67696769
PreferenceOptimizeRequest:
67706770
type: object
@@ -7498,7 +7498,6 @@ components:
74987498
- training_config
74997499
- hyperparam_search_config
75007500
- logger_config
7501-
- model
75027501
title: SupervisedFineTuneRequest
75037502
SyntheticDataGenerateRequest:
75047503
type: object

llama_stack/apis/post_training/post_training.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,11 @@ class EfficiencyConfig(BaseModel):
6060
@json_schema_type
6161
class TrainingConfig(BaseModel):
6262
n_epochs: int
63-
max_steps_per_epoch: int
64-
gradient_accumulation_steps: int
65-
max_validation_steps: int
66-
data_config: DataConfig
67-
optimizer_config: OptimizerConfig
63+
max_steps_per_epoch: int = 1
64+
gradient_accumulation_steps: int = 1
65+
max_validation_steps: Optional[int] = 1
66+
data_config: Optional[DataConfig] = None
67+
optimizer_config: Optional[OptimizerConfig] = None
6868
efficiency_config: Optional[EfficiencyConfig] = None
6969
dtype: Optional[str] = "bf16"
7070

@@ -177,9 +177,9 @@ async def supervised_fine_tune(
177177
training_config: TrainingConfig,
178178
hyperparam_search_config: Dict[str, Any],
179179
logger_config: Dict[str, Any],
180-
model: str = Field(
181-
default="Llama3.2-3B-Instruct",
182-
description="Model descriptor from `llama model list`",
180+
model: Optional[str] = Field(
181+
default=None,
182+
description="Model descriptor for training if not in provider config`",
183183
),
184184
checkpoint_dir: Optional[str] = None,
185185
algorithm_config: Optional[AlgorithmConfig] = None,

llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
from llama_stack.apis.datasets import Datasets
3939
from llama_stack.apis.post_training import (
4040
Checkpoint,
41+
DataConfig,
42+
EfficiencyConfig,
4143
LoraFinetuningConfig,
4244
OptimizerConfig,
4345
QATFinetuningConfig,
@@ -89,6 +91,10 @@ def __init__(
8991
datasetio_api: DatasetIO,
9092
datasets_api: Datasets,
9193
) -> None:
94+
assert isinstance(training_config.data_config, DataConfig), "DataConfig must be initialized"
95+
96+
assert isinstance(training_config.efficiency_config, EfficiencyConfig), "EfficiencyConfig must be initialized"
97+
9298
self.job_uuid = job_uuid
9399
self.training_config = training_config
94100
if not isinstance(algorithm_config, LoraFinetuningConfig):
@@ -188,13 +194,16 @@ async def setup(self) -> None:
188194
self._tokenizer = await self._setup_tokenizer()
189195
log.info("Tokenizer is initialized.")
190196

197+
assert isinstance(self.training_config.optimizer_config, OptimizerConfig), "OptimizerConfig must be initialized"
191198
self._optimizer = await self._setup_optimizer(optimizer_config=self.training_config.optimizer_config)
192199
log.info("Optimizer is initialized.")
193200

194201
self._loss_fn = CEWithChunkedOutputLoss()
195202
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
196203
log.info("Loss is initialized.")
197204

205+
assert isinstance(self.training_config.data_config, DataConfig), "DataConfig must be initialized"
206+
198207
self._training_sampler, self._training_dataloader = await self._setup_data(
199208
dataset_id=self.training_config.data_config.dataset_id,
200209
tokenizer=self._tokenizer,
@@ -452,6 +461,7 @@ async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]:
452461
"""
453462
The core training loop.
454463
"""
464+
assert isinstance(self.training_config.data_config, DataConfig), "DataConfig must be initialized"
455465
# Initialize tokens count and running loss (for grad accumulation)
456466
t0 = time.perf_counter()
457467
running_loss: float = 0.0

0 commit comments

Comments
 (0)