Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed all perf-db creations to be lazy #3212

Merged
merged 13 commits into from
Jan 20, 2025
Merged
11 changes: 6 additions & 5 deletions src/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -785,12 +785,13 @@ solver::ConvSolution MakeFusedSolution(const FusionContext& ctx,
const FusionDescription& problem,
const AnyInvokeParams& invoke_params)
{
decltype(auto) db = GetDb(ctx);
auto db_getter = MakeConvDbGetter(ctx);

solver::ConvSolution solution{miopenStatusInternalError};

GetAllFusionSolvers().FindById(id, [&](auto solver) {
solution = miopen::solver::FindSolution(
solver, ctx, problem, db, invoke_params, perf_cfg_override.value_or(""));
solver, ctx, problem, db_getter, invoke_params, perf_cfg_override.value_or(""));
});

return solution;
Expand Down Expand Up @@ -829,7 +830,7 @@ class FusionSolverFinder : public SolversFinderMixin<FusionDescription, FusionFi
const auto fusion_ctx = FusionContext(ctx);
return solvers.SearchForAllSolutions(fusion_ctx,
problem,
miopen::GetDb(ctx),
MakeConvDbGetter(ctx),
invoke_ctx,
std::numeric_limits<std::size_t>::max(),
options);
Expand Down Expand Up @@ -1038,9 +1039,9 @@ miopenStatus_t FusionPlanDescriptor::Compile(const Handle& handle)

GetAllFusionSolvers().FindById(id, [&](auto solver) {
const auto ctx = FusionContext{handle};
auto db = GetDb(ctx);
auto db_getter = MakeConvDbGetter(ctx);
const auto solution = solver::FindSolution(
solver, ctx, fusion_problem, db, {}); // auto tune is not expected here
solver, ctx, fusion_problem, db_getter, {}); // auto tune is not expected here
auto invoker =
handle.PrepareInvoker(*solution.invoker_factory, solution.construction_params);
// We register the invoker below
Expand Down
60 changes: 58 additions & 2 deletions src/include/miopen/any_solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
#include <memory>
#include <typeinfo>

#define FIN_ANY_SOLVER_FIND_SOLUTION_COMPAT

namespace miopen {
namespace solver {

Expand Down Expand Up @@ -87,6 +89,8 @@ struct AnySolver
return ptr_value->Type();
};
bool IsEmpty() const { return ptr_value == nullptr; };

#if defined(FIN_ANY_SOLVER_FIND_SOLUTION_COMPAT)
ConvSolution FindSolution(const ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
PerformanceDb& db,
Expand All @@ -95,7 +99,29 @@ struct AnySolver
{
assert(ptr_value != nullptr);
return ptr_value->FindSolution(ctx, problem, db, invoke_ctx, perf_cfg);
};
}
#endif

ConvSolution FindSolution(const ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
std::function<PerformanceDb&()>& db_getter,
const miopen::AnyInvokeParams& invoke_ctx,
const std::string& perf_cfg = "") const
{
assert(ptr_value != nullptr);
return ptr_value->FindSolution(ctx, problem, db_getter, invoke_ctx, perf_cfg);
}

ConvSolution FindSolution(const ExecutionContext& ctx,
DrizztDoUrden marked this conversation as resolved.
Show resolved Hide resolved
const miopen::conv::ProblemDescription& problem,
DbGetter& db_getter,
const miopen::AnyInvokeParams& invoke_ctx,
const std::string& perf_cfg = "") const
{
assert(ptr_value != nullptr);
return ptr_value->FindSolution(ctx, problem, db_getter, invoke_ctx, perf_cfg);
}

InvokerFactory GetInvokeFactory(const ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
const std::string& perf_cfg = "") const
Expand Down Expand Up @@ -154,6 +180,16 @@ struct AnySolver
PerformanceDb& db,
const miopen::AnyInvokeParams& invoke_ctx,
const std::string& perf_cfg) const = 0;
virtual ConvSolution FindSolution(const ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
std::function<PerformanceDb&()>& db_getter,
const miopen::AnyInvokeParams& invoke_ctx,
const std::string& perf_cfg) const = 0;
virtual ConvSolution FindSolution(const ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
DbGetter& db_getter,
const miopen::AnyInvokeParams& invoke_ctx,
const std::string& perf_cfg) const = 0;
virtual InvokerFactory GetInvokeFactory(const ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
const std::string& perf_cfg) const = 0;
Expand Down Expand Up @@ -308,7 +344,27 @@ struct AnySolver
const std::string& perf_cfg) const override
{
return miopen::solver::FindSolution(value, ctx, problem, db, invoke_ctx, perf_cfg);
};
}

