Skip to content

Commit

Permalink
[Pass] Profiling TVM compiler passes (apache#7500)
Browse files Browse the repository at this point in the history
* basic pass profiler prototype

* allow enable/disable of pass profiling

* lint

* add example pass profiler usage as test

* render pass profiles to String instead of stdout
  • Loading branch information
altanh authored and Trevor Morris committed May 6, 2021
1 parent b90b680 commit cd2518e
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 10 deletions.
13 changes: 3 additions & 10 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -349,11 +349,8 @@ class Pass : public ObjectRef {
*
* \return The transformed module.
*/
IRModule operator()(IRModule mod) const {
const PassNode* node = operator->();
ICHECK(node != nullptr);
return node->operator()(std::move(mod));
}
IRModule operator()(IRModule mod) const;

/*!
* \brief Transform mod using a functor under a given pass context.
*
Expand All @@ -362,11 +359,7 @@ class Pass : public ObjectRef {
*
* \return The transformed module.
*/
IRModule operator()(IRModule mod, const PassContext& pass_ctx) const {
const PassNode* node = operator->();
ICHECK(node != nullptr);
return node->operator()(std::move(mod), pass_ctx);
}
IRModule operator()(IRModule mod, const PassContext& pass_ctx) const;

TVM_DEFINE_OBJECT_REF_METHODS(Pass, ObjectRef, PassNode);
};
Expand Down
23 changes: 23 additions & 0 deletions python/tvm/ir/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,3 +330,26 @@ def PrintIR(header="", show_meta_data=False):
The pass
"""
return _ffi_transform_api.PrintIR(header, show_meta_data)


def render_pass_profiles():
"""Returns a string render of the pass profiling data. The format of each output line is
`{name}: {time} [{time excluding sub-passes}] ({% of total}; {% of parent})`.
The indentation of each line corresponds to nesting of passes.
"""
return _ffi_transform_api.render_pass_profiles()


def clear_pass_profiles():
"""Clears all stored pass profiling data."""
_ffi_transform_api.clear_pass_profiles()


def enable_pass_profiling():
"""Enables pass profiling."""
_ffi_transform_api.enable_pass_profiling()


def disable_pass_profiling():
"""Disables pass profiling."""
_ffi_transform_api.disable_pass_profiling()
157 changes: 157 additions & 0 deletions src/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>

#include <chrono>
#include <iomanip>
#include <stack>
#include <unordered_set>

Expand Down Expand Up @@ -169,6 +171,161 @@ void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_be

class ModulePass;

/*! \brief PassProfile stores profiling information for a given pass and its sub-passes. */
struct PassProfile {
// TODO(@altanh): expose PassProfile through TVM Object API
using Clock = std::chrono::steady_clock;
using Duration = std::chrono::duration<double, std::micro>;
using Time = std::chrono::time_point<Clock>;

/*! \brief The name of the pass being profiled. */
String name;
/*! \brief The time when the pass was entered. */
Time start;
/*! \brief The time when the pass completed. */
Time end;
/*! \brief The total duration of the pass, i.e. end - start. */
Duration duration;
/*! \brief PassProfiles for all sub-passes invoked during the execution of the pass. */
std::vector<PassProfile> children;

explicit PassProfile(String name)
: name(name), start(Clock::now()), end(Clock::now()), children() {}

/*! \brief Gets the PassProfile of the currently executing pass. */
static PassProfile* Current();
/*! \brief Pushes a new PassProfile with the given pass name. */
static void EnterPass(String name);
/*! \brief Pops the current PassProfile. */
static void ExitPass();
};

struct PassProfileThreadLocalEntry {
/*! \brief The placeholder top-level PassProfile. */
PassProfile root;
/*! \brief The stack of PassProfiles for nested passes currently running. */
std::stack<PassProfile*> profile_stack;
/*! \brief Whether or not pass profiling is active. */
bool active;

PassProfileThreadLocalEntry() : root("root"), active(false) {}
};

/*! \brief Thread local store to hold the pass profiling data. */
typedef dmlc::ThreadLocalStore<PassProfileThreadLocalEntry> PassProfileThreadLocalStore;

void PassProfile::EnterPass(String name) {
if (!PassProfileThreadLocalStore::Get()->active) return;
PassProfile* cur = PassProfile::Current();
cur->children.emplace_back(name);
PassProfileThreadLocalStore::Get()->profile_stack.push(&cur->children.back());
}

void PassProfile::ExitPass() {
if (!PassProfileThreadLocalStore::Get()->active) return;
PassProfile* cur = PassProfile::Current();
ICHECK_NE(cur->name, "root") << "mismatched enter/exit for pass profiling";
cur->end = std::move(PassProfile::Clock::now());
cur->duration = std::chrono::duration_cast<PassProfile::Duration>(cur->end - cur->start);
PassProfileThreadLocalStore::Get()->profile_stack.pop();
}

PassProfile* PassProfile::Current() {
PassProfileThreadLocalEntry* entry = PassProfileThreadLocalStore::Get();
if (!entry->profile_stack.empty()) {
return entry->profile_stack.top();
} else {
return &entry->root;
}
}

IRModule Pass::operator()(IRModule mod) const {
const PassNode* node = operator->();
ICHECK(node != nullptr);
PassProfile::EnterPass(node->Info()->name);
auto ret = node->operator()(std::move(mod));
PassProfile::ExitPass();
return std::move(ret);
}

IRModule Pass::operator()(IRModule mod, const PassContext& pass_ctx) const {
const PassNode* node = operator->();
ICHECK(node != nullptr);
PassProfile::EnterPass(node->Info()->name);
auto ret = node->operator()(std::move(mod), pass_ctx);
PassProfile::ExitPass();
return std::move(ret);
}

String RenderPassProfiles() {
PassProfileThreadLocalEntry* entry = PassProfileThreadLocalStore::Get();
CHECK(entry->profile_stack.empty()) << "cannot print pass profile while still in a pass!";

if (entry->root.children.empty()) {
LOG(WARNING) << "no passes have been profiled, did you enable pass profiling?";
return String();
}

// (depth, parent_duration, pass)
std::stack<std::tuple<size_t, PassProfile::Duration, PassProfile*>> profiles;

// push top level passes
PassProfile::Duration top_dur(0);
for (auto it = entry->root.children.begin(); it != entry->root.children.end(); ++it) {
top_dur += it->duration;
}
for (auto it = entry->root.children.rbegin(); it != entry->root.children.rend(); ++it) {
profiles.push(std::make_tuple(0, top_dur, &*it));
}

std::ostringstream os;
os << std::fixed;

while (profiles.size() > 0) {
size_t depth;
PassProfile::Duration parent_duration;
PassProfile* profile;
std::tie(depth, parent_duration, profile) = profiles.top();
profiles.pop();

// indent depth
for (size_t i = 0; i < depth; ++i) {
os << "\t";
}

// calculate time spent in pass itself (excluding sub-passes), and push children
PassProfile::Duration self_duration = profile->duration;
for (auto it = profile->children.rbegin(); it != profile->children.rend(); ++it) {
self_duration -= it->duration;
profiles.push(std::make_tuple(depth + 1, profile->duration, &*it));
}

double parent_pct = profile->duration.count() / parent_duration.count() * 100.0;
double total_pct = profile->duration.count() / top_dur.count() * 100.0;

os << profile->name << ": ";
os << std::setprecision(0);
os << profile->duration.count() << "us [" << self_duration.count() << "us] ";
os << std::setprecision(2) << "(" << total_pct << "%; " << parent_pct << "%)\n";
}

return os.str();
}

TVM_REGISTER_GLOBAL("transform.render_pass_profiles").set_body_typed(RenderPassProfiles);

TVM_REGISTER_GLOBAL("transform.clear_pass_profiles").set_body_typed([]() {
PassProfileThreadLocalStore::Get()->root.children.clear();
});

TVM_REGISTER_GLOBAL("transform.enable_pass_profiling").set_body_typed([]() {
PassProfileThreadLocalStore::Get()->active = true;
});

TVM_REGISTER_GLOBAL("transform.disable_pass_profiling").set_body_typed([]() {
PassProfileThreadLocalStore::Get()->active = false;
});

/*!
* \brief Module-level passes are designed to implement global
* analysis/optimizations, i.e. interprocedural optimizations (IPO), etc. Passes
Expand Down
41 changes: 41 additions & 0 deletions tests/python/relay/test_pass_profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.relay
from tvm.relay import op


def test_pass_profiler():
x, y, z = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xyz"]
e1 = op.add(x, y)
e2 = op.subtract(x, z)
e3 = op.multiply(e1, e1 / e2)
mod = tvm.IRModule.from_expr(e3 + e2)

tvm.transform.enable_pass_profiling()

mod = tvm.relay.transform.AnnotateSpans()(mod)
mod = tvm.relay.transform.ToANormalForm()(mod)
mod = tvm.relay.transform.InferType()(mod)

profiles = tvm.transform.render_pass_profiles()
assert "AnnotateSpans" in profiles
assert "ToANormalForm" in profiles
assert "InferType" in profiles

tvm.transform.clear_pass_profiles()
tvm.transform.disable_pass_profiling()

0 comments on commit cd2518e

Please sign in to comment.