Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 12 additions & 15 deletions autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,21 +191,18 @@ def __enter__(self):

self.build_model_graph()

from torch._subclasses.fake_tensor import unset_fake_temporarily

with unset_fake_temporarily():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you happen to know why this was here originally? Seems like it would have been logical to use fake mode all along. Maybe it was just to work around the one case with reshape below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it was there because in one of the shardings we were calling copy.deepcopy , which is not allowed under fake_mode.

rescale_grad_comm_cost_for_mp = 1.0
if self.mp_policy is not None:
param_size = self.mp_policy.param_dtype.itemsize
reduce_size = self.mp_policy.reduce_dtype.itemsize
if param_size != reduce_size:
rescale_grad_comm_cost_for_mp = reduce_size / param_size
# Tiebreak, favoring performing the comms in the largest
# dtype
rescale_grad_comm_cost_for_mp *= 1.1
sharding_optimizer = ShardingOptimizer(
self.gm, self.mesh, rescale_grad_comm_cost_for_mp
)
rescale_grad_comm_cost_for_mp = 1.0
if self.mp_policy is not None:
param_size = self.mp_policy.param_dtype.itemsize
reduce_size = self.mp_policy.reduce_dtype.itemsize
if param_size != reduce_size:
rescale_grad_comm_cost_for_mp = reduce_size / param_size
# Tiebreak, favoring performing the comms in the largest
# dtype
rescale_grad_comm_cost_for_mp *= 1.1
sharding_optimizer = ShardingOptimizer(
self.gm, self.mesh, rescale_grad_comm_cost_for_mp
)

# makes sharding of params and gradients the same
sharding_optimizer.add_grad_param_constraints()
Expand Down
5 changes: 4 additions & 1 deletion autoparallel/propagation_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,10 @@ def reshape_rule(mesh, op_schema):
@register_opschema_rule(torch.ops.aten.expand.default)
def expand_rule(mesh, op_schema_):
op = torch.ops.aten.expand.default
op_schema = copy.deepcopy(op_schema_)
from torch._subclasses.fake_tensor import unset_fake_temporarily

with unset_fake_temporarily():
op_schema = copy.deepcopy(op_schema_)
input_strat = op_schema.args_schema[0]
orig_shape = input_strat.strategies[0].output_specs.tensor_meta.shape
dest_shape = op_schema.args_schema[1]
Expand Down
Loading