ConvSolution FindSolution(const ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
std::function<PerformanceDb&()>& db_getter,
const miopen::AnyInvokeParams& invoke_ctx,
const std::string& perf_cfg) const override
{
return miopen::solver::FindSolution(
value, ctx, problem, db_getter, invoke_ctx, perf_cfg);
}

ConvSolution FindSolution(const ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
DbGetter& db_getter,
const miopen::AnyInvokeParams& invoke_ctx,
const std::string& perf_cfg) const override
{
return miopen::solver::FindSolution(
value, ctx, problem, db_getter, invoke_ctx, perf_cfg);
}

InvokerFactory GetInvokeFactory(const ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
Expand Down
65 changes: 65 additions & 0 deletions src/include/miopen/conv/db_getter.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2024 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/

#pragma once

#if MIOPEN_ENABLE_SQLITE && MIOPEN_USE_SQLITE_PERFDB
#include <miopen/sqlite_db.hpp>
#else
#include <miopen/readonlyramdb.hpp>
#include <miopen/ramdb.hpp>
#endif

#include <functional>
#include <optional>

namespace miopen {
struct ExecutionContext;

#if MIOPEN_ENABLE_SQLITE && MIOPEN_USE_SQLITE_PERFDB
using PerformanceDb = DbTimer<MultiFileDb<SQLitePerfDb, SQLitePerfDb, true>>;
#else
using PerformanceDb = DbTimer<MultiFileDb<ReadonlyRamDb, RamDb, true>>;
#endif

class [[nodiscard]] DbGetter final
{
public:
explicit DbGetter(std::function<PerformanceDb()>&& init_);

DbGetter(const DbGetter&) = delete;
auto operator=(const DbGetter&) -> DbGetter& = delete;

[[nodiscard]] auto operator()() -> PerformanceDb&;

private:
std::function<PerformanceDb()> init;
std::optional<PerformanceDb> db;
};

[[nodiscard]] MIOPEN_INTERNALS_EXPORT auto MakeConvDbGetter(const ExecutionContext& ctx)
-> DbGetter;
} // namespace miopen
11 changes: 1 addition & 10 deletions src/include/miopen/mlo_internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,7 @@ POSSIBILITY OF SUCH DAMAGE.
#include <miopen/ocldeviceinfo.hpp>
#endif
#include <miopen/db_path.hpp>
#if MIOPEN_ENABLE_SQLITE && MIOPEN_USE_SQLITE_PERFDB
#include <miopen/sqlite_db.hpp>
#else
#include <miopen/readonlyramdb.hpp>
#endif
#include <miopen/conv/db_getter.hpp>
#include <miopen/execution_context.hpp>
#include <miopen/handle.hpp>
#include <miopen/problem_description.hpp>
Expand Down Expand Up @@ -136,11 +132,6 @@ class DbTimer;

struct AnyInvokeParams;

#if MIOPEN_ENABLE_SQLITE && MIOPEN_USE_SQLITE_PERFDB
using PerformanceDb = DbTimer<MultiFileDb<SQLitePerfDb, SQLitePerfDb, true>>;
#else
using PerformanceDb = DbTimer<MultiFileDb<ReadonlyRamDb, RamDb, true>>;
#endif
MIOPEN_INTERNALS_EXPORT miopen::PerformanceDb GetDb(const miopen::ExecutionContext& ctx);

template <class TTo>
Expand Down
39 changes: 31 additions & 8 deletions src/mlo_dir_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,28 @@ static auto GetBwdWrW2DSolvers()

static auto GetFFTSolvers() { return miopen::solver::SolverContainer<miopen::solver::conv::fft>{}; }

miopen::DbGetter::DbGetter(std::function<PerformanceDb()>&& init_) : init(init_) {}

auto miopen::DbGetter::operator()() -> PerformanceDb&
{
if(!db)
db.emplace(init());

return *db;
}

auto miopen::MakeConvDbGetter(const ExecutionContext& ctx) -> DbGetter
{
return DbGetter{[&]() { return GetDb(ctx); }};
}

std::vector<miopen::solver::ConvSolution>
FindAllGemmSolutions(const miopen::ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
const miopen::AnyInvokeParams& invoke_ctx)
{
return GetGemmSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx);
return GetGemmSolvers().SearchForAllSolutions(
ctx, problem, miopen::MakeConvDbGetter(ctx), invoke_ctx);
}

