-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
More generic grouped grid reduction kernel #1740
Conversation
RefTuple<DataTypes...> out, | ||
const ConstRefTuple<DataTypes...>& inp, | ||
VolatilePtrTuple<DataTypes...> global_work_buffer, | ||
const LocalTuple<DataTypes...>& init_val, | ||
int64_t* global_sync_buffer, | ||
void* shared_mem, | ||
bool read_pred, // Prevent reading from out of bounds memory | ||
bool write_pred) { // Prevent from writing out of bounds | ||
const LocalTuple<BoolTypes...>& read_preds, | ||
const LocalTuple<BoolTypes...>& write_preds, | ||
Funcs... funcs) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the main change of this PR. Each tuple aggregates a parameter of each grid operations. The number of operations can be as large as 8.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't do a super in-depth review of all the functions, but overall LGTM. Would like to see an additional test to make sure the tuple's support multiple dtypes across the reductions. Would also just like to see some more comments in the helper functions like the for each functions in the runtime files. Simply reiterating the necessity of the different versions of the functions would be helpful for folks that may need to go through these files in the future. Thanks!
auto tv0 = makeSymbolicTensor(1); | ||
fusion.addInput(tv0); | ||
|
||
auto tv1 = sum(tv0, {0}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add some mixed types in the reductions (like one on float vals, one n double vals, and one on int vals)? It looked like that should be supported, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Added one more test.
0754259
to
4398e31
Compare
Thanks for the review. Added more comments. |
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/ Code changes includes: - TransformPropagator refactor: switched to Dijkstra instead of exhaustive enumeration on all possible paths to reduce compilation time on transform propagation; - Indexing refactor: remove reference tensor creation in all tensor indexing logic (csarofeen#1690) - (more) generic grouped grid reduction kernel; - Minor parser/fuser patches: 1. zero-dim tensor reduction support 3. no-op binary removal within fused graph 4. expand supported in fusion Squashed commits to WAR github API Commits that's actually in this PR from the devel branch: ``` a054b3e Refactor TransormPropagator to allow specifying a position and propagating to part of the DAG (csarofeen#1775) d67e1cd Indexing refactor stage 1: remove reference tensor creation in all tensor indexing logic (csarofeen#1690) 1b65299 Issue 1770 (csarofeen#1774) 35b0427 Avoid compilation errors like below: (csarofeen#1773) 452c773 Ignore reductions of zero-dim tensors per PyTorch conventions (csarofeen#1771) 31d6c56 TransformPropagator refactor (csarofeen#1769) 570c5a8 Merge pull request csarofeen#1767 from csarofeen/upstream_merge_0621 9d6c3d8 merging upstream 61305cd 0ed815f New TransformPropagator algorithm (csarofeen#1763) 6c19520 no-op binary removal (csarofeen#1764) ec7fa41 Proper propagation of IterType (csarofeen#1762) b263562 Fix dimensionality check (csarofeen#1759) 2d6343f More generic grouped grid reduction kernel (csarofeen#1740) 64e2b56 [nvfuser] prevent spamming warning message (pytorch#77777) (csarofeen#1758) 0c43162 [nvFuser] Improving bitwise ops support (pytorch#77158) (csarofeen#1757) b93a147 Parser expand (csarofeen#1754) ``` RUN_TORCHBENCH: nvfuser Pull Request resolved: pytorch#80355 Approved by: https://github.com/davidberard98
This PR generalizes the grouped grid reduction kernel with respect to the number of grouped reductions. The new kernel itself should work with an arbitrary number of inputs, but the underlying data structure, Tuple, still explicitly needs to be specialized for the number of values, which is currently limited to 8. Previously, there's only two-way grouped kernel.
See
FusionGroupAllreduce4
, which groups 8 grid reductions into a single grouped grid reduction.This PR is meant to allow more aggressive grouping of grid reductions, e.g., grouping across iterations.
There's still no support for Welford. Fusions with multiple Welford reductions would be unlikely, so horizontal grouping wouldn't be important, but there would be opportunities to group across iterations.