Skip to content

Conversation

@zpcore
Copy link
Contributor

@zpcore zpcore commented Jul 22, 2025

This PR added Implicit replication fallback and batch sharding strategy. Those DTensor feature is added here for faster iteration since there is a concern (pytorch/pytorch#158476 (comment)) that those features may not be a good fit into DTensor upstream.

Please use strategy_pool.get_op_strategy(op, op_schema) instead of torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[op](op_shcema) to retrieve strategies, as this can support implicit replication fallback.

New supported strategies:

  1. Implicit replication fallback:
from autoparallel.dtensor_util import strategy_pool
with strategy_pool.replicate_for_unsupported_operators():
    ... # (missing ops will use replicated strategy if possible)
  1. Batch sharding strategy support:
from torch.distributed.tensor._op_schema import RuntimeSchemaInfo
from autoparallel.dtensor_util.utils import batch_shard_strategy
from autoparallel.dtensor_util import strategy_pool
# create strategy with input tensor 1 replicated, input tensor 2 shard on dim 0. Output tensor shard on dim 0:
custom_shard_strategy = functools.partial(batch_shard_strategy, input_shard_dim=[None, 0], output_shard_dim=[0])
# register the strategy:
strategy_pool.register_op_strategy(new_op)(custom_shard_strategy)

For details, check func description in autoparallel/dtensor_util/utils.py and example usage in tests/test_dtensor.py.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 22, 2025
@zpcore zpcore requested review from XilunWu, fmassa and wconstab July 22, 2025 20:37
Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did an initial pass, this is looking pretty good, thanks!

def propagate(self, op_info: OpInfo) -> None:
op_info.output_sharding = self.propagate_op_sharding_non_cached(op_info.schema)

def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputSharding:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit hard to see what has been changed in this method vs the original one. Can you maybe list the main changes, so that if the original ShardingPropagator changes we can more easily adapt the changes here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the comment regarding the change to the class. The class is only useful in the test. AP won't use DTensor's _op_dispatcher.sharding_propagator.propagate function.

Comment on lines 200 to 207
] = (
torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs
)
self.op_to_schema_info: dict[
torch._ops.OpOverload, RuntimeSchemaInfo
] = (
torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_to_schema_info
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about making a copy of those dicts before assigning them to StrategyPool?

This way, we could register our custom strategies directly into StrategyPool as well, instead of handling them separately.

This doesn't need to be done now though, and I'm happy to get this PR merged as is

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, created a copy for existing DTensor's rule/strategy to make them independent. This reminded me that I forgot to copy rule.

return expand_to_full_mesh_op_strategy(mesh, op_schema, single_dim_placement)


def batch_shard_strategy(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't yet check this function implementtion



class ImplicitRegistrationTest(DTensorTestBase):
@with_comms
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests really don't need comms right? They can be done entirely single process with fake pg

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that in this simple test the comms can be replaced with tensor copy_ (what FakePG does). To use FakePG testing, should Pei just inherit the TestFakeDistributed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, the FakePG doesn't guarantee the numerical correctness, then we can't compare Dtensor value as in self._test_op_on_dtensor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, how do y'all feel about building a testing harness that lets us do DTensor on a single process and single device? One way to do this is to trace out the SPMD graph, including collectives, using fake tensors, and then run it with a custom interpreter that runs each operation on N copies of the tensor (the replicas on each node) and has custom implementations of each collective that are a single device kernel on the N inputs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hah, every word makes sense to me, but the detail is too hardcore and I need to spend time looking into that. This should be really useful. Put it in my radar!

Copy link

@XilunWu XilunWu Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea of serializing DTensor tests by @ezyang but I wonder about the role of this. Shall we only keep this serialized test in regular CI and move the current multi-threads multi-device DTensor tests to periodic? I do believe this would largely reduce the CI time on DTensor tests.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we already have multi-threaded process group test base. IT fails for some dTensor code because of RNG differentness on different ranks. That's a pretty subtle footgun. If we do the single-process thing, we'd need to figure out how to update the RNG states properly for each simulated rank.

hold on, i'm missing something. If the proposal is to use this new thing for tests where we don't care about numerics, then I don't see why we need a new thing. We already use fakePG + dtensor successfully in this repo, take a look at the examples folder. if we do want numerics, then i don't see how we can do this proposal bc we'd actually have to run the collectives.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can expand on this in chat if it will help, but essentially, I think a version of DTensor that is single process, single threaded, single device and matches numerics would be helpful for testing. For example, we can run DTensor tests on cheap 1gpu nodes in a single process when this happens. (I used to think multi-threaded process group was a good idea but I did try it out a bit and it's too easy to deadlock the collectives, so I don't think it's worth the squeeze since it's just for testing.)

How can you do this? Well, you work backwards from the constraints. If you're on one device, then all the data that logically lives on different ranks is actually all on the same device. How do you do collectives? Well, you can't run a real collective, but why do you need a real collective? They're all on the same device. Just do the equivalent regular Tensor operation (for example, allreduce is probably a stack and then sum).

I don't know what exactly the RNG problem is, but if you just need the RNGs to evolve lockstep over each virtual device, this can be done by maintaining N generators and switching into them appropriately as necessary.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ALSO, absent this harness existing, I think doing tests where you run some DTensor ops under fake pg and make_fx and then do an expect test on the operations would work in a pinch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, those mechanics sound right to me.

@zpcore zpcore changed the title Support of explicit fallback and batch sharding strategy Support of implicit fallback and batch sharding strategy Jul 23, 2025
self.enable_implicit_replication: bool = False
self.implicit_strategy_op_tracker: list[torch._ops.OpOverload] = []

def get_op_strategy(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm trying to think about the simplest way to write this PR, because it seems like it adds a lot of code that might not all be needed.

we already have a well-defined entrypoint from autoparallel into DTensor sharding prop via this function.

#158046 adds op_strategy_context which should take care of the machinery of pushing and later clearing the op strategy registration.

Could we have get_op_strategy do something simpler like

if op not in torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs:
   op_ctxs.enter_context(op_strategy_context(op, replicate_op_strategy))
   logger.warning(f"implicitly registering {op}")
 
return torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[op]

op_ctxs could be a contextlib.ExitStack that we store globally. I'd rather store it on the AutoParallel instance, but truth is it would be more correct for it to be a global/singleton since it matches up with global state inside DTensor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • I can't use op_strategy_context from upstream since it is only available in dtensor test. If this function is useful, I can try move it outside of the test.
  • We want to make a decision on whether we want to use torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs directly or make a copy for AP to use (@Francescaaa seems prefer copy in Support of implicit fallback and batch sharding strategy #46 (comment)).

zpcore added a commit that referenced this pull request Jul 25, 2025
(Split out the large PR from #46)

Introduce the batch sharding strategy:
```python
from torch.distributed.tensor._op_schema import RuntimeSchemaInfo
from autoparallel.dtensor_util.utils import batch_shard_strategy
from autoparallel.dtensor_util import strategy_pool
# create strategy with input tensor 1 replicated, input tensor 2 shard on dim 0. Output tensor shard on dim 0:
custom_shard_strategy = functools.partial(batch_shard_strategy, input_shard_dim=[None, 0], output_shard_dim=[0])
# register the strategy:
strategy_pool.register_op_strategy(new_op)(custom_shard_strategy)
```
For details, check func description in autoparallel/dtensor_util/utils.py and example usage in tests/test_dtensor.py.




[ghstack-poisoned]
zpcore added a commit that referenced this pull request Jul 25, 2025
(Split out the large PR from #46)

Introduce the batch sharding strategy:
```python
from torch.distributed.tensor._op_schema import RuntimeSchemaInfo
from autoparallel.dtensor_util.utils import batch_shard_strategy
from autoparallel.dtensor_util import strategy_pool
# create strategy with input tensor 1 replicated, input tensor 2 shard on dim 0. Output tensor shard on dim 0:
custom_shard_strategy = functools.partial(batch_shard_strategy, input_shard_dim=[None, 0], output_shard_dim=[0])
# register the strategy:
strategy_pool.register_op_strategy(new_op)(custom_shard_strategy)
```
For details, check func description in autoparallel/dtensor_util/utils.py and example usage in tests/test_dtensor.py.




[ghstack-poisoned]
zpcore added a commit that referenced this pull request Jul 25, 2025
(Split out the large PR #46)
Support the implicit replication fallback startegy.

How to use Implicit replication fallback:
```python
from autoparallel.dtensor_util import strategy_pool
with strategy_pool.replicate_for_unsupported_operators():
    ... # (missing ops will use replicated strategy if possible)
```

Note: StrategyPool reuses the _op_dispatcher.sharding_propagator.op_strategy_funcs/op_to_rules/op_to_schema_info by reference now.




[ghstack-poisoned]
zpcore added a commit that referenced this pull request Jul 25, 2025
(Split out the large PR from #46)

Introduce the batch sharding strategy:
```python
from torch.distributed.tensor._op_schema import RuntimeSchemaInfo
from autoparallel.dtensor_util.utils import batch_shard_strategy
from autoparallel.dtensor_util import strategy_pool
# create strategy with input tensor 1 replicated, input tensor 2 shard on dim 0. Output tensor shard on dim 0:
custom_shard_strategy = functools.partial(batch_shard_strategy, input_shard_dim=[None, 0], output_shard_dim=[0])
# register the strategy:
strategy_pool.register_op_strategy(new_op)(custom_shard_strategy)
```
For details, check func description in autoparallel/dtensor_util/utils.py and example usage in tests/test_dtensor.py.




[ghstack-poisoned]
zpcore added a commit that referenced this pull request Jul 25, 2025
(Split out the large PR from #46)

Introduce the batch sharding strategy:
```python
from torch.distributed.tensor._op_schema import RuntimeSchemaInfo
from autoparallel.dtensor_util.utils import batch_shard_strategy
from autoparallel.dtensor_util import strategy_pool
# create strategy with input tensor 1 replicated, input tensor 2 shard on dim 0. Output tensor shard on dim 0:
custom_shard_strategy = functools.partial(batch_shard_strategy, input_shard_dim=[None, 0], output_shard_dim=[0])
# register the strategy:
strategy_pool.register_op_strategy(new_op)(custom_shard_strategy)
```
For details, check func description in autoparallel/dtensor_util/utils.py and example usage in tests/test_dtensor.py.




[ghstack-poisoned]
zpcore added a commit that referenced this pull request Jul 25, 2025
(Split out the large PR from #46)

Introduce the batch sharding strategy:
```python
from torch.distributed.tensor._op_schema import RuntimeSchemaInfo
from autoparallel.dtensor_util.utils import batch_shard_strategy
from autoparallel.dtensor_util import strategy_pool
# create strategy with input tensor 1 replicated, input tensor 2 shard on dim 0. Output tensor shard on dim 0:
custom_shard_strategy = functools.partial(batch_shard_strategy, input_shard_dim=[None, 0], output_shard_dim=[0])
# register the strategy:
strategy_pool.register_op_strategy(new_op)(custom_shard_strategy)
```
For details, check func description in autoparallel/dtensor_util/utils.py and example usage in tests/test_dtensor.py.




[ghstack-poisoned]
zpcore added a commit that referenced this pull request Jul 25, 2025
(Split out the large PR from #46)

Introduce the batch sharding strategy:
```python
from torch.distributed.tensor._op_schema import RuntimeSchemaInfo
from autoparallel.dtensor_util.utils import batch_shard_strategy
from autoparallel.dtensor_util import strategy_pool
# create strategy with input tensor 1 replicated, input tensor 2 shard on dim 0. Output tensor shard on dim 0:
custom_shard_strategy = functools.partial(batch_shard_strategy, input_shard_dim=[None, 0], output_shard_dim=[0])
# register the strategy:
strategy_pool.register_op_strategy(new_op)(custom_shard_strategy)
```
For details, check func description in autoparallel/dtensor_util/utils.py and example usage in tests/test_dtensor.py.




[ghstack-poisoned]
zpcore added a commit that referenced this pull request Jul 25, 2025
(Split out the large PR #46)
Support the implicit replication fallback startegy.

How to use Implicit replication fallback:
```python
from autoparallel.dtensor_util import strategy_pool
with strategy_pool.replicate_for_unsupported_operators():
    ... # (missing ops will use replicated strategy if possible)
```

Note: StrategyPool reuses the _op_dispatcher.sharding_propagator.op_strategy_funcs/op_to_rules/op_to_schema_info by reference now.




[ghstack-poisoned]
@wconstab
Copy link
Contributor

@zpcore we can close this right? (since you split it into smaller PRs)

@zpcore
Copy link
Contributor Author

zpcore commented Jul 28, 2025

@zpcore we can close this right? (since you split it into smaller PRs)

Yes, let's close this one and use #49 stack PRs.

@zpcore zpcore closed this Jul 28, 2025
@fmassa fmassa deleted the piz/dtensor_util branch July 28, 2025 20:47
fmassa pushed a commit that referenced this pull request Jul 29, 2025
* Support of explicit fallback

[ghstack-poisoned]

* Update on "Support of implicit fallback"

(Split out the large PR #46)
Support the implicit replication fallback startegy.

How to use Implicit replication fallback:
```python
from autoparallel.dtensor_util import strategy_pool
with strategy_pool.replicate_for_unsupported_operators():
    ... # (missing ops will use replicated strategy if possible)
```

Note: StrategyPool reuses the _op_dispatcher.sharding_propagator.op_strategy_funcs/op_to_rules/op_to_schema_info by reference now.




[ghstack-poisoned]

* Update on "Support of implicit fallback"

(Split out the large PR #46)
Support the implicit replication fallback startegy.

How to use Implicit replication fallback:
```python
from autoparallel.dtensor_util import strategy_pool
with strategy_pool.replicate_for_unsupported_operators():
    ... # (missing ops will use replicated strategy if possible)
```

Note: StrategyPool reuses the _op_dispatcher.sharding_propagator.op_strategy_funcs/op_to_rules/op_to_schema_info by reference now.




[ghstack-poisoned]
zpcore added a commit that referenced this pull request Jul 29, 2025
(Split out the large PR #46)
Support the implicit replication fallback startegy.

How to use Implicit replication fallback:
```python
from autoparallel.dtensor_util import strategy_pool
with strategy_pool.replicate_for_unsupported_operators():
    ... # (missing ops will use replicated strategy if possible)
```

Note: StrategyPool reuses the _op_dispatcher.sharding_propagator.op_strategy_funcs/op_to_rules/op_to_schema_info by reference now.




[ghstack-poisoned]
zpcore added a commit that referenced this pull request Jul 29, 2025
(Split out the large PR from #46)

Introduce the batch sharding strategy:
```python
from torch.distributed.tensor._op_schema import RuntimeSchemaInfo
from autoparallel.dtensor_util.utils import batch_shard_strategy
from autoparallel.dtensor_util import strategy_pool
# create strategy with input tensor 1 replicated, input tensor 2 shard on dim 0. Output tensor shard on dim 0:
custom_shard_strategy = functools.partial(batch_shard_strategy, input_shard_dim=[None, 0], output_shard_dim=[0])
# register the strategy:
strategy_pool.register_op_strategy(new_op)(custom_shard_strategy)
```
For details, check func description in autoparallel/dtensor_util/utils.py and example usage in tests/test_dtensor.py.




[ghstack-poisoned]
zpcore added a commit that referenced this pull request Jul 29, 2025
(Split out the large PR #46)
Support the implicit replication fallback startegy.

How to use Implicit replication fallback:
```python
from autoparallel.dtensor_util import strategy_pool
with strategy_pool.replicate_for_unsupported_operators():
    ... # (missing ops will use replicated strategy if possible)
```

Note: StrategyPool reuses the _op_dispatcher.sharding_propagator.op_strategy_funcs/op_to_rules/op_to_schema_info by reference now.




[ghstack-poisoned]
zpcore added a commit that referenced this pull request Jul 29, 2025
(Split out the large PR from #46)

Introduce the batch sharding strategy:
```python
from torch.distributed.tensor._op_schema import RuntimeSchemaInfo
from autoparallel.dtensor_util.utils import batch_shard_strategy
from autoparallel.dtensor_util import strategy_pool
# create strategy with input tensor 1 replicated, input tensor 2 shard on dim 0. Output tensor shard on dim 0:
custom_shard_strategy = functools.partial(batch_shard_strategy, input_shard_dim=[None, 0], output_shard_dim=[0])
# register the strategy:
strategy_pool.register_op_strategy(new_op)(custom_shard_strategy)
```
For details, check func description in autoparallel/dtensor_util/utils.py and example usage in tests/test_dtensor.py.




[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants