Skip to content

Conversation

naoyam
Copy link
Collaborator

@naoyam naoyam commented Mar 16, 2021

Fixes #757

@naoyam naoyam requested a review from csarofeen March 16, 2021 23:43
@naoyam naoyam mentioned this pull request Mar 17, 2021
@naoyam naoyam force-pushed the fix-tv-parallelization branch from a415093 to 0ddf3e5 Compare March 17, 2021 15:47
@naoyam
Copy link
Collaborator Author

naoyam commented Mar 17, 2021

It turned out this is actually not the right thing to do. The parallel map contains mappings created with the the forward-bcast-mismatch enabled, so for example, it would create a mapping between I2*I1 of t4 and I1 of t2:

t0 = makeSymbolicTensor(1); // t0: [I1]
t1 = makeSymbolicTensor(2); // t1: [I2, I1]
t2 = t0 + 1; // t2: [I1]
t3 = broadcast(t2, {true, false}); // t3: [B1, I1]
t4 = t1 + t3; // t4: [I2, I1]
t4->merge(0, 1); // t4: [I2*I1]
t2->computeAt(t4, -1);
t4->axis(0)->parallelize(TIDx);

This PR would do t2->axis(0)->parallelize(TIDx) as the axis is mapped with the t4 axis. However, this is a problem since that would mean I2*I1 == blockDim.x and I1 == blockDim.x.

@naoyam naoyam force-pushed the fix-tv-parallelization branch from 0ddf3e5 to 9f9f902 Compare March 17, 2021 23:25
@naoyam
Copy link
Collaborator Author

naoyam commented Mar 17, 2021

Ended up adding ParallelTypeBitmap to kir::BroadcastOp. See 9f9f902. This is the only way I can think of for properly finding the parallelism of kir::BroadcastOp.

More broadly, I think a bigger problem is that we don't have straightforward way to know parallelism of kir:TensorView and kir::IterDomain. The truth only lies in the ComputeAt parallel map, which only exists at the lowering time. Referring to kir::IterDomain::parallelType() is not robust as it may not be the same as the real parallel type. We use that for codegen of kir::ReductionOp, where it is fortunately safe as there should be no reduction in the CA axes.

@naoyam naoyam requested a review from csarofeen March 17, 2021 23:43
@csarofeen
Copy link
Owner

I2*I1 == blockDim.x and I1 == blockDim.x actually seems fine to me in this case.

blockDim.x should be set to the maximum size. I think the issue is that we substitute values in the IR for blockDim.x where we shouldn't. I think the original issue stems from the fact that a tensor size gets substituted for this value, and thread bindings can be ambiguous in the presence of broadcst. I'm uncertain we can resolve all cases.

@csarofeen
Copy link
Owner

I thought that was how we started down this path: #622

Copy link
Owner

@csarofeen csarofeen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clearing approval as looks like we should discuss/think about this more. Broadcast parallelism is definitely a complex topioc.

@naoyam
Copy link
Collaborator Author

naoyam commented Mar 18, 2021

The substitution is done here:

https://github.com/csarofeen/pytorch/blob/20_12_3_devel/torch/csrc/jit/codegen/cuda/kernel_ir.cpp#L100-L109

If we allow I2*I1 == blockDim.x and I1 == blockDim.x, the substitution is no longer valid.

I see two options:

  1. Parallelize all kir::IterDomain when inferred from the computeAt map and stop the substitution. (ab83e3e)
  2. Annotate kir::BroadcastOp (9f9f902)

I'd say the second option is more conservative. It only solves this particular problem with BroadcastOp. The first option may be more preferable as it conveys the parallelism information to KIR. However, there may be side effects as the substitution problem.

The first option seems better to me from long-term perspectives.

@csarofeen
Copy link
Owner

Can you try 1 and run the reduction benchmark suite before and after to see if you find any serious perf regressions. We use a lot of tensor size information in our kernels anyways, I think removing the substitution shouldn't be significant, though I could definitely be wrong.

naoyam added 7 commits March 19, 2021 09:45
…onships"

This reverts commit ab83e3e6367ab186498b2d0ab81ca09dcb52f434.
There is no easy way to know which parallel types are used for
kir::TensorView after the lowering as the ComputeAt parallel map is not
maintained. Adds that information to kir::BroadcastOp as it is needed
for codegen.
@naoyam naoyam force-pushed the fix-tv-parallelization branch from 30d5976 to c0fa0c9 Compare March 19, 2021 17:06
@naoyam
Copy link
Collaborator Author

naoyam commented Mar 19, 2021

@csarofeen Done. I changed kir::IterDomain::extent() to just return extent_, which is just the same as rawExtent(). All tests are working fine.

@naoyam naoyam requested a review from csarofeen March 19, 2021 17:08
@csarofeen csarofeen merged commit 4df7a6a into 20_12_3_devel Mar 19, 2021
@csarofeen csarofeen deleted the fix-tv-parallelization branch June 9, 2021 13:51
jjsjann123 pushed a commit that referenced this pull request Dec 22, 2022
…e registered EBCs with shardedTensors as registered modules (#758) (pytorch#88026)

Summary:
X-link: meta-pytorch/torchrec#758

This PR fixes a bug in FSDP/DDP, where ShardedTensors are not supported even if passed in as params to ignore.
this is important for composability because TorchRec named_parameters() will return FQN of shardedTensors (as defined in goals)
It defines device of ShardedTensor to be None when local_tensor() does not exist on rank

update ShardedEmbeddingBagCollection to be composable according to https://docs.google.com/document/d/1TBJSd5zgEg6cRcXv3Okuj7bBkqQwGS2IPh4TLWNNzFI/edit

Differential Revision: D40458625

Pull Request resolved: pytorch#88026
Approved by: https://github.com/wanchaol, https://github.com/rohan-varma
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Missing parallel broadcast
2 participants