Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion torch/csrc/jit/codegen/cuda/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,16 @@ LaunchParams FusionExecutor::computeLaunchParams(
});
auto& parallel_iter_extents = parallel_iter_extent_entry.get();

auto simplified_parallel_iter_extent_entry =
executor_utils::caching::ExecutorCompileTimeEntry<
executor_utils::caching::SimplifiedParallelIterExtentMap>(
data_cache, [&parallel_binding_ids, &lower]() {
return executor_utils::getSimplifiedParallelIterExtents(
lower, parallel_binding_ids);
});
auto& simplified_parallel_iter_extents =
simplified_parallel_iter_extent_entry.get();

auto warp_padded_parallel_entry =
executor_utils::caching::ExecutorCompileTimeEntry<
executor_utils::caching::WarpPaddedParallelExtents>(
Expand Down Expand Up @@ -409,7 +419,7 @@ LaunchParams FusionExecutor::computeLaunchParams(
}

// Run through the rest of the parallel IterDomains and infer their size
for (auto& entry : parallel_iter_extents) {
for (auto& entry : simplified_parallel_iter_extents) {
FUSER_PERF_SCOPE("FusionExecutor::ParallelBindingResolution");
auto p_type = entry.first;
auto parallel_extents = entry.second;
Expand Down
52 changes: 44 additions & 8 deletions torch/csrc/jit/codegen/cuda/executor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,7 @@ ExecutorCompileTimeEntry<EntryClass>::ExecutorCompileTimeEntry(
// Template instantiation
template class ExecutorCompileTimeEntry<ParallelBindingIterDomains>;
template class ExecutorCompileTimeEntry<ParallelIterExtentMap>;
template class ExecutorCompileTimeEntry<SimplifiedParallelIterExtentMap>;
template class ExecutorCompileTimeEntry<WarpPaddedParallelExtents>;
template class ExecutorCompileTimeEntry<VectorizedTensorValidation>;
template class ExecutorCompileTimeEntry<InputAliasIndices>;
Expand All @@ -986,20 +987,55 @@ std::vector<IterDomain*> getParallelBindingsIterDomains(
return parallel_ids;
}

void insertParallelExtent(
GpuLower& lower,
IterDomain* binding_id,
const std::unique_ptr<ParallelExtentMap>& parallel_iter_extents_ptr) {
auto kir_extent = lower.lowerValue(binding_id->extent());
const auto it =
parallel_iter_extents_ptr->find(binding_id->getParallelType());
if (it != parallel_iter_extents_ptr->end()) {
it->second.push_back(kir_extent);
} else {
parallel_iter_extents_ptr->operator[](binding_id->getParallelType()) = {
kir_extent};
}
}

std::unique_ptr<ParallelExtentMap> getParallelIterExtents(
GpuLower& lower,
std::vector<IterDomain*>& parallel_binding_ids) {
auto parallel_iter_extents_ptr = std::make_unique<ParallelExtentMap>();
for (auto id : parallel_binding_ids) {
// TODO(kir): we should rewrite this logic based on the Kernel object
auto kir_extent = lower.lowerValue(id->extent());
const auto it = parallel_iter_extents_ptr->find(id->getParallelType());
if (it != parallel_iter_extents_ptr->end()) {
it->second.push_back(kir_extent);
} else {
parallel_iter_extents_ptr->operator[](id->getParallelType()) = {
kir_extent};
insertParallelExtent(lower, id, parallel_iter_extents_ptr);
}

return parallel_iter_extents_ptr;
}

std::unique_ptr<ParallelExtentMap> getSimplifiedParallelIterExtents(
GpuLower& lower,
std::vector<IterDomain*>& parallel_binding_ids) {
auto parallel_iter_extents_ptr = std::make_unique<ParallelExtentMap>();
auto& parallel_map = lower.caParallelMap();
std::vector<IterDomain*> mapped;
bool is_tidx_warp_padded = lower.getWarpPaddedParallelInfo().is_tidx_padded;

for (auto id : parallel_binding_ids) {
if (std::any_of(
mapped.begin(),
mapped.end(),
[id, &parallel_map](IterDomain* mapped_id) {
return parallel_map.areMapped(mapped_id, id);
})) {
if (id->getParallelType() != ParallelType::TIDx || !is_tidx_warp_padded) {
continue;
}
}

insertParallelExtent(
lower, parallel_map.getConcreteMappedID(id), parallel_iter_extents_ptr);
mapped.push_back(id);
}

return parallel_iter_extents_ptr;
Expand Down
28 changes: 28 additions & 0 deletions torch/csrc/jit/codegen/cuda/executor_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ namespace caching {
enum class CompileTimeEntryType {
PARALLEL_BINDING_ITERDOMAINS,
PARALLEL_ITER_EXTENT_MAP,
SIMPLIFIED_PARALLEL_ITER_EXTENT_MAP,
WARP_PADDED_PARALLEL_EXTENTS,
VECTORIZED_TENSOR_VALIDATION,
INPUT_ALIAS_INDICES,
Expand Down Expand Up @@ -114,6 +115,27 @@ class ParallelIterExtentMap {
CompileTimeEntryType::PARALLEL_ITER_EXTENT_MAP;
};

//! Compile-time info to be cached in each FusionExecutor:
//! SimplifiedParallelIterExtentMap
//! This entry type is a simplified version of ParallelIterExtentMap.
//!
//! For launch parameter binding we only need the most concrete iterdomain
//! in each disjoint set stored in CaParallelMap. This entry stores the
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@naoyam are we certain this doesn't have to be index map?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually think the parallel map is fine here since concrete domains are selected.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This actually makes me think again about ParallelDimensionMap as well. For example:

 auto tv0 = makeSymbolicTensor(1);
  fusion.addInput(tv0);
  auto tv1 = makeSymbolicTensor(2);
  fusion.addInput(tv1);
  auto tv2 = broadcast(tv0, {true, false});
  auto tv3 = add(tv2, tv1);
  fusion.addOutput(tv4);

  tv3->merge(0, 1);
  tv0->computeAt(tv3, -1);
  tv1->computeAt(tv3, -1);

  tv3->axis(0)->parallelize(ParallelType::BIDx);
Inputs:
  T0_g[ iS0{i1} ], float
  T1_g[ iS9{( i4 * i7 )} ], float
Outputs:
  T3_g[ iblockIdx.x7{( i4 * i1 )} ] produce_pos( 1), float

%kernel_math {
T2_l[ iS8{( 1 * i1 )} ] ca_pos( 1 ) = broadcast( T0_g[ iS0{i1} ] )
T3_g[ iblockIdx.x7{( i4 * i1 )} ] produce_pos( 1)
   = T2_l[ iS8{( 1 * i1 )} ] ca_pos( 1 )
   + T1_g[ iS9{( i4 * i7 )} ];
}

With this fusion, the parallel dimension map looks like:

blockIdx.x: gridDim.x, non-exact
blockIdx.y: unused
blockIdx.z: unused
threadIdx.x: unused
threadIdx.y: unused
threadIdx.z: unused

Note that BIDx is marked as non-exact, even though it is actually exact. This is because ParallelDimensionMap uses the index map so the merged axis of T3 is not mapped with the axis of T0, but both of them are mapped with each other in the parallel map and are parallelized by BIDx.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought this should be not exact since iS8{( 1 * i1 )} != iblockIdx.x7{( i4 * i1 )}

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that was the behavior we're interested in.

Copy link
Collaborator

@naoyam naoyam Sep 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you meant iS0{i1} != iblockIdx.x7{( i4 * i1 )}. iS8{( 1 * i1 )} and iblockIdx.x7{( i4 * i1 )} are mapped in the index map. The difference comes from the forwarding, so iS0{i1} and iblockIdx.x7{( i4 * i1 )} are mapped in the parallel map but not in the index map.

Since i1 != i4 * i1, they appear to make BIDx non-exact. However, with the computeAt, indexing T0 is always done with i4 * i1. In this particular case, while iS0{i1} is marked as parallelized with BIDx, its indexing is blockDim.x % i1, so it never goes beyond i1. So, in that sense, BIDx is still exact, in other words, we don't need a predicate like blockDim.x < i1.

This seems related to the issue we discussed before about the indexing change with and without computeAt.

I feel this is too much intricated. I'm spending too much time to think about which maps to use. Would be great if we could come up with a single map, but not sure if that's reasonably possible.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iS8{( 1 * i1 )} != iblockIdx.x7{( i4 * i1 )} should now be true in the index map, it was true that it used to map, but I explicitly changed that because it shouldn't.

This makes sense though. I forgot we don't map to the right of the compute at point in the parallel map. So really this is what parallel type is being used to index into the tensor. In that instance then yes, the parallel map is enough to get the unique entries of parallelization bound into the problem.

//! remaining list of extents for binding after this simplification.
//!
//! We still need ParallelIterExtentMap since we want to bind the concrete
//! values to the extents of all parallelized iterdomains. We would be
//! able to save these bindings if the integer machine has a notion of
//! equality and could be configured compile time. But that'd be a longer
//! term target.
class SimplifiedParallelIterExtentMap {
public:
using DataType =
std::unordered_map<ParallelType, std::vector<const kir::Val*>, TypeHash>;
static const CompileTimeEntryType EntryType =
CompileTimeEntryType::SIMPLIFIED_PARALLEL_ITER_EXTENT_MAP;
};

//! WarpPaddedExtentsInfo:
//! Auxiliary data type for entry class WarpPaddedParallelExtents
struct WarpPaddedExtentsInfo {
Expand Down Expand Up @@ -269,6 +291,12 @@ std::unique_ptr<ParallelExtentMap> getParallelIterExtents(
GpuLower& lower,
std::vector<IterDomain*>& parallel_binding_ids);

//! Returns the simplified set of extents necessary for launch parameter
//! binding.
std::unique_ptr<ParallelExtentMap> getSimplifiedParallelIterExtents(
GpuLower& lower,
std::vector<IterDomain*>& parallel_binding_ids);

//! Returns the symbolic or constant extetns of warp padded parallel
//! iterdomains in the given vector.
std::unique_ptr<caching::WarpPaddedExtentsInfo> getWarpPaddedExtentsInfo(
Expand Down
1 change: 0 additions & 1 deletion torch/csrc/jit/codegen/cuda/expr_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ c10::optional<Int::ScalarType> ExpressionEvaluator::evaluate(Val* value) {
if (evaluator_precomputed_integers_ != nullptr) {
return evaluator_precomputed_integers_->getMaybeValueFor(value);
} else {
FUSER_PERF_SCOPE("ExpressionEvaluator::evaluate");
auto maybe_concrete_value = getValue(value);
if (!maybe_concrete_value.has_value()) {
if (value->definition() != nullptr) {
Expand Down
Loading