Skip to content

Commit

Permalink
Fix the buffer pointer order of GenericSearchWrw (#2509)
Browse files Browse the repository at this point in the history
* fixed buf order for igemm wrw
* fixed datatype of RunAndMeasureSolutionBase
  • Loading branch information
zjing14 authored Apr 5, 2020
1 parent 10e3035 commit 481d6b9
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 47 deletions.
12 changes: 6 additions & 6 deletions src/include/miopen/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -837,8 +837,8 @@ struct ConvHipImplicitGemmV4R4WrWXdlops : SolverBase<ConvolutionContext>
PerformanceImplicitGemmXdlops Search(const ConvolutionContext&) const;
int RunAndMeasureSolution(miopen::Handle& profile_h,
ConstData_t bot_buf,
Data_t top_buf,
ConstData_t wei_buf,
ConstData_t top_buf,
Data_t wei_buf,
ConstData_t bias_buf,
const ConvolutionContext& ctx,
const ConvSolution& solution,
Expand Down Expand Up @@ -922,8 +922,8 @@ struct ConvHipImplicitGemmV4R1WrW : SolverBase<ConvolutionContext>
PerformanceImplicitGemmV4R1 Search(const ConvolutionContext&) const;
int RunAndMeasureSolution(miopen::Handle& profile_h,
ConstData_t bot_buf,
Data_t top_buf,
ConstData_t wei_buf,
ConstData_t top_buf,
Data_t wei_buf,
ConstData_t bias_buf,
const ConvolutionContext& ctx,
const ConvSolution& solution,
Expand All @@ -943,8 +943,8 @@ struct ConvHipImplicitGemmV4WrW : SolverBase<ConvolutionContext>
PerformanceImplicitGemm Search(const ConvolutionContext&) const;
int RunAndMeasureSolution(miopen::Handle& profile_h,
ConstData_t bot_buf,
Data_t top_buf,
ConstData_t wei_buf,
ConstData_t top_buf,
Data_t wei_buf,
ConstData_t bias_buf,
const ConvolutionContext& ctx,
const ConvSolution& solution,
Expand Down
6 changes: 3 additions & 3 deletions src/solver/conv_hip_implicit_gemm_v4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,8 @@ int ConvHipImplicitGemmV4Fwd::RunAndMeasureSolution(miopen::Handle& profile_h,

int ConvHipImplicitGemmV4WrW::RunAndMeasureSolution(miopen::Handle& profile_h,
ConstData_t bot_buf,
Data_t top_buf,
ConstData_t wei_buf,
ConstData_t top_buf,
Data_t wei_buf,
ConstData_t bias_buf,
const ConvolutionContext& ctx,
const ConvSolution& solution,
Expand Down Expand Up @@ -451,7 +451,7 @@ PerformanceImplicitGemm ConvHipImplicitGemmV4Fwd::Search(const ConvolutionContex
}
PerformanceImplicitGemm ConvHipImplicitGemmV4WrW::Search(const ConvolutionContext& context) const
{
return GenericSearchFwd(*this, context);
return GenericSearchWrW(*this, context);
}

PerformanceImplicitGemm ConvHipImplicitGemmV4_1x1::Search(const ConvolutionContext& context) const
Expand Down
6 changes: 3 additions & 3 deletions src/solver/conv_hip_implicit_gemm_v4r1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ int ConvHipImplicitGemmV4R1Fwd::RunAndMeasureSolution(miopen::Handle& profile_h,

int ConvHipImplicitGemmV4R1WrW::RunAndMeasureSolution(miopen::Handle& profile_h,
ConstData_t bot_buf,
Data_t top_buf,
ConstData_t wei_buf,
ConstData_t top_buf,
Data_t wei_buf,
ConstData_t bias_buf,
const ConvolutionContext& ctx,
const ConvSolution& solution,
Expand All @@ -166,7 +166,7 @@ ConvHipImplicitGemmV4R1Fwd::Search(const ConvolutionContext& context) const
PerformanceImplicitGemmV4R1
ConvHipImplicitGemmV4R1WrW::Search(const ConvolutionContext& context) const
{
return GenericSearchFwd(*this, context);
return GenericSearchWrW(*this, context);
}

ConvSolution ConvHipImplicitGemmV4R1Fwd::GetSolution(const ConvolutionContext& ctx,
Expand Down
30 changes: 2 additions & 28 deletions src/solver/conv_hip_implicit_gemm_v4r4_gen_xdlops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,35 +341,9 @@ int ConvHipImplicitGemmV4R4GenWrWXdlops::RunAndMeasureSolution(miopen::Handle& p
{
assert(bias_buf == nullptr);
(void)bias_buf;
(void)ctx;

KernelInfo k_info = solution.construction_params[0];

#ifdef NDEBUG
try
#endif
{
elapsed_time = std::numeric_limits<float>::max();
auto kernel = profile_h.AddKernel("",
"",
k_info.kernel_file,
k_info.kernel_name,
k_info.l_wk,
k_info.g_wk,
k_info.comp_options);

kernel(bot_buf, top_buf, wei_buf);

elapsed_time = profile_h.GetKernelTime();
}
#ifdef NDEBUG
catch(miopen::Exception& ex)
{
MIOPEN_LOG_WE(ex.what());
return -1;
}
#endif
return 0;
return RunAndMeasureSolutionBase(
profile_h, bot_buf, top_buf, wei_buf, ctx, solution, elapsed_time);
}

bool ConvHipImplicitGemmV4R4GenFwdXdlops::IsApplicable(const ConvolutionContext& ctx) const
Expand Down
6 changes: 3 additions & 3 deletions src/solver/conv_hip_implicit_gemm_v4r4_xdlops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,8 @@ int ConvHipImplicitGemmV4R4Xdlops_1x1::RunAndMeasureSolution(miopen::Handle& pro

int ConvHipImplicitGemmV4R4WrWXdlops::RunAndMeasureSolution(miopen::Handle& profile_h,
ConstData_t bot_buf,
Data_t top_buf,
ConstData_t wei_buf,
ConstData_t top_buf,
Data_t wei_buf,
ConstData_t bias_buf,
const ConvolutionContext& ctx,
const ConvSolution& solution,
Expand Down Expand Up @@ -384,7 +384,7 @@ ConvHipImplicitGemmV4R4FwdXdlops::Search(const ConvolutionContext& ctx) const
PerformanceImplicitGemmXdlops
ConvHipImplicitGemmV4R4WrWXdlops::Search(const ConvolutionContext& ctx) const
{
return GenericSearchFwd(*this, ctx);
return GenericSearchWrW(*this, ctx);
}

} // namespace solver
Expand Down
9 changes: 5 additions & 4 deletions src/solver/implicitgemm_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -517,10 +517,11 @@ static inline size_t ComputeLDSRequiredSize(const ConvolutionContext& ctx,
return lds_size;
}

template <typename BotBufType, typename TopBufType, typename WeiBufType>
static inline int RunAndMeasureSolutionBase(miopen::Handle& profile_h,
ConstData_t bot_buf,
Data_t top_buf,
ConstData_t wei_buf,
BotBufType bot_buf,
TopBufType top_buf,
WeiBufType wei_buf,
const ConvolutionContext& ctx,
const ConvSolution& solution,
float& elapsed_time)
Expand All @@ -543,7 +544,7 @@ static inline int RunAndMeasureSolutionBase(miopen::Handle& profile_h,

if(ctx.direction.IsBackwardWrW())
{
kernel(bot_buf, top_buf, wei_buf);
kernel(top_buf, bot_buf, wei_buf);
}
if(ctx.direction.IsBackwardData())
{
Expand Down

0 comments on commit 481d6b9

Please sign in to comment.