diff --git a/autoparallel/api.py b/autoparallel/api.py index f59a727d..3ac7fc17 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -191,21 +191,18 @@ def __enter__(self): self.build_model_graph() - from torch._subclasses.fake_tensor import unset_fake_temporarily - - with unset_fake_temporarily(): - rescale_grad_comm_cost_for_mp = 1.0 - if self.mp_policy is not None: - param_size = self.mp_policy.param_dtype.itemsize - reduce_size = self.mp_policy.reduce_dtype.itemsize - if param_size != reduce_size: - rescale_grad_comm_cost_for_mp = reduce_size / param_size - # Tiebreak, favoring performing the comms in the largest - # dtype - rescale_grad_comm_cost_for_mp *= 1.1 - sharding_optimizer = ShardingOptimizer( - self.gm, self.mesh, rescale_grad_comm_cost_for_mp - ) + rescale_grad_comm_cost_for_mp = 1.0 + if self.mp_policy is not None: + param_size = self.mp_policy.param_dtype.itemsize + reduce_size = self.mp_policy.reduce_dtype.itemsize + if param_size != reduce_size: + rescale_grad_comm_cost_for_mp = reduce_size / param_size + # Tiebreak, favoring performing the comms in the largest + # dtype + rescale_grad_comm_cost_for_mp *= 1.1 + sharding_optimizer = ShardingOptimizer( + self.gm, self.mesh, rescale_grad_comm_cost_for_mp + ) # makes sharding of params and gradients the same sharding_optimizer.add_grad_param_constraints() diff --git a/autoparallel/propagation_rules.py b/autoparallel/propagation_rules.py index abba8964..f445f03e 100644 --- a/autoparallel/propagation_rules.py +++ b/autoparallel/propagation_rules.py @@ -700,7 +700,10 @@ def reshape_rule(mesh, op_schema): @register_opschema_rule(torch.ops.aten.expand.default) def expand_rule(mesh, op_schema_): op = torch.ops.aten.expand.default - op_schema = copy.deepcopy(op_schema_) + from torch._subclasses.fake_tensor import unset_fake_temporarily + + with unset_fake_temporarily(): + op_schema = copy.deepcopy(op_schema_) input_strat = op_schema.args_schema[0] orig_shape = input_strat.strategies[0].output_specs.tensor_meta.shape dest_shape = op_schema.args_schema[1]