-
Notifications
You must be signed in to change notification settings - Fork 7
Dynamic shape latency improvement, step 1.5. Misc runtime allocation and cast cleanup #1114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cf2faf8
43d9078
c9759b4
1628dbb
802eec5
832ce38
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -80,6 +80,7 @@ namespace caching { | |
enum class CompileTimeEntryType { | ||
PARALLEL_BINDING_ITERDOMAINS, | ||
PARALLEL_ITER_EXTENT_MAP, | ||
SIMPLIFIED_PARALLEL_ITER_EXTENT_MAP, | ||
WARP_PADDED_PARALLEL_EXTENTS, | ||
VECTORIZED_TENSOR_VALIDATION, | ||
INPUT_ALIAS_INDICES, | ||
|
@@ -114,6 +115,27 @@ class ParallelIterExtentMap { | |
CompileTimeEntryType::PARALLEL_ITER_EXTENT_MAP; | ||
}; | ||
|
||
//! Compile-time info to be cached in each FusionExecutor: | ||
//! SimplifiedParallelIterExtentMap | ||
//! This entry type is a simplified version of ParallelIterExtentMap. | ||
//! | ||
//! For launch parameter binding we only need the most concrete iterdomain | ||
//! in each disjoint set stored in CaParallelMap. This entry stores the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @naoyam are we certain this doesn't have to be index map? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I actually think the parallel map is fine here since concrete domains are selected. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This actually makes me think again about
With this fusion, the parallel dimension map looks like:
Note that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought this should be not exact since There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought that was the behavior we're interested in. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you meant Since This seems related to the issue we discussed before about the indexing change with and without computeAt. I feel this is too much intricated. I'm spending too much time to think about which maps to use. Would be great if we could come up with a single map, but not sure if that's reasonably possible. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This makes sense though. I forgot we don't map to the right of the compute at point in the parallel map. So really this is what parallel type is being used to index into the tensor. In that instance then yes, the parallel map is enough to get the unique entries of parallelization bound into the problem. |
||
//! remaining list of extents for binding after this simplification. | ||
//! | ||
//! We still need ParallelIterExtentMap since we want to bind the concrete | ||
//! values to the extents of all parallelized iterdomains. We would be | ||
//! able to save these bindings if the integer machine has a notion of | ||
//! equality and could be configured compile time. But that'd be a longer | ||
//! term target. | ||
class SimplifiedParallelIterExtentMap { | ||
public: | ||
using DataType = | ||
std::unordered_map<ParallelType, std::vector<const kir::Val*>, TypeHash>; | ||
static const CompileTimeEntryType EntryType = | ||
CompileTimeEntryType::SIMPLIFIED_PARALLEL_ITER_EXTENT_MAP; | ||
}; | ||
|
||
//! WarpPaddedExtentsInfo: | ||
//! Auxiliary data type for entry class WarpPaddedParallelExtents | ||
struct WarpPaddedExtentsInfo { | ||
|
@@ -269,6 +291,12 @@ std::unique_ptr<ParallelExtentMap> getParallelIterExtents( | |
GpuLower& lower, | ||
std::vector<IterDomain*>& parallel_binding_ids); | ||
|
||
//! Returns the simplified set of extents necessary for launch parameter | ||
//! binding. | ||
std::unique_ptr<ParallelExtentMap> getSimplifiedParallelIterExtents( | ||
GpuLower& lower, | ||
std::vector<IterDomain*>& parallel_binding_ids); | ||
|
||
//! Returns the symbolic or constant extetns of warp padded parallel | ||
//! iterdomains in the given vector. | ||
std::unique_ptr<caching::WarpPaddedExtentsInfo> getWarpPaddedExtentsInfo( | ||
|
Uh oh!
There was an error while loading. Please reload this page.