diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index 9a43f2f1ad95e..67ebda4b5b7a1 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -24,6 +24,7 @@ import logging import tvm +from tvm.autotvm.task.dispatcher import DispatchContext, FallbackContext from .task import create from .topi_integration import TaskExtractEnv @@ -140,6 +141,10 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No build_thread.start() build_thread.join() relay.backend.compile_engine.get().clear() + # Clear the warning message cache in FallbackContext + if isinstance(DispatchContext.current, FallbackContext): + DispatchContext.current.memory = {} + DispatchContext.warning_messages = set() logger.disabled = old_state diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index b57c0eb8cbdb4..93ef56d41b3fa 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -32,6 +32,7 @@ #include #include "../../target/source/codegen_source_base.h" +#include "compile_engine.h" #include "utils.h" namespace tvm { @@ -225,6 +226,8 @@ class RelayBuildModule : public runtime::ModuleNode { targets_ = targets; target_host_ = target_host; BuildRelay(mod, params_); + // Clear compile engine so that tuning schedules can be changed between runs. See issue #6096. + CompileEngine::Global()->Clear(); } protected: diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 2aae8546248fa..3c4faf79147c1 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -770,7 +770,7 @@ class CompileEngineImpl : public CompileEngineNode { }; /*! \brief The global compile engine */ -const CompileEngine& CompileEngine::Global() { +CompileEngine& CompileEngine::Global() { // intentionally allocate raw pointer to avoid // free during destructuion. static CompileEngine* inst = new CompileEngine(make_object()); diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index a5f3f6359f893..e392c79d8cdef 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -238,7 +238,7 @@ class CompileEngine : public ObjectRef { CompileEngineNode* operator->() { return static_cast(get_mutable()); } using ContainerType = CompileEngineNode; /*! \brief The global compile engine. */ - TVM_DLL static const CompileEngine& Global(); + TVM_DLL static CompileEngine& Global(); }; /*!