Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
/torchtitan/experiments/

# codeowners for experiments/forge
/torchtitan/experiments/forge/* @ebsmothers @pbontrager @joecummings @allenwang28 @tianyu-l @wwwjn
/torchtitan/experiments/forge/* @ebsmothers @pbontrager @joecummings @allenwang28 @tianyu-l @wwwjn @fegin
2 changes: 1 addition & 1 deletion benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ A submission should be a file / files including the following information
3. The hardware setup, including the types of GPUs, interconnections, etc.
4. The actual performance report with training configs, e.g. via
- `.toml` files / commandline arguments
- complete configs, which can be found in the log with [`--print_args`](https://github.com/pytorch/torchtitan/blob/e7c0cae934df78d6e9c2835f42ff1f757dc3fddc/torchtitan/config_manager.py#L47) turned on (preferred as the default value not shown in `.toml` or specified in commandline could change from time to time)
- complete configs, which can be found in the log with [`--print_config`](https://github.com/pytorch/torchtitan/blob/e7c0cae934df78d6e9c2835f42ff1f757dc3fddc/torchtitan/config_manager.py#L47) turned on (preferred as the default value not shown in `.toml` or specified in commandline could change from time to time)
5. The versions and date/time of `torchtitan`, `torch`, `torchao`, or any relevant dependencies.
6. Other notes which could help reproduce the results.

Expand Down
18 changes: 9 additions & 9 deletions docs/extension.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ The extension points and protocols mentioned in this note are subject to change.
### `TrainSpec`

[`TrainSpec`](../torchtitan/protocols/train_spec.py) supports configuring high-level components in model training, including
- definitions of model class and model args config
- definitions of model class and model args
- model parallelization functions
- loss functions
- factory methods for creating dataloader / tokenizer / optimizer / learning rate scheduler / metrics processor
Expand Down Expand Up @@ -36,7 +36,7 @@ This is an ongoing effort, and the level of grouping is subject to change.

### Extending `JobConfig`

[`JobConfig`](../torchtitan/config/job_config.py) supports custom extension through the `--experimental.custom_args_module` flag.
[`JobConfig`](../torchtitan/config/job_config.py) supports custom extension through the `--job.custom_config_module` flag.
This lets you define a custom module that extends `JobConfig` with additional fields.

When specified, your custom `JobConfig` is merged with the default:
Expand All @@ -45,14 +45,14 @@ When specified, your custom `JobConfig` is merged with the default:

#### Example

To add a custom `custom_args` section, define your own `JobConfig`:
To add a custom `custom_config` section, define your own `JobConfig`:

```python
# torchtitan/experiments/your_folder/custom_args.py
# torchtitan/experiments/your_folder/job_config.py
from dataclasses import dataclass, field

@dataclass
class CustomArgs:
class CustomConfig:
how_is_your_day: str = "good"
"""Just an example."""

Expand All @@ -68,19 +68,19 @@ class Training:

@dataclass
class JobConfig:
custom_args: CustomArgs = field(default_factory=CustomArgs)
custom_config: CustomConfig = field(default_factory=CustomConfig)
training: Training= field(default_factory=Training)
```

Then run your script with:

```bash
--experimental.custom_args_module=torchtitan.experiments.your_folder.custom_args
--job.custom_config_module=torchtitan.experiments.your_folder.job_config
```

Or specify it in your `.toml` config:

```toml
[experimental]
custom_args_module = "torchtitan.experiments.your_folder.custom_args"
[job]
custom_config_module = "torchtitan.experiments.your_folder.job_config"
```
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


@dataclass
class CustomArgs:
class CustomConfig:
how_is_your_day: str = "good"
"""Just an example helptext"""

Expand All @@ -28,5 +28,5 @@ class JobConfig:
This is an example of how to extend the tyro parser with custom config classes.
"""

custom_args: CustomArgs = field(default_factory=CustomArgs)
custom_config: CustomConfig = field(default_factory=CustomConfig)
training: Training = field(default_factory=Training)
4 changes: 2 additions & 2 deletions tests/integration_tests/base_config.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[job]
dump_folder = "./outputs"
description = "model debug training for integration test"
print_args = false
description = "model debug training for integration tests"
print_config = false

[profiling]
enable_profiling = false
Expand Down
32 changes: 16 additions & 16 deletions tests/unit_tests/test_job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def test_print_help(self):
parser = get_parser(ConfigManager)
parser.print_help()

def test_extend_jobconfig_directly(self):
def test_extend_job_config_directly(self):
@dataclass
class CustomCheckpoint:
convert_path: str = "/custom/path"
Expand All @@ -233,19 +233,19 @@ class CustomJobConfig:
assert hasattr(config, "model")

def test_custom_parser(self):
path = "tests.assets.extend_jobconfig_example"
path = "tests.assets.extended_job_config_example"

config_manager = ConfigManager()
config = config_manager.parse_args(
[
f"--experimental.custom_args_module={path}",
"--custom_args.how-is-your-day",
f"--job.custom_config_module={path}",
"--custom_config.how-is-your-day",
"bad",
"--model.converters",
"float8,mxfp",
]
)
assert config.custom_args.how_is_your_day == "bad"
assert config.custom_config.how_is_your_day == "bad"
assert config.model.converters == ["float8", "mxfp"]
result = config.to_dict()
assert isinstance(result, dict)
Expand All @@ -254,8 +254,8 @@ def test_custom_parser(self):
with self.assertRaisesRegex(SystemExit, "2"):
config = config_manager.parse_args(
[
f"--experimental.custom_args_module={path}",
"--custom_args.how-is-your-day",
f"--job.custom_config_module={path}",
"--custom_config.how-is-your-day",
"bad",
"--model.converters",
"float8,mxfp",
Expand All @@ -266,8 +266,8 @@ def test_custom_parser(self):
with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as fp:
tomli_w.dump(
{
"experimental": {
"custom_args_module": path,
"job": {
"custom_config_module": path,
}
},
fp,
Expand All @@ -278,14 +278,14 @@ def test_custom_parser(self):
config = config_manager.parse_args(
[
f"--job.config_file={fp.name}",
f"--experimental.custom_args_module={path}",
"--custom_args.how-is-your-day",
f"--job.custom_config_module={path}",
"--custom_config.how-is-your-day",
"bad",
"--model.converters",
"float8,mxfp",
]
)
assert config.custom_args.how_is_your_day == "bad"
assert config.custom_config.how_is_your_day == "bad"
assert config.training.my_custom_steps == 32
assert config.model.converters == ["float8", "mxfp"]
result = config.to_dict()
Expand All @@ -294,10 +294,10 @@ def test_custom_parser(self):
with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as fp:
tomli_w.dump(
{
"experimental": {
"custom_args_module": path,
"job": {
"custom_config_module": path,
},
"custom_args": {"how_is_your_day": "really good"},
"custom_config": {"how_is_your_day": "really good"},
"model": {"converters": ["float8", "mxfp"]},
},
fp,
Expand All @@ -311,7 +311,7 @@ def test_custom_parser(self):
]
)

assert config.custom_args.how_is_your_day == "really good"
assert config.custom_config.how_is_your_day == "really good"
assert config.model.converters == ["float8", "mxfp"]
result = config.to_dict()
assert isinstance(result, dict)
Expand Down
12 changes: 10 additions & 2 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,14 @@ class Job:
description: str = "default job"
"""Description of the job"""

print_args: bool = False
"""Print the args to terminal"""
print_config: bool = False
"""Print the configs to terminal"""

custom_config_module: str = ""
"""
This option allows users to extend the existing JobConfig with a customized
JobConfig dataclass. Users need to ensure that the path can be imported.
"""


@dataclass
Expand Down Expand Up @@ -834,6 +840,8 @@ class Experimental:

custom_args_module: str = ""
"""
DEPRECATED (moved to Job.custom_config_module). Will be removed soon.

This option allows users to extend TorchTitan's existing JobConfig by extending
a user defined JobConfig dataclass. Similar to ``--experimental.custom_model_path``, the user
needs to ensure that the path can be imported.
Expand Down
28 changes: 20 additions & 8 deletions torchtitan/config/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, config_cls: Type[JobConfig] = JobConfig):

def parse_args(self, args: list[str] = sys.argv[1:]) -> JobConfig:
toml_values = self._maybe_load_toml(args)
config_cls = self._maybe_add_custom_args(args, toml_values)
config_cls = self._maybe_add_custom_config(args, toml_values)

base_config = (
self._dict_to_dataclass(config_cls, toml_values)
Expand Down Expand Up @@ -83,16 +83,19 @@ def _maybe_load_toml(self, args: list[str]) -> dict[str, Any] | None:
logger.exception(f"Error while loading config file: {file_path}")
raise e

def _maybe_add_custom_args(
def _maybe_add_custom_config(
self, args: list[str], toml_values: dict[str, Any] | None
) -> Type[JobConfig]: # noqa: B006
"""Find and merge custom arguments module with current JobConfig class"""
"""
Find and merge custom config module with current JobConfig class, if it is given.
The search order is first searching CLI args, then toml config file.
"""
module_path = None

# 1. Check CLI
valid_keys = {
"--experimental.custom_args_module",
"--experimental.custom-args-module",
"--job.custom_config_module",
"--job.custom-config-module",
}
for i, arg in enumerate(args):
key = arg.split("=")[0]
Expand All @@ -102,9 +105,9 @@ def _maybe_add_custom_args(

# 2. If not found in CLI, check TOML
if not module_path and toml_values:
experimental = toml_values.get("experimental", {})
if isinstance(experimental, dict):
module_path = experimental.get("custom_args_module")
job = toml_values.get("job", {})
if isinstance(job, dict):
module_path = job.get("custom_config_module")

if not module_path:
return self.config_cls
Expand Down Expand Up @@ -178,6 +181,15 @@ def _dict_to_dataclass(self, cls, data: dict[str, Any]) -> Any:
return cls(**result)

def _validate_config(self) -> None:
if self.config.experimental.custom_args_module:
logger.warning(
"This field is being moved to --job.custom_config_module and "
"will be deprecated soon. Setting job.custom_config_module to "
"experimental.custom_args_module temporarily."
)
self.config.job.custom_config_module = (
self.config.experimental.custom_args_module
)
# TODO: temporary mitigation of BC breaking change in hf_assets_path
# tokenizer default path, need to remove later
if self.config.model.tokenizer_path:
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/experiments/flux/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Run the following command to train the model on a single GPU:

```

If you want to train with other model config, run the following command:
If you want to train with other model args, run the following command:
```bash
CONFIG_FILE="./torchtitan/experiments/flux/train_configs/flux_schnell_model.toml" ./torchtitan/experiments/flux/run_train.sh
```
Expand Down
7 changes: 2 additions & 5 deletions torchtitan/experiments/flux/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.


from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.optimizer import build_optimizers
Expand All @@ -28,7 +25,7 @@
]


flux_configs = {
flux_args = {
"flux-dev": FluxModelArgs(
in_channels=64,
out_channels=64,
Expand Down Expand Up @@ -110,7 +107,7 @@
def get_train_spec() -> TrainSpec:
return TrainSpec(
model_cls=FluxModel,
model_args=flux_configs,
model_args=flux_args,
parallelize_fn=parallelize_flux,
pipelining_fn=None,
build_optimizers_fn=build_optimizers,
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/experiments/flux/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def mse_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.mse_loss(pred.float(), labels.float().detach())


def build_mse_loss(job_config: JobConfig):
def build_mse_loss(job_config: JobConfig, **kwargs):
loss_fn = mse_loss
if job_config.compile.enable and "loss" in job_config.compile.components:
logger.info("Compiling the loss function with torch.compile")
Expand Down
4 changes: 2 additions & 2 deletions torchtitan/experiments/flux/model/state_dict_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@

import torch

from torchtitan.protocols.state_dict_adapter import BaseStateDictAdapter
from torchtitan.protocols.state_dict_adapter import StateDictAdapter

from .args import FluxModelArgs

logger = logging.getLogger()


class FluxStateDictAdapter(BaseStateDictAdapter):
class FluxStateDictAdapter(StateDictAdapter):
"""
State dict adapter for Flux model to convert between HuggingFace safetensors format
and torchtitan DCP format.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_load_dataset(self):
config_manager = ConfigManager()
config = config_manager.parse_args(
[
f"--experimental.custom_args_module={path}",
f"--job.custom_config_module={path}",
"--training.img_size",
str(256),
"--training.dataset",
Expand Down
7 changes: 2 additions & 5 deletions torchtitan/experiments/flux/train_configs/debug_model.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@

[job]
dump_folder = "./outputs"
description = "Flux debug model"
print_args = false
print_config = false
custom_config_module = "torchtitan.experiments.flux.job_config"

[profiling]
enable_profiling = false
Expand Down Expand Up @@ -49,9 +49,6 @@ autoencoder_path = "torchtitan/experiments/flux/assets/autoencoder/ae.safetensor
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1

[experimental]
custom_args_module = "torchtitan.experiments.flux.job_config"

[activation_checkpoint]
mode = "full"

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@

[job]
dump_folder = "./outputs"
description = "Flux-dev model"
print_args = false
custom_config_module = "torchtitan.experiments.flux.job_config"

[profiling]
enable_profiling = false
Expand Down Expand Up @@ -49,9 +48,6 @@ autoencoder_path = "torchtitan/experiments/flux/assets/autoencoder/ae.safetensor
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1

[experimental]
custom_args_module = "torchtitan.experiments.flux.job_config"

[activation_checkpoint]
mode = "full"

Expand Down
Loading
Loading