Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Distributed][auto-parallel] Sharding Specification and rule discovery #336

Merged
merged 54 commits into from
Aug 4, 2023

Conversation

soodoshll
Copy link
Collaborator

This PR tries to implement step 1) and 3) mentioned in #335 (comment)

It contains an example of automatically discovering the sharding rules of all operators in resnet model, which involves dynamic batch size. You can run it with

python -m hidet.distributed.partition.rule

Now we only support discovering 1D partitions.

The rule discovery process is basically to enumerate all possible input-sharding specifications and check their validity. The checking takes two steps:

  1. Shape check: it will generate a sharded version of input tensors and re-forward the op to get the new output tensors. Then it will try to align the new output with the original whole output tensors, in order to find which output dimensions are sharded. For example, if the input $A$ is an 8x8 matrix, the op is $sin$, and we want to shard the first dimension of $A$ into 2 partitions, then it will generate an input shard with shape 4x8, and get the output shape 4x8, and we know that the first dimension of the output is also shared.
  2. Data dependency check. The validity of the shape check does not guarantee correctness. For example, if we have an 8x8 matrix $A$, and the op is $softmax$, then sharding along the second dimension gives the correct shape but wrong results, because the denominator of $softmax$ depends on the complete second dimension. So we analyze the range of data the output depends on, and the sharding is valid only if the output shard only depends on the input shard. For example, any shard of $A+1$ only depends on the same shard of $A$, so it is valid. We trace the indexing expressions(TensorElement) from fcompute of all outputs, to get the range of input indices that are accessed. And then we use Z3 SMT solver to check if any accessed input indices are out of the shard boundary. If not, the sharding is valid.

Note that even data dependency does not guarantee correctness completely (if the op plays tricks), but I find it can cover most of the ops I met. Though we still need to test it more thoroughly.

I can add more details about it to the RFC.

cc @yaoyaoding @xinli-git

yaoyaoding and others added 13 commits July 28, 2023 21:00
Fix llama2 and add test for num_heads != num_key_value_heads.

---------

Co-authored-by: Allan Lin <allan.lin@centml.ai>
Fix llama2 and add test for num_heads != num_key_value_heads.

---------

Co-authored-by: Allan Lin <allan.lin@centml.ai>
@yaoyaoding
Copy link
Member

Hi @soodoshll,

Can I know the main purpose of the z3 solver?

Because we also have a bound analyzer as well as an arthimatic simplifier in hidet, is any limitation of our internal analyzer and simpilifer?

We need to be cautious when introducing new dependency, thus more justification is needed when introducing a new one.

@soodoshll
Copy link
Collaborator Author

soodoshll commented Jul 31, 2023

Yes, we can replace it with the internal analyzer, except for dynamic shapes, where symbols might appear in the boundary expression.

I'm not very sure if we need to support dynamic shape for auto-sharding, since later if we use ILP to find the optimal sharding strategy, all shape symbols need to be concretized. It's difficult to solve an optimization problem involving dynamic shapes.

Therefore, if our ultimate goal is a function like hidet.distributed.partition(graph), the shape symbols need to be concretized at some point, and we partition this specially concretized version. Nevertheless, since the ops of concretized and dynamic graphs are one-to-one paired, we can project the partitioning on the concretized graph back to the dynamic-shaped graph, which is runnable but might have suboptimal performance.

So the whole pipeline might be like:

build flowgraph(dynamic) ---> concretization (concrete) ---> auto parallelization --> partition plan (concrete)
          |                                                                                      |
          +<-------------------------------------------------------------------------------------+
          | (projecting)
          v
partitioned graph(dynamic) -> optimize/compile

In one word, if auto-partition only happens for concretized shapes, then we do not need Z3.

@yaoyaoding @xinli-git Do you think the pipeline above makes sense?

@soodoshll
Copy link
Collaborator Author

z3 has been removed

@yaoyaoding
Copy link
Member

Make sense. For different sequence length, the best partition strategy might be different. Thus, we do need to optimize for some specific input size and then use the found stragegy for all sequence length.

@yaoyaoding
Copy link
Member

Hi @soodoshll, ping me when the PR is ready to be reviewed. Thanks!

soodoshll and others added 2 commits August 1, 2023 00:55
Assuming `U1` and `U2` are unary elementwise operators, and `B` is a
binary elementwise operator.
Introduce a new fused operator called Composite Elementwise Operation,
where `comp_elemwise(x, U1, U2, B) = B(U1(x), U2(x))`.
Also add subgraph rewrite rule. This allows for more fusion opportunity.
@soodoshll
Copy link
Collaborator Author

soodoshll commented Aug 1, 2023

