-
Notifications
You must be signed in to change notification settings - Fork 8
Add factory_strategy to support empty, full, ... #44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
6870a7f to
304ebc6
Compare
fmassa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had a first pass, I think it is looking pretty good and close to be merged after it has been rebased
| # TODO: there shouldn't actually be a row here, since there is no input to the op and the rows correspond | ||
| # to the inputs. However, the optimization code is not set up to tolerate input-less ops, so hack around it | ||
| # (see "/data/users/whc/autoparallel/autoparallel/optimize_sharding.py", line 226, in walk_over_options) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I agree with the comment and the current solution you implemented is what I would recommend doing as well.
autoparallel/utils.py
Outdated
| # but index_put op insists on looking at 'input_specs' of its input, which seems absurd. | ||
| # so just copy it for now and fix later | ||
| strat.input_specs = (strat.output_specs,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it still the case even after using the index_put from PyTorch main? Given that we have removed our custom rule in #43
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point ill have to check, i may be able to remove this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i am removing this for now, because it is definitely not needed for the tests/examples landed on main. If we run into it on DS3 branch we can revisit a better fix.
autoparallel/propagation_rules.py
Outdated
| assert isinstance(dtype, torch.dtype), dtype | ||
|
|
||
| # TODO: ensure the solver knows that it is more expensive to Replicate factory functions than shard | ||
| # for now, put replicate last since this might encourage sharding :? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replicate last since this might encourage sharding
Hum, I'm not sure we have guarantees wrt that... in any case, let's see
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea, i'll make the comment clearer: ordering did appear to change the outcome locally for my experiment, but i agree its not a stable guarantee and we should try to deal with this a proper way. i was mainly happy that i could ensure that sharding of factories was at least happening so i could test that it worked.
|
|
||
|
|
||
| @register_opschema_rule(TENSOR_FACTORY_OPS) | ||
| def factory_rule(mesh, op_schema: OpSchema) -> OpStrategy: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we could replace / unify this implementation with the implementation in _create_all_options https://github.com/pytorch-labs/autoparallel/blob/b53ad103b2054177db1c0ac50d0b0021a5b8bb57/autoparallel/propagation_rules.py#L119-L159 ?
Can be left for the future, just bringing this up as I think they are quite similar
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i wasn't aware of these, agree we can refactor these to remove the duplication. I can do that in a follow up PR. A few questions
- what is "_create_all_options_no_nested_sharding" intended for? It is currently unused so i might delete it unless you have a use case in mind
- "_create_all_options" is pretty close to 'factory rule'. I can probably make factory_rule delegate to _create_all_options.
- _create_all_options is a bit vague and the todo suggests adding partial. but factories would not want partial. I'll think a bit more about how to factor the function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
About your points:
what is "_create_all_options_no_nested_sharding" intended for? It is currently unused so i might delete it unless you have a use case in mind
I think we can remove it. It was the first implementation that I did because I was learning DTensor on the go and was using DTensorSpec.from_dim_map to generate the specs. But then I realized that it couldn't generate nested shardings like S(0)S(0), so I wrote the second function. I didn't know back then if this would be useful or not so I kept it just in case, but I think we can remove it now
"_create_all_options" is pretty close to 'factory rule'. I can probably make factory_rule delegate to _create_all_options.
Yes, I think it would be a good thing to unify both
_create_all_options is a bit vague and the todo suggests adding partial. but factories would not want partial. I'll think a bit more about how to factor the function
That comment is legacy from _create_all_options_no_nested_sharding which went into _create_all_options -- I don't think we want Partial support anymore, so it's basically the same thing as factory_rule
| from .propagation_rules import _op_partial_rules, _op_rules, remove_invalid_configs, TENSOR_FACTORY_OPS | ||
|
|
||
|
|
||
| def propagate_tensor_meta(op, user_args, user_kwargs, out_strat): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: as a follow-up work, it might be good to disable those functions to see if we are still missing some cases in DTensor sharding propagation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
disable which, 'propagate_tensor_meta'? I can try that in another PR.
Moves from ad-hoc and incomplete support for a couple of these ops to supporting all of the standard factory ops with sharding support. Still needs further work to add memory constraints to encourage sharding.
fmassa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Let's think about unifying _create_all_options in another PR
Moves from ad-hoc and incomplete support for a couple of these ops to supporting all of the standard factory ops with sharding support.
Still needs further work to add memory constraints to encourage sharding.