From a4e90aae0e0b7a0626a28650f5d192ddc0a509f8 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Thu, 23 Jul 2020 15:52:32 -0700 Subject: [PATCH] [FIX] Fixes #6096 Clear the compile cache between module builds so that schedule changes will have an effect. Also, clear the warning cache so that schedule changes properly list untuned ops. --- python/tvm/autotvm/task/relay_integration.py | 5 +++++ src/relay/backend/build_module.cc | 3 +++ src/relay/backend/compile_engine.cc | 2 +- src/relay/backend/compile_engine.h | 2 +- 4 files changed, 10 insertions(+), 2 deletions(-) diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index 9a43f2f1ad95e..67ebda4b5b7a1 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -24,6 +24,7 @@ import logging import tvm +from tvm.autotvm.task.dispatcher import DispatchContext, FallbackContext from .task import create from .topi_integration import TaskExtractEnv @@ -140,6 +141,10 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No build_thread.start() build_thread.join() relay.backend.compile_engine.get().clear() + # Clear the warning message cache in FallbackContext + if isinstance(DispatchContext.current, FallbackContext): + DispatchContext.current.memory = {} + DispatchContext.warning_messages = set() logger.disabled = old_state diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index b57c0eb8cbdb4..93ef56d41b3fa 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -32,6 +32,7 @@ #include #include "../../target/source/codegen_source_base.h" +#include "compile_engine.h" #include "utils.h" namespace tvm { @@ -225,6 +226,8 @@ class RelayBuildModule : public runtime::ModuleNode { targets_ = targets; target_host_ = target_host; BuildRelay(mod, params_); + // Clear compile engine so that tuning schedules can be changed between runs. See issue #6096. + CompileEngine::Global()->Clear(); } protected: diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 2aae8546248fa..3c4faf79147c1 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -770,7 +770,7 @@ class CompileEngineImpl : public CompileEngineNode { }; /*! \brief The global compile engine */ -const CompileEngine& CompileEngine::Global() { +CompileEngine& CompileEngine::Global() { // intentionally allocate raw pointer to avoid // free during destructuion. static CompileEngine* inst = new CompileEngine(make_object()); diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index a5f3f6359f893..e392c79d8cdef 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -238,7 +238,7 @@ class CompileEngine : public ObjectRef { CompileEngineNode* operator->() { return static_cast(get_mutable()); } using ContainerType = CompileEngineNode; /*! \brief The global compile engine. */ - TVM_DLL static const CompileEngine& Global(); + TVM_DLL static CompileEngine& Global(); }; /*!