Hi @yaoyaoding Update: I found that the engineering effort for the end-to-end partitioning pipeline is not as much as we thought. And it is hard to tell if some design is good without seeing how it is used by downstream components and what's lacking. So I think I can first try to build a whole auto-partition pipeline prototype, and iteratively improve it.

I spent some days building a prototype from sharding rule discovery, partition scheme search (ILP), and communication op injection, weight sharding, launching script etc.

I'm working on this branch, the code is still messy and needs to be refactored. It's runnable except for gathering final outputs from different gpus.

https://github.com/soodoshll/hidet/blob/auto-parallel-connect/examples/distributed/resnet.py

Once it's finished, we can have something like python -m hidet.distributed.launch [num_gpus] resnet.py [out_dir] helping us auto-partition the graph and run inference.

…-org#339)

This PR clears the intermediate object files generated during tuning.

This should reduce 2/3 cache size without hurting any functionality. 

hidet-org#338
@yaoyaoding
Copy link
Member

Hi @soodoshll, thanks for the update!

After you finish the whole pipeline, you can split it into several PRs and @xinli-git and I can help to review.

Copy link
Collaborator

@xinli-git xinli-git left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still going through the code to understand the sharding rules better but posting some comments first.

Feel free to merge without me once Yaoyao thinks it looks good :)


class IndexRewriter(ExprRewriter, ComputeRewriter):
def __init__(self):
super().__init__()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is implicitly just calling ExprRewriter.init ?

Maybe a nitpick but relying on MRO might be a bit error prone (if say in the future the inheritance order changes, or if a different method with the same name is added to the first base class ), it might be easier to just explicitly call Base.init(self) ? I guess same for super().visit, etc?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @xinli-git , the current writing looks fine. This function does not require any explicit ordering of calling its parent classes' init method. The default one super().__init__() will call them in MRO order and it works as we expected. Explicitly calling some parent class's init method might ignore some other parent classes and is error prone.

@@ -63,3 +63,61 @@ def test_llama2(device, opt):
print(current_memory_pool("cuda"))
print(current_memory_pool("cpu"))
print(current_memory_pool("vcuda"))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a github bug? these changes are already merged in https://github.com/hidet-org/hidet/pull/333/files ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh NVM, can we do a rebase of main to auto-parallel ? so this PR only includes the related changes.

found_rules = []

# We do not allow dynamic shapes
assert len(op.task.symbols) == 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we raise NotImpleentError instead?

python/hidet/distributed/partition/rule.py Outdated Show resolved Hide resolved
if len(new_inputs) < len(inputs):
continue
outputs = op.reforward(new_inputs)
except (ValueError, RuntimeError, AssertionError, IndexError):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do we arrive at these exceptions? I assume through trial and error?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. For example, if we have an op that adds up two 4x4 matrices, and we try to shard one of them along the first dimension, then it becomes 2x4 + 4x4, which is illegal.

# Only works for 1-D partition
slice_list = []
for i in range(len(tensor.shape)):
if i != shard_dim:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

regarding the -1 vs None, I think the code would be cleaner with None meaning not sharded and avoid -1. For example, here, we can just check if shard_dim is None and be more explicit.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Especially given -1 can stand for the last dimension :( I will switch to None.

Copy link
Member

@yaoyaoding yaoyaoding left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good to me now.

Feel free to merge after fixing the existing comments. Thanks @soodoshll !

def __init__(self):
super().__init__()
self.var2idx: Dict[Var, Expr] = {}
self.input_accessed_indices: Dict[TensorInput, List[Expr]] = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dict[TensorInput, List[List[Expr]]] ?

Comment on lines 70 to 75
index_rewriter = IndexRewriter()
for o in self.op.task.outputs:
if not isinstance(o, GridCompute):
self.valid = False
return # we don't support scalar output
index_rewriter.visit(o)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
index_rewriter = IndexRewriter()
for o in self.op.task.outputs:
if not isinstance(o, GridCompute):
self.valid = False
return # we don't support scalar output
index_rewriter.visit(o)
index_rewriter = IndexRewriter()
index_rewriter.visit(self.op.task.outputs)
self.valid = all(isinstance(out, GridCompute) for out in self.op.task.outputs)

@soodoshll soodoshll merged commit 0e860be into hidet-org:auto-parallel Aug 4, 2023
soodoshll added a commit that referenced this pull request Aug 4, 2023
…discovery (#342)

Please refer to #336 for original
discussion.

I accidentally messed up the branch :( So I re-open this PR.

Have reset the upstream/auto-parallel to the latest commit of the main
branch. This PR only contains auto-partition related modifications.

---------

Co-authored-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants