-
Notifications
You must be signed in to change notification settings - Fork 8
introduce batch sharding strategy #50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| def batch_shard_strategy( | ||
| op_schema: OpSchema, | ||
| input_shard_dim: list[Optional[int]], | ||
| output_shard_dim: list[Optional[int]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The terminology is confusing here. Is the shard dim the batch dim? Or is it something else? If it is the batch dim, the comments below seem to imply there is only one batch dim, why is this a list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's batch dim. There can be multiple tensor input. Each element in input_shard_dim maps to one input tensor. Same to output tensor(s).
|
|
||
| # -------------define universal op strategy------------- | ||
| def batch_shard_strategy( | ||
| op_schema: OpSchema, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason this takes an OpSchema and not an OpStrategy?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type definition of OpSchema is very unclear in DTensor. Here the OpSchema is a collection of OpStrategy for all input args. While for the output, most of the time we just have one output tensor, so OpStrategy is sufficient for output.
| Note: It is the user's responsibility to make sure the sharded tensor for | ||
| processing is correct in shape. | ||
| """ | ||
| output_type = [str(ret.type) for ret in op_schema.op._schema.returns] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Running str on the type, very suspicious!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the recommended way to check the tensor output type?
01a8203 to
b98c184
Compare
|
Check to see if there's any concerns for this PR. Or should we merge and do a try? |
|
This is pretty reversible so I don't mind landing it and deciding what to do with it later. @zpcore I do hope this gets obsoleted by whatever we end up deciding to do with DTensor sharding though! |
wconstab
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's land this and get some experience using the API for deepseek enablement.
@zpcore have you rebased? I'd like to at least kick off a job on llama3 mast to make sure nothing got broken
(Split out the large PR from #46)
Introduce the batch sharding strategy:
For details, check func description in autoparallel/dtensor_util/utils.py and example usage in tests/test_dtensor.py.
Stack from ghstack (oldest at bottom):