-
Notifications
You must be signed in to change notification settings - Fork 8
Support of implicit fallback #49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
|
Had an offline discussion regarding #46 (comment), since |
(Split out the large PR #46) Support the implicit replication fallback startegy. How to use Implicit replication fallback: ```python from autoparallel.dtensor_util import strategy_pool with strategy_pool.replicate_for_unsupported_operators(): ... # (missing ops will use replicated strategy if possible) ``` Note: StrategyPool reuses the _op_dispatcher.sharding_propagator.op_strategy_funcs/op_to_rules/op_to_schema_info by reference now. [ghstack-poisoned]
autoparallel/dtensor_util/utils.py
Outdated
| replicate_op_strategy = torch.distributed.tensor._ops.utils.replicate_op_strategy | ||
|
|
||
|
|
||
| class StrategyPool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My question would be, if we have the context manager above, do we actually need a StrategyPool class that maintains copies of the dtensor registries? We should probably pick one approach or the other. If we use the context manager, then a way to keep track of it here could be to use an ExitStack as I mentioned in #46
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! I removed the StrategyPool, now the structure is simpler.
tests/test_dtensor.py
Outdated
| ) | ||
|
|
||
|
|
||
| @contextmanager |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we have this in the above utility file we can delete it from here right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, if we can upstream the with_implicit_strategies.
autoparallel/dtensor_util/utils.py
Outdated
| ) | ||
| else: | ||
| # No stack available, just register permanently | ||
| register_op_strategy(op)(replicate_op_strategy) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused. Won't this register the op into dtensor itself? But above we are checking if the op is registered in our COPY of dtensor's registry, and I don't see us updating our copy. Should we just delete our copy and use this way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.op_strategy_funcs in StrategyPool is a reference to upstream op_strategy_funcs instead of COPY. Let me remove the reference and use upstream op_strategy_funcs to make it clear.
(Split out the large PR #46) Support the implicit replication fallback startegy. How to use Implicit replication fallback: ```python from autoparallel.dtensor_util import strategy_pool with strategy_pool.replicate_for_unsupported_operators(): ... # (missing ops will use replicated strategy if possible) ``` Note: StrategyPool reuses the _op_dispatcher.sharding_propagator.op_strategy_funcs/op_to_rules/op_to_schema_info by reference now. [ghstack-poisoned]
| ]( | ||
| op_schema | ||
| ) | ||
| out_strat = get_op_strategy(op, op_schema) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do this in its own refactor and land it asap?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should I merge this PR first so that we can quickly play with batch sharding strategy?
| # replication strategy fallback. | ||
| class CustomShardingPropagator( | ||
| torch.distributed.tensor._sharding_prop.ShardingPropagator | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm generally down on out of core things like this which are very closely entwined to internal implementation details of another library we're relying on: it is unlikely that these APIs have any test coverage in pytorch, which means we're more likely to accidentally break autoparallel from otherwise safe refactoring changes. I haven't though closely enough about what a good architecture looks like, but our default should be to make autoparallel rely only on public APIs and move anything that needs close coordination to pytorch core. (I'm OK with landing stuff to autoparallel on a temporary basis with a clear understanding it needs to go in core.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hah, this sounds like we need a test for test. This is used to help do quick test on strategy correctness under eager mode.
|
Damn, I merged a gh-stack PR... |
(Split out the large PR #46)
Support the implicit replication fallback startegy.
How to use Implicit replication fallback:
Note: StrategyPool reuses the _op_dispatcher.sharding_propagator.op_strategy_funcs/op_to_rules/op_to_schema_info by reference now.
Stack from ghstack (oldest at bottom):