-
Notifications
You must be signed in to change notification settings - Fork 8
Support of implicit fallback and batch sharding strategy #46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
fmassa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did an initial pass, this is looking pretty good, thanks!
| def propagate(self, op_info: OpInfo) -> None: | ||
| op_info.output_sharding = self.propagate_op_sharding_non_cached(op_info.schema) | ||
|
|
||
| def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputSharding: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bit hard to see what has been changed in this method vs the original one. Can you maybe list the main changes, so that if the original ShardingPropagator changes we can more easily adapt the changes here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added the comment regarding the change to the class. The class is only useful in the test. AP won't use DTensor's _op_dispatcher.sharding_propagator.propagate function.
autoparallel/dtensor_util/utils.py
Outdated
| ] = ( | ||
| torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs | ||
| ) | ||
| self.op_to_schema_info: dict[ | ||
| torch._ops.OpOverload, RuntimeSchemaInfo | ||
| ] = ( | ||
| torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_to_schema_info | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think about making a copy of those dicts before assigning them to StrategyPool?
This way, we could register our custom strategies directly into StrategyPool as well, instead of handling them separately.
This doesn't need to be done now though, and I'm happy to get this PR merged as is
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, created a copy for existing DTensor's rule/strategy to make them independent. This reminded me that I forgot to copy rule.
| return expand_to_full_mesh_op_strategy(mesh, op_schema, single_dim_placement) | ||
|
|
||
|
|
||
| def batch_shard_strategy( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't yet check this function implementtion
|
|
||
|
|
||
| class ImplicitRegistrationTest(DTensorTestBase): | ||
| @with_comms |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These tests really don't need comms right? They can be done entirely single process with fake pg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that in this simple test the comms can be replaced with tensor copy_ (what FakePG does). To use FakePG testing, should Pei just inherit the TestFakeDistributed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC, the FakePG doesn't guarantee the numerical correctness, then we can't compare Dtensor value as in self._test_op_on_dtensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, how do y'all feel about building a testing harness that lets us do DTensor on a single process and single device? One way to do this is to trace out the SPMD graph, including collectives, using fake tensors, and then run it with a custom interpreter that runs each operation on N copies of the tensor (the replicas on each node) and has custom implementations of each collective that are a single device kernel on the N inputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hah, every word makes sense to me, but the detail is too hardcore and I need to spend time looking into that. This should be really useful. Put it in my radar!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the idea of serializing DTensor tests by @ezyang but I wonder about the role of this. Shall we only keep this serialized test in regular CI and move the current multi-threads multi-device DTensor tests to periodic? I do believe this would largely reduce the CI time on DTensor tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we already have multi-threaded process group test base. IT fails for some dTensor code because of RNG differentness on different ranks. That's a pretty subtle footgun. If we do the single-process thing, we'd need to figure out how to update the RNG states properly for each simulated rank.
hold on, i'm missing something. If the proposal is to use this new thing for tests where we don't care about numerics, then I don't see why we need a new thing. We already use fakePG + dtensor successfully in this repo, take a look at the examples folder. if we do want numerics, then i don't see how we can do this proposal bc we'd actually have to run the collectives.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can expand on this in chat if it will help, but essentially, I think a version of DTensor that is single process, single threaded, single device and matches numerics would be helpful for testing. For example, we can run DTensor tests on cheap 1gpu nodes in a single process when this happens. (I used to think multi-threaded process group was a good idea but I did try it out a bit and it's too easy to deadlock the collectives, so I don't think it's worth the squeeze since it's just for testing.)
How can you do this? Well, you work backwards from the constraints. If you're on one device, then all the data that logically lives on different ranks is actually all on the same device. How do you do collectives? Well, you can't run a real collective, but why do you need a real collective? They're all on the same device. Just do the equivalent regular Tensor operation (for example, allreduce is probably a stack and then sum).
I don't know what exactly the RNG problem is, but if you just need the RNGs to evolve lockstep over each virtual device, this can be done by maintaining N generators and switching into them appropriately as necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ALSO, absent this harness existing, I think doing tests where you run some DTensor ops under fake pg and make_fx and then do an expect test on the operations would work in a pinch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, those mechanics sound right to me.
| self.enable_implicit_replication: bool = False | ||
| self.implicit_strategy_op_tracker: list[torch._ops.OpOverload] = [] | ||
|
|
||
| def get_op_strategy( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm trying to think about the simplest way to write this PR, because it seems like it adds a lot of code that might not all be needed.
we already have a well-defined entrypoint from autoparallel into DTensor sharding prop via this function.
#158046 adds op_strategy_context which should take care of the machinery of pushing and later clearing the op strategy registration.
Could we have get_op_strategy do something simpler like
if op not in torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs:
op_ctxs.enter_context(op_strategy_context(op, replicate_op_strategy))
logger.warning(f"implicitly registering {op}")
return torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[op]
op_ctxs could be a contextlib.ExitStack that we store globally. I'd rather store it on the AutoParallel instance, but truth is it would be more correct for it to be a global/singleton since it matches up with global state inside DTensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I can't use
op_strategy_contextfrom upstream since it is only available in dtensor test. If this function is useful, I can try move it outside of the test. - We want to make a decision on whether we want to use
torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcsdirectly or make a copy for AP to use (@Francescaaa seems prefer copy in Support of implicit fallback and batch sharding strategy #46 (comment)).
(Split out the large PR from #46) Introduce the batch sharding strategy: ```python from torch.distributed.tensor._op_schema import RuntimeSchemaInfo from autoparallel.dtensor_util.utils import batch_shard_strategy from autoparallel.dtensor_util import strategy_pool # create strategy with input tensor 1 replicated, input tensor 2 shard on dim 0. Output tensor shard on dim 0: custom_shard_strategy = functools.partial(batch_shard_strategy, input_shard_dim=[None, 0], output_shard_dim=[0]) # register the strategy: strategy_pool.register_op_strategy(new_op)(custom_shard_strategy) ``` For details, check func description in autoparallel/dtensor_util/utils.py and example usage in tests/test_dtensor.py. [ghstack-poisoned]
(Split out the large PR from #46) Introduce the batch sharding strategy: ```python from torch.distributed.tensor._op_schema import RuntimeSchemaInfo from autoparallel.dtensor_util.utils import batch_shard_strategy from autoparallel.dtensor_util import strategy_pool # create strategy with input tensor 1 replicated, input tensor 2 shard on dim 0. Output tensor shard on dim 0: custom_shard_strategy = functools.partial(batch_shard_strategy, input_shard_dim=[None, 0], output_shard_dim=[0]) # register the strategy: strategy_pool.register_op_strategy(new_op)(custom_shard_strategy) ``` For details, check func description in autoparallel/dtensor_util/utils.py and example usage in tests/test_dtensor.py. [ghstack-poisoned]
(Split out the large PR #46) Support the implicit replication fallback startegy. How to use Implicit replication fallback: ```python from autoparallel.dtensor_util import strategy_pool with strategy_pool.replicate_for_unsupported_operators(): ... # (missing ops will use replicated strategy if possible) ``` Note: StrategyPool reuses the _op_dispatcher.sharding_propagator.op_strategy_funcs/op_to_rules/op_to_schema_info by reference now. [ghstack-poisoned]
(Split out the large PR from #46) Introduce the batch sharding strategy: ```python from torch.distributed.tensor._op_schema import RuntimeSchemaInfo from autoparallel.dtensor_util.utils import batch_shard_strategy from autoparallel.dtensor_util import strategy_pool # create strategy with input tensor 1 replicated, input tensor 2 shard on dim 0. Output tensor shard on dim 0: custom_shard_strategy = functools.partial(batch_shard_strategy, input_shard_dim=[None, 0], output_shard_dim=[0]) # register the strategy: strategy_pool.register_op_strategy(new_op)(custom_shard_strategy) ``` For details, check func description in autoparallel/dtensor_util/utils.py and example usage in tests/test_dtensor.py. [ghstack-poisoned]
(Split out the large PR from #46) Introduce the batch sharding strategy: ```python from torch.distributed.tensor._op_schema import RuntimeSchemaInfo from autoparallel.dtensor_util.utils import batch_shard_strategy from autoparallel.dtensor_util import strategy_pool # create strategy with input tensor 1 replicated, input tensor 2 shard on dim 0. Output tensor shard on dim 0: custom_shard_strategy = functools.partial(batch_shard_strategy, input_shard_dim=[None, 0], output_shard_dim=[0]) # register the strategy: strategy_pool.register_op_strategy(new_op)(custom_shard_strategy) ``` For details, check func description in autoparallel/dtensor_util/utils.py and example usage in tests/test_dtensor.py. [ghstack-poisoned]
(Split out the large PR from #46) Introduce the batch sharding strategy: ```python from torch.distributed.tensor._op_schema import RuntimeSchemaInfo from autoparallel.dtensor_util.utils import batch_shard_strategy from autoparallel.dtensor_util import strategy_pool # create strategy with input tensor 1 replicated, input tensor 2 shard on dim 0. Output tensor shard on dim 0: custom_shard_strategy = functools.partial(batch_shard_strategy, input_shard_dim=[None, 0], output_shard_dim=[0]) # register the strategy: strategy_pool.register_op_strategy(new_op)(custom_shard_strategy) ``` For details, check func description in autoparallel/dtensor_util/utils.py and example usage in tests/test_dtensor.py. [ghstack-poisoned]
(Split out the large PR from #46) Introduce the batch sharding strategy: ```python from torch.distributed.tensor._op_schema import RuntimeSchemaInfo from autoparallel.dtensor_util.utils import batch_shard_strategy from autoparallel.dtensor_util import strategy_pool # create strategy with input tensor 1 replicated, input tensor 2 shard on dim 0. Output tensor shard on dim 0: custom_shard_strategy = functools.partial(batch_shard_strategy, input_shard_dim=[None, 0], output_shard_dim=[0]) # register the strategy: strategy_pool.register_op_strategy(new_op)(custom_shard_strategy) ``` For details, check func description in autoparallel/dtensor_util/utils.py and example usage in tests/test_dtensor.py. [ghstack-poisoned]
(Split out the large PR #46) Support the implicit replication fallback startegy. How to use Implicit replication fallback: ```python from autoparallel.dtensor_util import strategy_pool with strategy_pool.replicate_for_unsupported_operators(): ... # (missing ops will use replicated strategy if possible) ``` Note: StrategyPool reuses the _op_dispatcher.sharding_propagator.op_strategy_funcs/op_to_rules/op_to_schema_info by reference now. [ghstack-poisoned]
|
@zpcore we can close this right? (since you split it into smaller PRs) |
* Support of explicit fallback [ghstack-poisoned] * Update on "Support of implicit fallback" (Split out the large PR #46) Support the implicit replication fallback startegy. How to use Implicit replication fallback: ```python from autoparallel.dtensor_util import strategy_pool with strategy_pool.replicate_for_unsupported_operators(): ... # (missing ops will use replicated strategy if possible) ``` Note: StrategyPool reuses the _op_dispatcher.sharding_propagator.op_strategy_funcs/op_to_rules/op_to_schema_info by reference now. [ghstack-poisoned] * Update on "Support of implicit fallback" (Split out the large PR #46) Support the implicit replication fallback startegy. How to use Implicit replication fallback: ```python from autoparallel.dtensor_util import strategy_pool with strategy_pool.replicate_for_unsupported_operators(): ... # (missing ops will use replicated strategy if possible) ``` Note: StrategyPool reuses the _op_dispatcher.sharding_propagator.op_strategy_funcs/op_to_rules/op_to_schema_info by reference now. [ghstack-poisoned]
(Split out the large PR #46) Support the implicit replication fallback startegy. How to use Implicit replication fallback: ```python from autoparallel.dtensor_util import strategy_pool with strategy_pool.replicate_for_unsupported_operators(): ... # (missing ops will use replicated strategy if possible) ``` Note: StrategyPool reuses the _op_dispatcher.sharding_propagator.op_strategy_funcs/op_to_rules/op_to_schema_info by reference now. [ghstack-poisoned]
(Split out the large PR from #46) Introduce the batch sharding strategy: ```python from torch.distributed.tensor._op_schema import RuntimeSchemaInfo from autoparallel.dtensor_util.utils import batch_shard_strategy from autoparallel.dtensor_util import strategy_pool # create strategy with input tensor 1 replicated, input tensor 2 shard on dim 0. Output tensor shard on dim 0: custom_shard_strategy = functools.partial(batch_shard_strategy, input_shard_dim=[None, 0], output_shard_dim=[0]) # register the strategy: strategy_pool.register_op_strategy(new_op)(custom_shard_strategy) ``` For details, check func description in autoparallel/dtensor_util/utils.py and example usage in tests/test_dtensor.py. [ghstack-poisoned]
(Split out the large PR #46) Support the implicit replication fallback startegy. How to use Implicit replication fallback: ```python from autoparallel.dtensor_util import strategy_pool with strategy_pool.replicate_for_unsupported_operators(): ... # (missing ops will use replicated strategy if possible) ``` Note: StrategyPool reuses the _op_dispatcher.sharding_propagator.op_strategy_funcs/op_to_rules/op_to_schema_info by reference now. [ghstack-poisoned]
(Split out the large PR from #46) Introduce the batch sharding strategy: ```python from torch.distributed.tensor._op_schema import RuntimeSchemaInfo from autoparallel.dtensor_util.utils import batch_shard_strategy from autoparallel.dtensor_util import strategy_pool # create strategy with input tensor 1 replicated, input tensor 2 shard on dim 0. Output tensor shard on dim 0: custom_shard_strategy = functools.partial(batch_shard_strategy, input_shard_dim=[None, 0], output_shard_dim=[0]) # register the strategy: strategy_pool.register_op_strategy(new_op)(custom_shard_strategy) ``` For details, check func description in autoparallel/dtensor_util/utils.py and example usage in tests/test_dtensor.py. [ghstack-poisoned]
This PR added Implicit replication fallback and batch sharding strategy. Those DTensor feature is added here for faster iteration since there is a concern (pytorch/pytorch#158476 (comment)) that those features may not be a good fit into DTensor upstream.
Please use
strategy_pool.get_op_strategy(op, op_schema)instead oftorch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[op](op_shcema)to retrieve strategies, as this can support implicit replication fallback.New supported strategies:
For details, check func description in autoparallel/dtensor_util/utils.py and example usage in tests/test_dtensor.py.