1717)
1818
1919from torchtitan .config .job_config import ActivationCheckpoint as ACConfig
20- from torchtitan .config .job_config import Debug as DebugConfig
2120from torchtitan .tools .logging import logger , warn_once
2221
2322
@@ -43,7 +42,7 @@ def _apply_layer_sac(module: nn.Module, ac_config: ACConfig) -> nn.Module:
4342 preserve_rng_state = ac_config .preserve_rng_state ,
4443 determinism_check = ac_config .determinism_check ,
4544 early_stop = ac_config .early_stop ,
46- debug = ac_config .debug
45+ debug = ac_config .debug ,
4746 )
4847 else :
4948 return module
@@ -133,7 +132,7 @@ def selective_checkpointing_context_fn():
133132 preserve_rng_state = ac_config .preserve_rng_state ,
134133 determinism_check = ac_config .determinism_check ,
135134 early_stop = ac_config .early_stop ,
136- debug = ac_config .debug
135+ debug = ac_config .debug ,
137136 )
138137
139138
@@ -152,7 +151,7 @@ def _apply_full_ac(module: nn.Module, ac_config: ACConfig) -> nn.Module:
152151 preserve_rng_state = ac_config .preserve_rng_state ,
153152 determinism_check = ac_config .determinism_check ,
154153 early_stop = ac_config .early_stop ,
155- debug = ac_config .debug
154+ debug = ac_config .debug ,
156155 )
157156
158157
@@ -198,7 +197,6 @@ def _apply_op_sac_to_transformer_block_with_flex(
198197 ),
199198 )
200199
201-
202200 def wrap_submodule (name : str , full_ac : bool = False ) -> None :
203201 submodule = getattr (module , name )
204202 if full_ac :
0 commit comments