diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 56905ded5201..6557bbe31b8e 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -349,11 +349,8 @@ class Pass : public ObjectRef { * * \return The transformed module. */ - IRModule operator()(IRModule mod) const { - const PassNode* node = operator->(); - ICHECK(node != nullptr); - return node->operator()(std::move(mod)); - } + IRModule operator()(IRModule mod) const; + /*! * \brief Transform mod using a functor under a given pass context. * @@ -362,11 +359,7 @@ class Pass : public ObjectRef { * * \return The transformed module. */ - IRModule operator()(IRModule mod, const PassContext& pass_ctx) const { - const PassNode* node = operator->(); - ICHECK(node != nullptr); - return node->operator()(std::move(mod), pass_ctx); - } + IRModule operator()(IRModule mod, const PassContext& pass_ctx) const; TVM_DEFINE_OBJECT_REF_METHODS(Pass, ObjectRef, PassNode); }; diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index bb230cad0c9c..36e06eeb8b23 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -330,3 +330,26 @@ def PrintIR(header="", show_meta_data=False): The pass """ return _ffi_transform_api.PrintIR(header, show_meta_data) + + +def render_pass_profiles(): + """Returns a string render of the pass profiling data. The format of each output line is + `{name}: {time} [{time excluding sub-passes}] ({% of total}; {% of parent})`. + The indentation of each line corresponds to nesting of passes. + """ + return _ffi_transform_api.render_pass_profiles() + + +def clear_pass_profiles(): + """Clears all stored pass profiling data.""" + _ffi_transform_api.clear_pass_profiles() + + +def enable_pass_profiling(): + """Enables pass profiling.""" + _ffi_transform_api.enable_pass_profiling() + + +def disable_pass_profiling(): + """Disables pass profiling.""" + _ffi_transform_api.disable_pass_profiling() diff --git a/src/ir/transform.cc b/src/ir/transform.cc index f4516d5e57c5..48f13bc81df4 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -28,6 +28,8 @@ #include #include +#include +#include #include #include @@ -169,6 +171,161 @@ void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_be class ModulePass; +/*! \brief PassProfile stores profiling information for a given pass and its sub-passes. */ +struct PassProfile { + // TODO(@altanh): expose PassProfile through TVM Object API + using Clock = std::chrono::steady_clock; + using Duration = std::chrono::duration; + using Time = std::chrono::time_point; + + /*! \brief The name of the pass being profiled. */ + String name; + /*! \brief The time when the pass was entered. */ + Time start; + /*! \brief The time when the pass completed. */ + Time end; + /*! \brief The total duration of the pass, i.e. end - start. */ + Duration duration; + /*! \brief PassProfiles for all sub-passes invoked during the execution of the pass. */ + std::vector children; + + explicit PassProfile(String name) + : name(name), start(Clock::now()), end(Clock::now()), children() {} + + /*! \brief Gets the PassProfile of the currently executing pass. */ + static PassProfile* Current(); + /*! \brief Pushes a new PassProfile with the given pass name. */ + static void EnterPass(String name); + /*! \brief Pops the current PassProfile. */ + static void ExitPass(); +}; + +struct PassProfileThreadLocalEntry { + /*! \brief The placeholder top-level PassProfile. */ + PassProfile root; + /*! \brief The stack of PassProfiles for nested passes currently running. */ + std::stack profile_stack; + /*! \brief Whether or not pass profiling is active. */ + bool active; + + PassProfileThreadLocalEntry() : root("root"), active(false) {} +}; + +/*! \brief Thread local store to hold the pass profiling data. */ +typedef dmlc::ThreadLocalStore PassProfileThreadLocalStore; + +void PassProfile::EnterPass(String name) { + if (!PassProfileThreadLocalStore::Get()->active) return; + PassProfile* cur = PassProfile::Current(); + cur->children.emplace_back(name); + PassProfileThreadLocalStore::Get()->profile_stack.push(&cur->children.back()); +} + +void PassProfile::ExitPass() { + if (!PassProfileThreadLocalStore::Get()->active) return; + PassProfile* cur = PassProfile::Current(); + ICHECK_NE(cur->name, "root") << "mismatched enter/exit for pass profiling"; + cur->end = std::move(PassProfile::Clock::now()); + cur->duration = std::chrono::duration_cast(cur->end - cur->start); + PassProfileThreadLocalStore::Get()->profile_stack.pop(); +} + +PassProfile* PassProfile::Current() { + PassProfileThreadLocalEntry* entry = PassProfileThreadLocalStore::Get(); + if (!entry->profile_stack.empty()) { + return entry->profile_stack.top(); + } else { + return &entry->root; + } +} + +IRModule Pass::operator()(IRModule mod) const { + const PassNode* node = operator->(); + ICHECK(node != nullptr); + PassProfile::EnterPass(node->Info()->name); + auto ret = node->operator()(std::move(mod)); + PassProfile::ExitPass(); + return std::move(ret); +} + +IRModule Pass::operator()(IRModule mod, const PassContext& pass_ctx) const { + const PassNode* node = operator->(); + ICHECK(node != nullptr); + PassProfile::EnterPass(node->Info()->name); + auto ret = node->operator()(std::move(mod), pass_ctx); + PassProfile::ExitPass(); + return std::move(ret); +} + +String RenderPassProfiles() { + PassProfileThreadLocalEntry* entry = PassProfileThreadLocalStore::Get(); + CHECK(entry->profile_stack.empty()) << "cannot print pass profile while still in a pass!"; + + if (entry->root.children.empty()) { + LOG(WARNING) << "no passes have been profiled, did you enable pass profiling?"; + return String(); + } + + // (depth, parent_duration, pass) + std::stack> profiles; + + // push top level passes + PassProfile::Duration top_dur(0); + for (auto it = entry->root.children.begin(); it != entry->root.children.end(); ++it) { + top_dur += it->duration; + } + for (auto it = entry->root.children.rbegin(); it != entry->root.children.rend(); ++it) { + profiles.push(std::make_tuple(0, top_dur, &*it)); + } + + std::ostringstream os; + os << std::fixed; + + while (profiles.size() > 0) { + size_t depth; + PassProfile::Duration parent_duration; + PassProfile* profile; + std::tie(depth, parent_duration, profile) = profiles.top(); + profiles.pop(); + + // indent depth + for (size_t i = 0; i < depth; ++i) { + os << "\t"; + } + + // calculate time spent in pass itself (excluding sub-passes), and push children + PassProfile::Duration self_duration = profile->duration; + for (auto it = profile->children.rbegin(); it != profile->children.rend(); ++it) { + self_duration -= it->duration; + profiles.push(std::make_tuple(depth + 1, profile->duration, &*it)); + } + + double parent_pct = profile->duration.count() / parent_duration.count() * 100.0; + double total_pct = profile->duration.count() / top_dur.count() * 100.0; + + os << profile->name << ": "; + os << std::setprecision(0); + os << profile->duration.count() << "us [" << self_duration.count() << "us] "; + os << std::setprecision(2) << "(" << total_pct << "%; " << parent_pct << "%)\n"; + } + + return os.str(); +} + +TVM_REGISTER_GLOBAL("transform.render_pass_profiles").set_body_typed(RenderPassProfiles); + +TVM_REGISTER_GLOBAL("transform.clear_pass_profiles").set_body_typed([]() { + PassProfileThreadLocalStore::Get()->root.children.clear(); +}); + +TVM_REGISTER_GLOBAL("transform.enable_pass_profiling").set_body_typed([]() { + PassProfileThreadLocalStore::Get()->active = true; +}); + +TVM_REGISTER_GLOBAL("transform.disable_pass_profiling").set_body_typed([]() { + PassProfileThreadLocalStore::Get()->active = false; +}); + /*! * \brief Module-level passes are designed to implement global * analysis/optimizations, i.e. interprocedural optimizations (IPO), etc. Passes diff --git a/tests/python/relay/test_pass_profiler.py b/tests/python/relay/test_pass_profiler.py new file mode 100644 index 000000000000..acf6c8c50aff --- /dev/null +++ b/tests/python/relay/test_pass_profiler.py @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.relay +from tvm.relay import op + + +def test_pass_profiler(): + x, y, z = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xyz"] + e1 = op.add(x, y) + e2 = op.subtract(x, z) + e3 = op.multiply(e1, e1 / e2) + mod = tvm.IRModule.from_expr(e3 + e2) + + tvm.transform.enable_pass_profiling() + + mod = tvm.relay.transform.AnnotateSpans()(mod) + mod = tvm.relay.transform.ToANormalForm()(mod) + mod = tvm.relay.transform.InferType()(mod) + + profiles = tvm.transform.render_pass_profiles() + assert "AnnotateSpans" in profiles + assert "ToANormalForm" in profiles + assert "InferType" in profiles + + tvm.transform.clear_pass_profiles() + tvm.transform.disable_pass_profiling()