Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor cleanup in pointwise scheduler #1858

Merged
merged 1 commit into from
Jul 22, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 22 additions & 46 deletions torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@ namespace {
// Unused at the moment, commenting for clang tidy
constexpr int64_t kThreadX = 128;

// Returns number of non-reduction/non-broadcast dims in rfactor domain
size_t nRootDims(const TensorView* tv) {
auto root_dom = tv->getMaybeRFactorDomain();
size_t tv_n_dims = 0;
for (auto dim : root_dom) {
if (!dim->isReduction() && !dim->isBroadcast()) {
tv_n_dims++;
}
}
return tv_n_dims;
}

// DomainMap uses the ComputeAtMap to find a reference TensorView
// that maps to all iterDomains in the fusion.
class DomainMap {
Expand All @@ -38,15 +50,21 @@ class DomainMap {
// The pointwise scheduler heuristics requires a minimum number of axes.
// The output reference tensor should respect this requirement.
TensorView* findReferenceTensorView(int minimum_num_axes = 0) const {
TensorView* result = nullptr;
int max_dims = -1;
for (auto output_tv :
ir_utils::filterByType<TensorView>(fusion_->outputs())) {
if (isValidReference(output_tv) &&
hasMinimumSize(output_tv, minimum_num_axes) &&
!output_tv->isFusionInput()) {
return output_tv;
int n_dims = nRootDims(output_tv);
if (n_dims > max_dims) {
result = output_tv;
max_dims = n_dims;
}
}
}
return nullptr;
return result;
}

static bool hasReferenceTensorView(Fusion* fusion) {
Expand Down Expand Up @@ -187,35 +205,11 @@ c10::optional<PointwiseParams> getPointwiseHeuristics(
// Incase any buffer is of type DataType::Index
DataType index_type = indexModeToDtype(runtime_info.getIndexMode());

TensorView* largest_out = nullptr;
int max_dims = -1;

auto in_tvs = ir_utils::filterByType<TensorView>(fusion->inputs());
// Will want to access this with direct indexing later, convert now.
std::vector<TensorView*> out_tvs;
// Only use valid reference tensors during heuristics analysis

DomainMap domain_map(fusion);
for (auto out_tv : ir_utils::filterByType<TensorView>(fusion->outputs())) {
if (domain_map.isValidReference(out_tv)) {
out_tvs.push_back(out_tv);
}
}
TORCH_INTERNAL_ASSERT(
!out_tvs.empty(), "No valid reference outputs were found!");

for (auto out_tv : out_tvs) {
int n_dims = 0;
for (auto id : out_tv->getMaybeRFactorDomain()) {
if (id->isReduction() || id->isBroadcast()) {
continue;
}
n_dims++;
}
if (n_dims > max_dims) {
largest_out = out_tv;
max_dims = n_dims;
}
}
TensorView* largest_out = domain_map.findReferenceTensorView();

TORCH_INTERNAL_ASSERT(largest_out != nullptr);

Expand All @@ -224,15 +218,12 @@ c10::optional<PointwiseParams> getPointwiseHeuristics(

// TODO: Set to 1?
int64_t max_input_dtype_size = 2;
size_t n_tensors = 0;

for (auto inp : in_tvs) {
max_input_dtype_size = std::max(
max_input_dtype_size,
(int64_t)dataTypeSize(inp->getDataType().value(), index_type));
n_tensors++;
}
n_tensors += std::distance(out_tvs.begin(), out_tvs.end());

auto ref_root = largest_out->getMaybeRFactorDomain();
std::vector<int64_t> elem_counts(ref_root.size(), 1);
Expand Down Expand Up @@ -533,7 +524,6 @@ c10::optional<PointwiseParams> getPointwiseHeuristics(
std::cerr << "\n===== Pointwise Stats ========\n"
<< "num_elems: " << n_elems << "\n"
<< "elem_counts: " << elem_counts << "\n"
<< "n_tensor_inputs: " << n_tensors << "\n"
<< "max_input_dtype_size: " << max_input_dtype_size << "\n"
<< "vectorize_factor: " << vectorize_factor << std::endl;
std::cerr << "broadcast_byte_multiples: ";
Expand Down Expand Up @@ -563,20 +553,6 @@ LaunchParams schedulePointwise(
return params.value().lparams;
}

namespace {
// Returns number of non-reduction/non-broadcast dims in rfactor domain
size_t nRootDims(const TensorView* tv) {
auto root_dom = tv->getMaybeRFactorDomain();
size_t tv_n_dims = 0;
for (auto dim : root_dom) {
if (!dim->isReduction() && !dim->isBroadcast()) {
tv_n_dims++;
}
}
return tv_n_dims;
}
} // namespace

bool hasReferenceTensorView(Fusion* fusion) {
return DomainMap::hasReferenceTensorView(fusion);
}
Expand Down