-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Distributed][auto-parallel] Sharding Specification and rule discovery #336
Conversation
Fix llama2 and add test for num_heads != num_key_value_heads. --------- Co-authored-by: Allan Lin <allan.lin@centml.ai>
Fix llama2 and add test for num_heads != num_key_value_heads. --------- Co-authored-by: Allan Lin <allan.lin@centml.ai>
Hi @soodoshll, Can I know the main purpose of the z3 solver? Because we also have a bound analyzer as well as an arthimatic simplifier in hidet, is any limitation of our internal analyzer and simpilifer? We need to be cautious when introducing new dependency, thus more justification is needed when introducing a new one. |
Yes, we can replace it with the internal analyzer, except for dynamic shapes, where symbols might appear in the boundary expression. I'm not very sure if we need to support dynamic shape for auto-sharding, since later if we use ILP to find the optimal sharding strategy, all shape symbols need to be concretized. It's difficult to solve an optimization problem involving dynamic shapes. Therefore, if our ultimate goal is a function like So the whole pipeline might be like:
In one word, if auto-partition only happens for concretized shapes, then we do not need Z3. @yaoyaoding @xinli-git Do you think the pipeline above makes sense? |
z3 has been removed |
Make sense. For different sequence length, the best partition strategy might be different. Thus, we do need to optimize for some specific input size and then use the found stragegy for all sequence length. |
Hi @soodoshll, ping me when the PR is ready to be reviewed. Thanks! |
Assuming `U1` and `U2` are unary elementwise operators, and `B` is a binary elementwise operator. Introduce a new fused operator called Composite Elementwise Operation, where `comp_elemwise(x, U1, U2, B) = B(U1(x), U2(x))`. Also add subgraph rewrite rule. This allows for more fusion opportunity.
Hi @yaoyaoding Update: I found that the engineering effort for the end-to-end partitioning pipeline is not as much as we thought. And it is hard to tell if some design is good without seeing how it is used by downstream components and what's lacking. So I think I can first try to build a whole auto-partition pipeline prototype, and iteratively improve it. I spent some days building a prototype from sharding rule discovery, partition scheme search (ILP), and communication op injection, weight sharding, launching script etc. I'm working on this branch, the code is still messy and needs to be refactored. It's runnable except for gathering final outputs from different gpus. Once it's finished, we can have something like |
…-org#339) This PR clears the intermediate object files generated during tuning. This should reduce 2/3 cache size without hurting any functionality. hidet-org#338
Hi @soodoshll, thanks for the update! After you finish the whole pipeline, you can split it into several PRs and @xinli-git and I can help to review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am still going through the code to understand the sharding rules better but posting some comments first.
Feel free to merge without me once Yaoyao thinks it looks good :)
|
||
class IndexRewriter(ExprRewriter, ComputeRewriter): | ||
def __init__(self): | ||
super().__init__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this is implicitly just calling ExprRewriter.init ?
Maybe a nitpick but relying on MRO might be a bit error prone (if say in the future the inheritance order changes, or if a different method with the same name is added to the first base class ), it might be easier to just explicitly call Base.init(self) ? I guess same for super().visit, etc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @xinli-git , the current writing looks fine. This function does not require any explicit ordering of calling its parent classes' init method. The default one super().__init__()
will call them in MRO order and it works as we expected. Explicitly calling some parent class's init method might ignore some other parent classes and is error prone.
@@ -63,3 +63,61 @@ def test_llama2(device, opt): | |||
print(current_memory_pool("cuda")) | |||
print(current_memory_pool("cpu")) | |||
print(current_memory_pool("vcuda")) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a github bug? these changes are already merged in https://github.com/hidet-org/hidet/pull/333/files ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh NVM, can we do a rebase of main to auto-parallel ? so this PR only includes the related changes.
found_rules = [] | ||
|
||
# We do not allow dynamic shapes | ||
assert len(op.task.symbols) == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we raise NotImpleentError instead?
if len(new_inputs) < len(inputs): | ||
continue | ||
outputs = op.reforward(new_inputs) | ||
except (ValueError, RuntimeError, AssertionError, IndexError): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how do we arrive at these exceptions? I assume through trial and error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. For example, if we have an op that adds up two 4x4 matrices, and we try to shard one of them along the first dimension, then it becomes 2x4 + 4x4, which is illegal.
# Only works for 1-D partition | ||
slice_list = [] | ||
for i in range(len(tensor.shape)): | ||
if i != shard_dim: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
regarding the -1 vs None, I think the code would be cleaner with None meaning not sharded and avoid -1. For example, here, we can just check if shard_dim is None and be more explicit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Especially given -1 can stand for the last dimension :( I will switch to None.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks good to me now.
Feel free to merge after fixing the existing comments. Thanks @soodoshll !
def __init__(self): | ||
super().__init__() | ||
self.var2idx: Dict[Var, Expr] = {} | ||
self.input_accessed_indices: Dict[TensorInput, List[Expr]] = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dict[TensorInput, List[List[Expr]]]
?
index_rewriter = IndexRewriter() | ||
for o in self.op.task.outputs: | ||
if not isinstance(o, GridCompute): | ||
self.valid = False | ||
return # we don't support scalar output | ||
index_rewriter.visit(o) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
index_rewriter = IndexRewriter() | |
for o in self.op.task.outputs: | |
if not isinstance(o, GridCompute): | |
self.valid = False | |
return # we don't support scalar output | |
index_rewriter.visit(o) | |
index_rewriter = IndexRewriter() | |
index_rewriter.visit(self.op.task.outputs) | |
self.valid = all(isinstance(out, GridCompute) for out in self.op.task.outputs) |
…discovery (#342) Please refer to #336 for original discussion. I accidentally messed up the branch :( So I re-open this PR. Have reset the upstream/auto-parallel to the latest commit of the main branch. This PR only contains auto-partition related modifications. --------- Co-authored-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
This PR tries to implement step 1) and 3) mentioned in #335 (comment)
It contains an example of automatically discovering the sharding rules of all operators in resnet model, which involves dynamic batch size. You can run it with
python -m hidet.distributed.partition.rule
Now we only support discovering 1D partitions.
The rule discovery process is basically to enumerate all possible input-sharding specifications and check their validity. The checking takes two steps:
TensorElement
) fromfcompute
of all outputs, to get the range of input indices that are accessed. And then we use Z3 SMT solver to check if any accessed input indices are out of the shard boundary. If not, the sharding is valid.Note that even data dependency does not guarantee correctness completely (if the op plays tricks), but I find it can cover most of the ops I met. Though we still need to test it more thoroughly.
I can add more details about it to the RFC.
cc @yaoyaoding @xinli-git