Skip to content

Commit

Permalink
Transpose scheduler, step 1 (#1854)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm committed Aug 11, 2022
1 parent 8a45dbf commit b7435af
Show file tree
Hide file tree
Showing 14 changed files with 1,674 additions and 414 deletions.
1 change: 1 addition & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,7 @@ libtorch_cuda_core_sources = [
"torch/csrc/jit/codegen/cuda/root_domain_map.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp",
Expand Down
1 change: 1 addition & 0 deletions test/cpp/jit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ if(USE_CUDA)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_view.cpp)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_rng.cu)
endif()

Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <torch/csrc/jit/codegen/cuda/scheduler/normalization.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/pointwise.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/reduction.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/transpose.h>

namespace torch {
namespace jit {
Expand All @@ -12,7 +13,8 @@ enum class TORCH_CUDA_CU_API ScheduleHeuristic {
None,
PointWise,
Reduction,
Persistent
Persistent,
Transpose
};
}
} // namespace fuser
Expand Down
10 changes: 10 additions & 0 deletions torch/csrc/jit/codegen/cuda/scheduler/compile_time_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ enum class CompileTimeEntryType {
DOMAIN_MAP,
REFERENCE_TENSORS,
VECTORIZABLE_INPUTS_AND_OUTPUTS,
INPUTS_AND_OUTPUTS_INNER_DIM_GROUPS,
UNROLLABLE_INPUTS_AND_OUTPUTS,
REDUCTION_TVS,
PERSISTENT_BUFFER_INFO,
Expand Down Expand Up @@ -62,6 +63,15 @@ class VectorizableInputsAndOutputs {
CompileTimeEntryType::VECTORIZABLE_INPUTS_AND_OUTPUTS;
};

//! Entry type definition class for `INPUTS_AND_OUTPUTS_INNER_DIM_GROUPS`,
//! stores the fusion's inputs and outputs grouped by inner most dimension.
class InputsOutputsInnerDimGroups {
public:
using DataType = std::vector<std::vector<TensorView*>>;
static const CompileTimeEntryType EntryType =
CompileTimeEntryType::INPUTS_AND_OUTPUTS_INNER_DIM_GROUPS;
};

//! Entry type definition class for `UNROLLABLE_INPUTS_AND_OUTPUTS`,
//! stores the unrollable TensorViews on a fusion's inputs and outputs.
class UnrollableInputsAndOutputs {
Expand Down
24 changes: 1 addition & 23 deletions torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/registry.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.h>
Expand Down Expand Up @@ -57,29 +58,6 @@ class DomainMap : public pointwise_utils::DomainMap {
return domain_map.findReferenceTensorView() != nullptr;
}

// Determine if output TensorView is a valid reference tensor for this fusion.
// The reference tensor must map to all the iterDomains in each input.
bool isValidReference(TensorView* output_tv) const {
if (output_tv->isFusionInput()) {
return false;
}
for (auto input_tv :
ir_utils::filterByType<TensorView>(fusion_->inputs())) {
if (input_tv->uses().empty()) {
continue;
}

if (fusion_->getOutputAlias(output_tv) == input_tv) {
continue;
}

if (!areAllInputIdsMappedToOutput(input_tv, output_tv)) {
return false;
}
}
return true;
}

private:
bool hasMinimumSize(TensorView* tv, int num_axes) const {
TORCH_INTERNAL_ASSERT(tv != nullptr);
Expand Down
46 changes: 24 additions & 22 deletions torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,9 @@ namespace fuser {
namespace cuda {
namespace pointwise_utils {

DomainMap::DomainMap(Fusion* fusion)
: fusion_(fusion), ca_map_(ComputeAtMap(fusion)) {
view_tvs_ = scheduler_utils::getViewTVs(fusion);
}

bool DomainMap::areExactMapped(IterDomain* id1, IterDomain* id2) {
return ca_map_.areMapped(id1, id2, IdMappingMode::EXACT);
}

// Determine if all IterDomains in input are mapped to output
bool DomainMap::areAllInputIdsMappedToOutput(
TensorView* input_tv,
TensorView* output_tv) const {
// Determine if all IterDomains in input are mapped to the given tensor
bool DomainMap::areAllInputIdsMappedTo(TensorView* input_tv, TensorView* tv)
const {
// Get concrete IDs for input root or rfactor domain
std::unordered_set<IterDomain*> in_concrete_ids;
for (auto in_id : input_tv->getMaybeRFactorDomain()) {
Expand All @@ -30,11 +20,9 @@ bool DomainMap::areAllInputIdsMappedToOutput(

// Erase all input concrete IDs mapped to the output domain
// Ignore unresolved broadcast dimensions
for (auto out_id : output_tv->getMaybeRFactorDomain()) {
if (!out_id->isBroadcast()) {
if (!eraseIfMapped(in_concrete_ids, out_id)) {
eraseIfInputMappedThroughViewToOutput(in_concrete_ids, out_id);
}
for (auto id : tv->getMaybeRFactorDomain()) {
if (!eraseIfMapped(in_concrete_ids, id)) {
eraseIfInputMappedThroughViewTo(in_concrete_ids, id);
}
}
return in_concrete_ids.empty();
Expand All @@ -45,7 +33,7 @@ bool DomainMap::eraseIfMapped(
std::unordered_set<IterDomain*>& in_concrete_ids,
IterDomain* out_id) const {
auto out_concrete_id =
ca_map_.getConcreteMappedID(out_id, IdMappingMode::EXACT);
ca_map_.getConcreteMappedID(out_id, IdMappingMode::PERMISSIVE);
auto in_concrete_id_iter = in_concrete_ids.find(out_concrete_id);
bool found_match = in_concrete_id_iter != in_concrete_ids.end();
if (found_match) {
Expand All @@ -58,12 +46,12 @@ bool DomainMap::eraseIfMapped(
// Currently this function only allow having one view on the path from input to
// output. If there are multiple views, then likely the pointwise scheduler will
// reject the fusion because we can not correctly find a reference tensor.
void DomainMap::eraseIfInputMappedThroughViewToOutput(
void DomainMap::eraseIfInputMappedThroughViewTo(
std::unordered_set<IterDomain*>& in_concrete_ids,
IterDomain* out_id) const {
IterDomain* id) const {
for (auto view : view_tvs_) {
// Find any ID in view rfactor domain that is mapped to output ID
auto view_rfactor_id = anyMapped(view->getRFactorDomain(), out_id);
auto view_rfactor_id = anyMapped(view->getRFactorDomain(), id);
if (view_rfactor_id == nullptr) {
continue;
}
Expand Down Expand Up @@ -94,6 +82,20 @@ IterDomain* DomainMap::anyMapped(
return nullptr;
}

// Determine if output TensorView is a valid reference tensor for this fusion.
// The reference tensor must map to all the iterDomains in each input.
bool DomainMap::isValidReference(TensorView* tv) const {
for (auto input_tv : ir_utils::filterByType<TensorView>(fusion_->inputs())) {
if (input_tv->uses().empty()) {
continue;
}
if (!areAllInputIdsMappedTo(input_tv, tv)) {
return false;
}
}
return true;
}

} // namespace pointwise_utils
} // namespace cuda
} // namespace fuser
Expand Down
22 changes: 15 additions & 7 deletions torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,37 @@ namespace pointwise_utils {
// that maps to all IterDomains in the fusion.
class DomainMap {
public:
DomainMap(Fusion* fusion);
DomainMap(Fusion* fusion) : fusion_(fusion), ca_map_(fusion) {
view_tvs_ = scheduler_utils::getViewTVs(fusion);
}
virtual ~DomainMap() = default;

bool areExactMapped(IterDomain* id1, IterDomain* id2);
bool areExactMapped(IterDomain* id1, IterDomain* id2) const {
return ca_map_.areMapped(id1, id2, IdMappingMode::EXACT);
}

const ComputeAtMap& getComputeAtMap() const {
return ca_map_;
}

// Determine if a TensorView is a valid reference tensor for this fusion.
// The reference tensor must map to all the iterDomains in each input.
bool isValidReference(TensorView* tv) const;

protected:
// Determine if all iterDomains are mapped between input and output tvs
bool areAllInputIdsMappedToOutput(TensorView* input_tv, TensorView* output_tv)
// Determine if all IterDomains are mapped between input and the given tvs
bool areAllInputIdsMappedTo(TensorView* input_tv, TensorView* output_tv)
const;

// Erase input concrete ID if it is mapped to output ID
bool eraseIfMapped(
std::unordered_set<IterDomain*>& in_concrete_ids,
IterDomain* out_id) const;

// Check if in_id is mapped to out_id through any view rfactor domain
void eraseIfInputMappedThroughViewToOutput(
// Check if in_id is mapped to id through any view rfactor domain
void eraseIfInputMappedThroughViewTo(
std::unordered_set<IterDomain*>& in_concrete_ids,
IterDomain* out_id) const;
IterDomain* id) const;

// Find any id in domain that maps with target id
IterDomain* anyMapped(
Expand Down
86 changes: 86 additions & 0 deletions torch/csrc/jit/codegen/cuda/scheduler/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <torch/csrc/jit/codegen/cuda/scheduler/debug_utils.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/pointwise.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/registry.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/transpose.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>

#include <limits>
Expand Down Expand Up @@ -1244,10 +1245,75 @@ class PersistentKernelScheduler : public SchedulerEntry {
}
};

class TransposeScheduler : public SchedulerEntry {
public:
explicit TransposeScheduler(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache = nullptr)
: SchedulerEntry(ScheduleHeuristic::Transpose) {
computeHeuristics(fusion, runtime_info, data_cache);
}

static bool canScheduleCompileTime(Fusion* fusion) {
// Not enabling this yet. Needs more validation.
return false;
#if 0
if (!hasAtLeastTwoValidGroups(fusion)) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::Transpose,
"cannot find two mismatching inner most dimensions");
return false;
}

// TODO: add support for trivial reduction
auto reduction_ops =
ir_utils::getReductionOps(fusion, false /* ignore_trivial */);

if (!reduction_ops.empty()) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::Transpose, "no support for reduction ops");
return false;
}

if (hasNonUniqueBcast(fusion)) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::Transpose,
"Broadcasting dimension might be broadcasting to multiple sizes.");
return false;
}

return true;
#endif
}

static bool canScheduleRunTime(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache = nullptr) {
return true;
}

void schedule(Fusion* fusion) override {
FUSER_PERF_SCOPE("Schedule Transpose Fusion");
scheduleTranspose(fusion, transposeParams());
}

private:
void computeHeuristics(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache = nullptr) {
params_ = getTransposeHeuristics(fusion, runtime_info, data_cache);
TORCH_INTERNAL_ASSERT(params_ != nullptr);
}
};

// Schedule Table
const std::vector<ScheduleHeuristic>& all_heuristics() {
static const std::vector<ScheduleHeuristic> hlist = {
ScheduleHeuristic::Reduction,
ScheduleHeuristic::Transpose,
ScheduleHeuristic::PointWise,
ScheduleHeuristic::Persistent};
return hlist;
Expand Down Expand Up @@ -1294,6 +1360,9 @@ bool SchedulerEntry::canSchedule(
case ScheduleHeuristic::Persistent:
return checkCanSchedule<PersistentKernelScheduler>(
fusion, runtime_info, data_cache);
case ScheduleHeuristic::Transpose:
return checkCanSchedule<TransposeScheduler>(
fusion, runtime_info, data_cache);
default:
TORCH_INTERNAL_ASSERT(false, "unreachable");
return false;
Expand All @@ -1320,6 +1389,10 @@ std::unique_ptr<SchedulerEntry> SchedulerEntry::makeEntry(
scheduler_entry = std::make_unique<PersistentKernelScheduler>(
fusion, runtime_info, data_cache);
break;
case ScheduleHeuristic::Transpose:
scheduler_entry = std::make_unique<TransposeScheduler>(
fusion, runtime_info, data_cache);
break;
default:
TORCH_INTERNAL_ASSERT(false, "unreachable");
}
Expand Down Expand Up @@ -1353,6 +1426,8 @@ std::string toString(ScheduleHeuristic sh) {
return "reduction";
case ScheduleHeuristic::Persistent:
return "persistent";
case ScheduleHeuristic::Transpose:
return "transpose";
default:
TORCH_INTERNAL_ASSERT(false, "undefined schedule");
}
Expand Down Expand Up @@ -1405,6 +1480,10 @@ HeuristicSummary::HeuristicSummary(
getPersistentHeuristics(fusion, runtime_info, this);
PersistentKernelScheduler::canScheduleRunTime(fusion, runtime_info, this);
break;
case ScheduleHeuristic::Transpose:
getTransposeHeuristics(fusion, runtime_info, this);
TransposeScheduler::canScheduleRunTime(fusion, runtime_info, this);
break;
default:
TORCH_INTERNAL_ASSERT(false, "unknown heuristic");
}
Expand Down Expand Up @@ -1451,6 +1530,11 @@ void HeuristicSummary::validate() const {
entry_type_map_.count(EntryType::SCOPE_PERSISTENT_FACTOR_INFO));
break;
}
case ScheduleHeuristic::Transpose: {
TORCH_INTERNAL_ASSERT(entry_type_map_.count(
EntryType::INPUTS_AND_OUTPUTS_INNER_DIM_GROUPS));
break;
}
default:
TORCH_INTERNAL_ASSERT(false, "unknown heuristic");
}
Expand Down Expand Up @@ -1490,6 +1574,8 @@ template class HeuristicSummaryEntry<HeuristicCompileTime::DomainMap>;
template class HeuristicSummaryEntry<HeuristicCompileTime::ReferenceTensors>;
template class HeuristicSummaryEntry<
HeuristicCompileTime::VectorizableInputsAndOutputs>;
template class HeuristicSummaryEntry<
HeuristicCompileTime::InputsOutputsInnerDimGroups>;
template class HeuristicSummaryEntry<
HeuristicCompileTime::UnrollableInputsAndOutputs>;
template class HeuristicSummaryEntry<HeuristicCompileTime::ReductionTVs>;
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/jit/codegen/cuda/scheduler/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,13 @@ class TORCH_CUDA_CU_API SchedulerEntry {
return *pparams;
}

const TransposeParams& transposeParams() const {
auto tparams = std::dynamic_pointer_cast<TransposeParams>(params_);
TORCH_INTERNAL_ASSERT(
tparams != nullptr, "Heuristic parameter is not a transpose parameter");
return *tparams;
}

void updateLaunchConstraint(const LaunchParams& launch_params) {
params_->lparams = launch_params;
}
Expand Down
Loading

0 comments on commit b7435af

Please sign in to comment.