Skip to content

Conversation

@zpcore
Copy link
Contributor

@zpcore zpcore commented Jul 24, 2025

(Split out the large PR from #46)

Introduce the batch sharding strategy:

from torch.distributed.tensor._op_schema import RuntimeSchemaInfo
from autoparallel.dtensor_util.utils import batch_shard_strategy
from torch.distributed.tensor._ops.utils import register_op_strategy
# create strategy with input tensor 1 replicated, input tensor 2 shard on dim 0. Output tensor shard on dim 0:
custom_shard_strategy = functools.partial(batch_shard_strategy, input_shard_dim=[None, 0], output_shard_dim=[0])
# register the strategy:
register_op_strategy(new_op)(custom_shard_strategy)

For details, check func description in autoparallel/dtensor_util/utils.py and example usage in tests/test_dtensor.py.

Stack from ghstack (oldest at bottom):

zpcore added a commit that referenced this pull request Jul 24, 2025
ghstack-source-id: dfdc089
Pull Request resolved: #50
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 24, 2025
zpcore added a commit that referenced this pull request Jul 25, 2025
ghstack-source-id: 696aa5e
Pull Request resolved: #50
zpcore added a commit that referenced this pull request Jul 25, 2025
ghstack-source-id: 6753840
Pull Request resolved: #50
@zpcore zpcore requested review from XilunWu, fmassa and wconstab July 25, 2025 00:14
zpcore added a commit that referenced this pull request Jul 25, 2025
ghstack-source-id: cc08134
Pull Request resolved: #50
def batch_shard_strategy(
op_schema: OpSchema,
input_shard_dim: list[Optional[int]],
output_shard_dim: list[Optional[int]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The terminology is confusing here. Is the shard dim the batch dim? Or is it something else? If it is the batch dim, the comments below seem to imply there is only one batch dim, why is this a list?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's batch dim. There can be multiple tensor input. Each element in input_shard_dim maps to one input tensor. Same to output tensor(s).


# -------------define universal op strategy-------------
def batch_shard_strategy(
op_schema: OpSchema,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason this takes an OpSchema and not an OpStrategy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type definition of OpSchema is very unclear in DTensor. Here the OpSchema is a collection of OpStrategy for all input args. While for the output, most of the time we just have one output tensor, so OpStrategy is sufficient for output.

Note: It is the user's responsibility to make sure the sharded tensor for
processing is correct in shape.
"""
output_type = [str(ret.type) for ret in op_schema.op._schema.returns]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Running str on the type, very suspicious!

Copy link
Contributor Author

@zpcore zpcore Jul 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the recommended way to check the tensor output type?

zpcore added a commit that referenced this pull request Jul 29, 2025
ghstack-source-id: 604f1b6
Pull Request resolved: #50
@zpcore zpcore changed the base branch from gh/zpcore/2/base to main July 29, 2025 17:36
@zpcore zpcore force-pushed the gh/zpcore/2/head branch 4 times, most recently from 01a8203 to b98c184 Compare July 30, 2025 20:07
@zpcore
Copy link
Contributor Author

zpcore commented Jul 30, 2025

Check to see if there's any concerns for this PR. Or should we merge and do a try?

@zpcore zpcore force-pushed the gh/zpcore/2/head branch from b98c184 to bd5f109 Compare July 30, 2025 21:07
@ezyang
Copy link
Contributor

ezyang commented Jul 31, 2025

This is pretty reversible so I don't mind landing it and deciding what to do with it later. @zpcore I do hope this gets obsoleted by whatever we end up deciding to do with DTensor sharding though!

Copy link
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's land this and get some experience using the API for deepseek enablement.

@zpcore have you rebased? I'd like to at least kick off a job on llama3 mast to make sure nothing got broken

@zpcore zpcore force-pushed the gh/zpcore/2/head branch from bd5f109 to cc28a95 Compare July 31, 2025 23:00
@zpcore zpcore merged commit 385d06e into main Jul 31, 2025
6 checks passed
@zpcore zpcore deleted the gh/zpcore/2/head branch July 31, 2025 23:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants