Skip to content

Commit b53ad10

Browse files
authored
Fix upstream api breakage (default_strategy) (#41)
Renamed after changing the semantic upstream pytorch/pytorch#158490
1 parent 583e5fd commit b53ad10

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

autoparallel/propagation_rules.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -500,9 +500,11 @@ def native_layer_norm_backward_rule(mesh, op_schema):
500500

501501
@register_opschema_rule(torch.ops.prims.convert_element_type.default)
502502
def convert_element_type_rule(mesh, op_schema):
503-
from torch.distributed.tensor._ops._tensor_ops import default_strategy
503+
from torch.distributed.tensor._ops._tensor_ops import (
504+
propagate_single_input_strategy,
505+
)
504506

505-
out_strat = default_strategy(op_schema)
507+
out_strat = propagate_single_input_strategy(op_schema)
506508
return out_strat
507509

508510

0 commit comments

Comments
 (0)