Skip to content

Commit 3fba727

Browse files
authored
Use fake_mode when constructing ShardingOptimizer (#70)
This is particularly important for constructor nodes and for the flop estimation, otherwise they could materialize massive tensors in memory, leading to OOM. This shows up in DeepSeek.
1 parent aaf6f5f commit 3fba727

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

autoparallel/api.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -191,21 +191,18 @@ def __enter__(self):
191191

192192
self.build_model_graph()
193193

194-
from torch._subclasses.fake_tensor import unset_fake_temporarily
195-
196-
with unset_fake_temporarily():
197-
rescale_grad_comm_cost_for_mp = 1.0
198-
if self.mp_policy is not None:
199-
param_size = self.mp_policy.param_dtype.itemsize
200-
reduce_size = self.mp_policy.reduce_dtype.itemsize
201-
if param_size != reduce_size:
202-
rescale_grad_comm_cost_for_mp = reduce_size / param_size
203-
# Tiebreak, favoring performing the comms in the largest
204-
# dtype
205-
rescale_grad_comm_cost_for_mp *= 1.1
206-
sharding_optimizer = ShardingOptimizer(
207-
self.gm, self.mesh, rescale_grad_comm_cost_for_mp
208-
)
194+
rescale_grad_comm_cost_for_mp = 1.0
195+
if self.mp_policy is not None:
196+
param_size = self.mp_policy.param_dtype.itemsize
197+
reduce_size = self.mp_policy.reduce_dtype.itemsize
198+
if param_size != reduce_size:
199+
rescale_grad_comm_cost_for_mp = reduce_size / param_size
200+
# Tiebreak, favoring performing the comms in the largest
201+
# dtype
202+
rescale_grad_comm_cost_for_mp *= 1.1
203+
sharding_optimizer = ShardingOptimizer(
204+
self.gm, self.mesh, rescale_grad_comm_cost_for_mp
205+
)
209206

210207
# makes sharding of params and gradients the same
211208
sharding_optimizer.add_grad_param_constraints()

autoparallel/propagation_rules.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,10 @@ def reshape_rule(mesh, op_schema):
700700
@register_opschema_rule(torch.ops.aten.expand.default)
701701
def expand_rule(mesh, op_schema_):
702702
op = torch.ops.aten.expand.default
703-
op_schema = copy.deepcopy(op_schema_)
703+
from torch._subclasses.fake_tensor import unset_fake_temporarily
704+
705+
with unset_fake_temporarily():
706+
op_schema = copy.deepcopy(op_schema_)
704707
input_strat = op_schema.args_schema[0]
705708
orig_shape = input_strat.strategies[0].output_specs.tensor_meta.shape
706709
dest_shape = op_schema.args_schema[1]

0 commit comments

Comments
 (0)