Skip to content

Conversation

@wconstab
Copy link
Contributor

Moves from ad-hoc and incomplete support for a couple of these ops to supporting all of the standard factory ops with sharding support.

Still needs further work to add memory constraints to encourage sharding.

@wconstab wconstab requested a review from fmassa July 19, 2025 14:21
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 19, 2025
@wconstab wconstab force-pushed the whc/factory branch 2 times, most recently from 6870a7f to 304ebc6 Compare July 19, 2025 14:26
Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had a first pass, I think it is looking pretty good and close to be merged after it has been rebased

Comment on lines +413 to +417
# TODO: there shouldn't actually be a row here, since there is no input to the op and the rows correspond
# to the inputs. However, the optimization code is not set up to tolerate input-less ops, so hack around it
# (see "/data/users/whc/autoparallel/autoparallel/optimize_sharding.py", line 226, in walk_over_options)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree with the comment and the current solution you implemented is what I would recommend doing as well.

Comment on lines 53 to 55
# but index_put op insists on looking at 'input_specs' of its input, which seems absurd.
# so just copy it for now and fix later
strat.input_specs = (strat.output_specs,)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it still the case even after using the index_put from PyTorch main? Given that we have removed our custom rule in #43

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point ill have to check, i may be able to remove this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i am removing this for now, because it is definitely not needed for the tests/examples landed on main. If we run into it on DS3 branch we can revisit a better fix.

assert isinstance(dtype, torch.dtype), dtype

# TODO: ensure the solver knows that it is more expensive to Replicate factory functions than shard
# for now, put replicate last since this might encourage sharding :?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replicate last since this might encourage sharding

Hum, I'm not sure we have guarantees wrt that... in any case, let's see

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, i'll make the comment clearer: ordering did appear to change the outcome locally for my experiment, but i agree its not a stable guarantee and we should try to deal with this a proper way. i was mainly happy that i could ensure that sharding of factories was at least happening so i could test that it worked.



@register_opschema_rule(TENSOR_FACTORY_OPS)
def factory_rule(mesh, op_schema: OpSchema) -> OpStrategy:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could replace / unify this implementation with the implementation in _create_all_options https://github.com/pytorch-labs/autoparallel/blob/b53ad103b2054177db1c0ac50d0b0021a5b8bb57/autoparallel/propagation_rules.py#L119-L159 ?

Can be left for the future, just bringing this up as I think they are quite similar

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i wasn't aware of these, agree we can refactor these to remove the duplication. I can do that in a follow up PR. A few questions

  1. what is "_create_all_options_no_nested_sharding" intended for? It is currently unused so i might delete it unless you have a use case in mind
  2. "_create_all_options" is pretty close to 'factory rule'. I can probably make factory_rule delegate to _create_all_options.
  3. _create_all_options is a bit vague and the todo suggests adding partial. but factories would not want partial. I'll think a bit more about how to factor the function

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About your points:

what is "_create_all_options_no_nested_sharding" intended for? It is currently unused so i might delete it unless you have a use case in mind

I think we can remove it. It was the first implementation that I did because I was learning DTensor on the go and was using DTensorSpec.from_dim_map to generate the specs. But then I realized that it couldn't generate nested shardings like S(0)S(0), so I wrote the second function. I didn't know back then if this would be useful or not so I kept it just in case, but I think we can remove it now

"_create_all_options" is pretty close to 'factory rule'. I can probably make factory_rule delegate to _create_all_options.

Yes, I think it would be a good thing to unify both

_create_all_options is a bit vague and the todo suggests adding partial. but factories would not want partial. I'll think a bit more about how to factor the function

That comment is legacy from _create_all_options_no_nested_sharding which went into _create_all_options -- I don't think we want Partial support anymore, so it's basically the same thing as factory_rule

from .propagation_rules import _op_partial_rules, _op_rules, remove_invalid_configs, TENSOR_FACTORY_OPS


def propagate_tensor_meta(op, user_args, user_kwargs, out_strat):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: as a follow-up work, it might be good to disable those functions to see if we are still missing some cases in DTensor sharding propagation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

disable which, 'propagate_tensor_meta'? I can try that in another PR.

Moves from ad-hoc and incomplete support for a couple of these ops to
supporting all of the standard factory ops with sharding support.

Still needs further work to add memory constraints to encourage
sharding.
Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Let's think about unifying _create_all_options in another PR

@fmassa fmassa merged commit 6be7804 into main Jul 25, 2025
5 checks passed
@fmassa fmassa deleted the whc/factory branch July 25, 2025 08:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants