Skip to content

Commit ee9bcdd

Browse files
feginjquesnelle
authored andcommitted
Remove the unused compiled_autograd option (pytorch#1939)
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * pytorch#1857 * __->__ pytorch#1939 TorchTitan doesn't need compiled_autograd, which is meant to support compiled DDP, but TorchTitan will adopt fully_shard-based replicate. Let's remove it.
1 parent 3013e2c commit ee9bcdd

File tree

11 files changed

+7
-38
lines changed

11 files changed

+7
-38
lines changed

scripts/estimate/estimation.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,9 @@ def estimate_memory(job_config: JobConfig):
3333
# Get the world size
3434
world_size = int(os.environ["WORLD_SIZE"])
3535

36-
if job_config.compile.enable or job_config.parallelism.enable_compiled_autograd:
36+
if job_config.compile.enable:
3737
logger.info("Compile mode is not supported yet. Switching to eager mode.")
3838
job_config.compile.enable = False
39-
job_config.parallelism.enable_compiled_autograd = False
4039

4140
# init fake pg
4241
store = FakeStore()
@@ -80,10 +79,7 @@ def estimate_memory(job_config: JobConfig):
8079
loss_parallel_enabled = (
8180
parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel
8281
)
83-
train_context = dist_utils.get_train_context(
84-
loss_parallel_enabled,
85-
job_config.parallelism.enable_compiled_autograd,
86-
)
82+
train_context = dist_utils.get_train_context(loss_parallel_enabled)
8783

8884
# build model (using meta init)
8985
model_args = train_spec.model_args[job_config.model.flavor]

torchtitan/config/job_config.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,9 +301,6 @@ class Parallelism:
301301
1 means disabled.
302302
"""
303303

304-
enable_compiled_autograd: bool = False
305-
"""Enable CompiledAutograd to compile the backward."""
306-
307304
data_parallel_shard_degree: int = -1
308305
"""
309306
The `data_parallel_shard_degree` argument specifies the degree of data

torchtitan/distributed/utils.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -193,20 +193,13 @@ def create_context_parallel_ctx(
193193
)
194194

195195

196-
def get_train_context(
197-
enable_loss_parallel: bool, enable_compiled_autograd: bool
198-
) -> Generator[None, None, None]:
196+
def get_train_context(enable_loss_parallel: bool) -> Generator[None, None, None]:
199197
@contextlib.contextmanager
200198
def context(cp_context: Generator[None, None, None] | None = None):
201199
with contextlib.ExitStack() as stack:
202200
if enable_loss_parallel:
203201
stack.enter_context(torch.distributed.tensor.parallel.loss_parallel())
204202

205-
if enable_compiled_autograd:
206-
stack.enter_context(
207-
torch._dynamo.utils.maybe_enable_compiled_autograd(True)
208-
)
209-
210203
if cp_context:
211204
stack.enter_context(cp_context)
212205

torchtitan/experiments/forge/engine.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,7 @@ def __init__(self, job_config: ForgeJobConfig):
233233
loss_parallel_enabled = (
234234
parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel
235235
)
236-
self.train_context = dist_utils.get_train_context(
237-
loss_parallel_enabled,
238-
parallelism_config.enable_compiled_autograd,
239-
)
236+
self.train_context = dist_utils.get_train_context(loss_parallel_enabled)
240237
self.maybe_enable_amp = dist_utils.maybe_enable_amp(
241238
parallel_dims,
242239
job_config.training.mixed_precision_param,

torchtitan/experiments/gpt_oss/infra/parallelize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
torch._higher_order_ops.flex_attention,
4646
}
4747

48+
4849
# Adapted from llama4/infra/parallelize.py
4950
def parallelize_gptoss(
5051
model: nn.Module,
@@ -168,7 +169,6 @@ def parallelize_gptoss(
168169
model,
169170
dp_mesh,
170171
enable_compile=model_compile_enabled,
171-
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
172172
)
173173

174174
return model

torchtitan/experiments/vlm/infra/parallelize.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ def parallelize_vlm(
107107
model,
108108
world_mesh,
109109
enable_compile=job_config.compile.enable,
110-
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
111110
)
112111

113112
return model

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,6 @@ def parallelize_deepseekv3(
171171
model,
172172
dp_mesh,
173173
enable_compile=model_compile_enabled,
174-
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
175174
)
176175

177176
return model

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,6 @@ def parallelize_llama(
143143
model,
144144
world_mesh,
145145
enable_compile=model_compile_enabled,
146-
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
147146
)
148147

149148
return model
@@ -324,15 +323,9 @@ def apply_ddp(
324323
model: nn.Module,
325324
dp_mesh: DeviceMesh,
326325
enable_compile: bool,
327-
enable_compiled_autograd: bool,
328326
):
329327
if enable_compile:
330-
if enable_compiled_autograd:
331-
torch._dynamo.config.optimize_ddp = (
332-
"python_reducer_without_compiled_forward"
333-
)
334-
else:
335-
torch._dynamo.config.optimize_ddp = "ddp_optimizer"
328+
torch._dynamo.config.optimize_ddp = "ddp_optimizer"
336329

337330
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
338331

torchtitan/models/llama4/infra/parallelize.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,6 @@ def parallelize_llama(
191191
model,
192192
dp_mesh,
193193
enable_compile=model_compile_enabled,
194-
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
195194
)
196195

197196
return model

torchtitan/models/qwen3/infra/parallelize.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,6 @@ def parallelize_qwen3(
170170
model,
171171
world_mesh,
172172
enable_compile=model_compile_enabled,
173-
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
174173
)
175174

176175
# Enable weight tying after applying parallelisms

0 commit comments

Comments
 (0)