std::vector<std::pair<std::string, size_t>>
Expand All @@ -221,7 +237,8 @@ FindAllDirectSolutions(const miopen::ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
const miopen::AnyInvokeParams& invoke_ctx)
{
return GetDirectSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx);
return GetDirectSolvers().SearchForAllSolutions(
ctx, problem, miopen::MakeConvDbGetter(ctx), invoke_ctx);
}

std::vector<std::pair<std::string, size_t>>
Expand Down Expand Up @@ -257,23 +274,26 @@ FindAllImplicitGemmSolutions(const miopen::ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
const miopen::AnyInvokeParams& invoke_ctx)
{
return GetImplicitGemmSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx);
return GetImplicitGemmSolvers().SearchForAllSolutions(
ctx, problem, miopen::MakeConvDbGetter(ctx), invoke_ctx);
}

std::vector<miopen::solver::ConvSolution>
FindAllWinogradSolutions(const miopen::ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
const miopen::AnyInvokeParams& invoke_ctx)
{
return GetWindogradSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx);
return GetWindogradSolvers().SearchForAllSolutions(
ctx, problem, miopen::MakeConvDbGetter(ctx), invoke_ctx);
}

std::vector<miopen::solver::ConvSolution>
FindWinogradWrWAllSolutions(const miopen::ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
const miopen::AnyInvokeParams& invoke_ctx)
{
return GetWindogradWrWSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx);
return GetWindogradWrWSolvers().SearchForAllSolutions(
ctx, problem, miopen::MakeConvDbGetter(ctx), invoke_ctx);
}

std::vector<std::pair<std::string, size_t>>
Expand All @@ -295,23 +315,26 @@ FindImplicitGemmWrWAllSolutions(const miopen::ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
const miopen::AnyInvokeParams& invoke_ctx)
{
return GetImplicitGemmWrWSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx);
return GetImplicitGemmWrWSolvers().SearchForAllSolutions(
ctx, problem, miopen::MakeConvDbGetter(ctx), invoke_ctx);
}

std::vector<miopen::solver::ConvSolution>
FindAllBwdWrW2DSolutions(const miopen::ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
const miopen::AnyInvokeParams& invoke_ctx)
{
return GetBwdWrW2DSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx);
return GetBwdWrW2DSolvers().SearchForAllSolutions(
ctx, problem, miopen::MakeConvDbGetter(ctx), invoke_ctx);
}

std::vector<miopen::solver::ConvSolution>
FindAllFFTSolutions(const miopen::ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem,
const miopen::AnyInvokeParams& invoke_ctx)
{
return GetFFTSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx);
return GetFFTSolvers().SearchForAllSolutions(
ctx, problem, miopen::MakeConvDbGetter(ctx), invoke_ctx);
}

std::vector<std::pair<std::string, size_t>>
Expand Down
2 changes: 1 addition & 1 deletion src/ocl/convolutionocl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ static Invoker PrepareInvoker(ExecutionContext ctx,
ctx.disable_search_enforce = true;

const auto solver = solver_id.GetSolver();
auto db = GetDb(ctx);
auto db = MakeConvDbGetter(ctx);
auto solution = solver.FindSolution(ctx, problem, db, {}); // auto tune is not expected here
auto& handle = ctx.GetStream();
auto invoker = handle.PrepareInvoker(*solution.invoker_factory, solution.construction_params);
Expand Down
2 changes: 1 addition & 1 deletion src/problem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,7 @@ std::vector<Solution> Problem::FindSolutionsImpl(const Handle& handle,

auto results =
FindConvolution(ctx, conv_problem, invoke_ctx, max_solutions, options.attach_binaries);
auto db = MakeConvDbGetter(ctx);

for(auto& result : results)
{
Expand All @@ -529,7 +530,6 @@ std::vector<Solution> Problem::FindSolutionsImpl(const Handle& handle,
// This would make binaries not serialized and invoker not cached.
// So we prepare them here.

auto db = GetDb(ctx);
const auto conv_solution =
result.GetSolver().GetSolver().FindSolution(ctx, conv_problem, db, invoke_ctx);

Expand Down
2 changes: 1 addition & 1 deletion src/solution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ void Solution::RunImpl(const Handle& handle,
auto conv_ctx = ExecutionContext{&handle};
conv_problem.SetupFloats(conv_ctx);

decltype(auto) db = GetDb(conv_ctx);
decltype(auto) db = MakeConvDbGetter(conv_ctx);
const auto conv_solution = GetSolver().GetSolver().FindSolution(
conv_ctx, conv_problem, db, invoke_ctx, perf_cfg.value_or(""));

Expand Down
Loading