Skip to content

Commit dc2c211

Browse files
committed
ufmt fixes.
1 parent a45e7eb commit dc2c211

File tree

8 files changed

+17
-15
lines changed

8 files changed

+17
-15
lines changed

scripts/generate/test_generate.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
RowwiseParallel,
2626
)
2727
from torchtitan.components.metrics import build_device_memory_monitor
28-
from torchtitan.config import ConfigManager
28+
from torchtitan.config import ConfigManager, Debug as DebugConfig
2929
from torchtitan.distributed import ParallelDims, utils as dist_utils
3030
from torchtitan.protocols.train_spec import get_train_spec
3131
from torchtitan.tools import utils
@@ -133,7 +133,9 @@ def test_generate(
133133
# sequences would require https://github.com/pytorch/torchtitan/pull/686
134134
apply_tp_minus_sp(model, parallel_dims.world_mesh["tp"])
135135

136-
dist_utils.set_determinism(world_mesh, device, seed, deterministic)
136+
debug_config = DebugConfig()
137+
debug_config.deterministic = deterministic
138+
dist_utils.set_determinism(world_mesh, device, debug_config, seed)
137139

138140
# materalize model
139141
model.to_empty(device=device_type)

torchtitan/config/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
ActivationCheckpoint,
1717
Checkpoint,
1818
Comm,
19+
Debug,
1920
FaultTolerance,
2021
Job,
2122
JobConfig,
@@ -28,7 +29,6 @@
2829
Quantize,
2930
Training,
3031
Validation,
31-
Debug
3232
)
3333
from .manager import ConfigManager
3434

@@ -50,5 +50,5 @@
5050
"Profiling",
5151
"Training",
5252
"Validation",
53-
"Debug"
53+
"Debug",
5454
]

torchtitan/config/job_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,7 @@ class Debug:
905905
moe_force_load_balance: bool = False
906906
"""If True, we force each experts to get the same amount of tokens via round-robin. This option is for debugging usage only."""
907907

908+
908909
@dataclass
909910
class JobConfig:
910911
"""

torchtitan/distributed/activation_checkpoint.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
)
1818

1919
from torchtitan.config.job_config import ActivationCheckpoint as ACConfig
20-
from torchtitan.config.job_config import Debug as DebugConfig
2120
from torchtitan.tools.logging import logger, warn_once
2221

2322

@@ -43,7 +42,7 @@ def _apply_layer_sac(module: nn.Module, ac_config: ACConfig) -> nn.Module:
4342
preserve_rng_state=ac_config.preserve_rng_state,
4443
determinism_check=ac_config.determinism_check,
4544
early_stop=ac_config.early_stop,
46-
debug=ac_config.debug
45+
debug=ac_config.debug,
4746
)
4847
else:
4948
return module
@@ -133,7 +132,7 @@ def selective_checkpointing_context_fn():
133132
preserve_rng_state=ac_config.preserve_rng_state,
134133
determinism_check=ac_config.determinism_check,
135134
early_stop=ac_config.early_stop,
136-
debug=ac_config.debug
135+
debug=ac_config.debug,
137136
)
138137

139138

@@ -152,7 +151,7 @@ def _apply_full_ac(module: nn.Module, ac_config: ACConfig) -> nn.Module:
152151
preserve_rng_state=ac_config.preserve_rng_state,
153152
determinism_check=ac_config.determinism_check,
154153
early_stop=ac_config.early_stop,
155-
debug=ac_config.debug
154+
debug=ac_config.debug,
156155
)
157156

158157

@@ -198,7 +197,6 @@ def _apply_op_sac_to_transformer_block_with_flex(
198197
),
199198
)
200199

201-
202200
def wrap_submodule(name: str, full_ac: bool = False) -> None:
203201
submodule = getattr(module, name)
204202
if full_ac:

torchtitan/distributed/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
from torch.distributed.device_mesh import DeviceMesh
1818
from torch.distributed.tensor import DTensor
1919

20-
from torchtitan.config import Comm as CommConfig, TORCH_DTYPE_MAP
21-
from torchtitan.config import Debug as DebugConfig
20+
from torchtitan.config import Comm as CommConfig, Debug as DebugConfig, TORCH_DTYPE_MAP
2221
from torchtitan.distributed.parallel_dims import ParallelDims
2322
from torchtitan.tools.logging import logger
2423
from torchtitan.tools.utils import device_module, device_type
@@ -100,7 +99,9 @@ def set_determinism(
10099
if debug_config.deterministic:
101100
logger.info("Deterministic algorithm enabled (expect perf degradation).")
102101
torch.use_deterministic_algorithms(True)
103-
torch.use_deterministic_algorithms(True, warn_only=debug_config.deterministic_warn_only)
102+
torch.use_deterministic_algorithms(
103+
True, warn_only=debug_config.deterministic_warn_only
104+
)
104105
torch.backends.cudnn.deterministic = True
105106
torch.backends.cudnn.benchmark = False
106107
# env var for deterministic CuBLAS

torchtitan/experiments/forge/job_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Checkpoint,
1313
Comm,
1414
Compile,
15+
Debug,
1516
Job,
1617
LRScheduler,
1718
MemoryEstimation,
@@ -20,7 +21,6 @@
2021
Parallelism,
2122
Quantize,
2223
Training,
23-
Debug,
2424
)
2525

2626

torchtitan/models/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def __init__(self) -> None:
9090
SDPBackend.CUDNN_ATTENTION,
9191
SDPBackend.FLASH_ATTENTION,
9292
SDPBackend.EFFICIENT_ATTENTION,
93-
SDPBackend.MATH
93+
SDPBackend.MATH,
9494
]
9595

9696
def forward(

torchtitan/models/qwen3/model/args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
5656
self.max_seq_len = seq_len
5757

5858
self.moe_args._debug_force_load_balance = (
59-
job_config.training.debug_moe_force_load_balance
59+
job_config.debug.moe_force_load_balance
6060
)
6161

6262
def get_nparams_and_flops(

0 commit comments

Comments
 (0)