-
Notifications
You must be signed in to change notification settings - Fork 7
Dynamic shape latency improvement, step 1.5. Misc runtime allocation and cast cleanup #1114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
4434f95
to
39149a6
Compare
I left comments. The only major comment is whether we could use I'd be surprised with some of the changes if they really had performance impact. For example, 27dff5e#r707737832. It would eliminate the call to |
ba0a89e
to
fe2bf76
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for updating the PR. LGTM.
cc8f675
to
c9759b4
Compare
//! This entry type is a simplified version of ParallelIterExtentMap. | ||
//! | ||
//! For launch parameter binding we only need the most concrete iterdomain | ||
//! in each disjoint set stored in CaParallelMap. This entry stores the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@naoyam are we certain this doesn't have to be index map?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually think the parallel map is fine here since concrete domains are selected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This actually makes me think again about ParallelDimensionMap
as well. For example:
auto tv0 = makeSymbolicTensor(1);
fusion.addInput(tv0);
auto tv1 = makeSymbolicTensor(2);
fusion.addInput(tv1);
auto tv2 = broadcast(tv0, {true, false});
auto tv3 = add(tv2, tv1);
fusion.addOutput(tv4);
tv3->merge(0, 1);
tv0->computeAt(tv3, -1);
tv1->computeAt(tv3, -1);
tv3->axis(0)->parallelize(ParallelType::BIDx);
Inputs:
T0_g[ iS0{i1} ], float
T1_g[ iS9{( i4 * i7 )} ], float
Outputs:
T3_g[ iblockIdx.x7{( i4 * i1 )} ] produce_pos( 1), float
%kernel_math {
T2_l[ iS8{( 1 * i1 )} ] ca_pos( 1 ) = broadcast( T0_g[ iS0{i1} ] )
T3_g[ iblockIdx.x7{( i4 * i1 )} ] produce_pos( 1)
= T2_l[ iS8{( 1 * i1 )} ] ca_pos( 1 )
+ T1_g[ iS9{( i4 * i7 )} ];
}
With this fusion, the parallel dimension map looks like:
blockIdx.x: gridDim.x, non-exact
blockIdx.y: unused
blockIdx.z: unused
threadIdx.x: unused
threadIdx.y: unused
threadIdx.z: unused
Note that BIDx
is marked as non-exact, even though it is actually exact. This is because ParallelDimensionMap
uses the index map so the merged axis of T3
is not mapped with the axis of T0
, but both of them are mapped with each other in the parallel map and are parallelized by BIDx
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought this should be not exact since iS8{( 1 * i1 )} != iblockIdx.x7{( i4 * i1 )}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought that was the behavior we're interested in.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you meant iS0{i1} != iblockIdx.x7{( i4 * i1 )}
. iS8{( 1 * i1 )}
and iblockIdx.x7{( i4 * i1 )}
are mapped in the index map. The difference comes from the forwarding, so iS0{i1}
and iblockIdx.x7{( i4 * i1 )}
are mapped in the parallel map but not in the index map.
Since i1 != i4 * i1
, they appear to make BIDx
non-exact. However, with the computeAt, indexing T0
is always done with i4 * i1
. In this particular case, while iS0{i1}
is marked as parallelized with BIDx
, its indexing is blockDim.x % i1
, so it never goes beyond i1
. So, in that sense, BIDx
is still exact, in other words, we don't need a predicate like blockDim.x < i1
.
This seems related to the issue we discussed before about the indexing change with and without computeAt.
I feel this is too much intricated. I'm spending too much time to think about which maps to use. Would be great if we could come up with a single map, but not sure if that's reasonably possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
iS8{( 1 * i1 )} != iblockIdx.x7{( i4 * i1 )}
should now be true in the index map, it was true that it used to map, but I explicitly changed that because it shouldn't.
This makes sense though. I forgot we don't map to the right of the compute at point in the parallel map. So really this is what parallel type is being used to index into the tensor. In that instance then yes, the parallel map is enough to get the unique entries of parallelization bound into the problem.
All the changes should be for better performance only and not supposed to change any of the compile/runtime semantics.
Some effort was spent on making sure each commit is a self-contained optimization so commit-by-commit review might be easier.
Update:
I have removed the commits that made significant difference only under profiled environment (nsys), but google benchmark didn't show improvement (within 5us all together). They could be added back if needed.
Benchmark result:
@
c9759b44e8aec7cd45990b22130589dbffcf26e8
comparison reference available at #1098