Skip to content

Commit

Permalink
Minor cleanup in pointwise scheduler (#1858)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm authored Jul 22, 2022
1 parent 9ee850c commit e842a9b
Showing 1 changed file with 22 additions and 46 deletions.
68 changes: 22 additions & 46 deletions torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@ namespace {
// Unused at the moment, commenting for clang tidy
constexpr int64_t kThreadX = 128;

// Returns number of non-reduction/non-broadcast dims in rfactor domain
size_t nRootDims(const TensorView* tv) {
auto root_dom = tv->getMaybeRFactorDomain();
size_t tv_n_dims = 0;
for (auto dim : root_dom) {
if (!dim->isReduction() && !dim->isBroadcast()) {
tv_n_dims++;
}
}
return tv_n_dims;
}

// DomainMap uses the ComputeAtMap to find a reference TensorView
// that maps to all iterDomains in the fusion.
class DomainMap {
Expand All @@ -38,15 +50,21 @@ class DomainMap {
// The pointwise scheduler heuristics requires a minimum number of axes.
// The output reference tensor should respect this requirement.
TensorView* findReferenceTensorView(int minimum_num_axes = 0) const {
TensorView* result = nullptr;
int max_dims = -1;
for (auto output_tv :
ir_utils::filterByType<TensorView>(fusion_->outputs())) {
if (isValidReference(output_tv) &&
hasMinimumSize(output_tv, minimum_num_axes) &&
!output_tv->isFusionInput()) {
return output_tv;
int n_dims = nRootDims(output_tv);
if (n_dims > max_dims) {
result = output_tv;
max_dims = n_dims;
}
}
}
return nullptr;
return result;
}

static bool hasReferenceTensorView(Fusion* fusion) {
Expand Down Expand Up @@ -187,35 +205,11 @@ c10::optional<PointwiseParams> getPointwiseHeuristics(
// Incase any buffer is of type DataType::Index
DataType index_type = indexModeToDtype(runtime_info.getIndexMode());

TensorView* largest_out = nullptr;
int max_dims = -1;

auto in_tvs = ir_utils::filterByType<TensorView>(fusion->inputs());
// Will want to access this with direct indexing later, convert now.
std::vector<TensorView*> out_tvs;
// Only use valid reference tensors during heuristics analysis

DomainMap domain_map(fusion);
for (auto out_tv : ir_utils::filterByType<TensorView>(fusion->outputs())) {
if (domain_map.isValidReference(out_tv)) {
out_tvs.push_back(out_tv);
}
}
TORCH_INTERNAL_ASSERT(
!out_tvs.empty(), "No valid reference outputs were found!");

for (auto out_tv : out_tvs) {
int n_dims = 0;
for (auto id : out_tv->getMaybeRFactorDomain()) {
if (id->isReduction() || id->isBroadcast()) {
continue;
}
n_dims++;
}
if (n_dims > max_dims) {
largest_out = out_tv;
max_dims = n_dims;
}
}
TensorView* largest_out = domain_map.findReferenceTensorView();

TORCH_INTERNAL_ASSERT(largest_out != nullptr);

Expand All @@ -224,15 +218,12 @@ c10::optional<PointwiseParams> getPointwiseHeuristics(

// TODO: Set to 1?
int64_t max_input_dtype_size = 2;
size_t n_tensors = 0;

for (auto inp : in_tvs) {
max_input_dtype_size = std::max(
max_input_dtype_size,
(int64_t)dataTypeSize(inp->getDataType().value(), index_type));
n_tensors++;
}
n_tensors += std::distance(out_tvs.begin(), out_tvs.end());

auto ref_root = largest_out->getMaybeRFactorDomain();
std::vector<int64_t> elem_counts(ref_root.size(), 1);
Expand Down Expand Up @@ -533,7 +524,6 @@ c10::optional<PointwiseParams> getPointwiseHeuristics(
std::cerr << "\n===== Pointwise Stats ========\n"
<< "num_elems: " << n_elems << "\n"
<< "elem_counts: " << elem_counts << "\n"
<< "n_tensor_inputs: " << n_tensors << "\n"
<< "max_input_dtype_size: " << max_input_dtype_size << "\n"
<< "vectorize_factor: " << vectorize_factor << std::endl;
std::cerr << "broadcast_byte_multiples: ";
Expand Down Expand Up @@ -563,20 +553,6 @@ LaunchParams schedulePointwise(
return params.value().lparams;
}

namespace {
// Returns number of non-reduction/non-broadcast dims in rfactor domain
size_t nRootDims(const TensorView* tv) {
auto root_dom = tv->getMaybeRFactorDomain();
size_t tv_n_dims = 0;
for (auto dim : root_dom) {
if (!dim->isReduction() && !dim->isBroadcast()) {
tv_n_dims++;
}
}
return tv_n_dims;
}
} // namespace

bool hasReferenceTensorView(Fusion* fusion) {
return DomainMap::hasReferenceTensorView(fusion);
}
Expand Down

0 comments on commit e842a9b

Please sign in to comment.