Skip to content

Commit

Permalink
[Cherry-pick] Simplify conv codes and fix cache and autotune bugs. (#…
Browse files Browse the repository at this point in the history
…47197)

* Simplify the codes of conv. (#45966)

* Enable to record whether the conv algo is got by exhaustive search to fix autotune cache bug. (#47065)
  • Loading branch information
Xreki authored Oct 20, 2022
1 parent 50d4fa5 commit c0ed872
Show file tree
Hide file tree
Showing 11 changed files with 222 additions and 260 deletions.
20 changes: 11 additions & 9 deletions paddle/fluid/operators/conv_base_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,10 @@ using framework::ConvSearchCache;
template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;

// As the basic for SearchAlgorithm struct.
template <typename PerfT>
struct SearchAlgorithm {};

// As the container of searchAlgorithm::Find() result.
template <typename AlgoT>
struct SearchResult {
SearchResult() {}
explicit SearchResult(const phi::autotune::DnnNode& node)
: algo(static_cast<AlgoT>(node.algo)),
workspace_size(node.workspace_size) {}

explicit SearchResult(AlgoT a) : algo(a) {}
explicit SearchResult(AlgoT a, float t, size_t size)
Expand All @@ -55,12 +48,21 @@ struct SearchResult {
AlgoT algo = static_cast<AlgoT>(0);
float time = -1.f;
size_t workspace_size = 0;
bool exhaustive_search = false;
};

template <typename T>
static std::ostream& operator<<(std::ostream& out, const std::vector<T>& v) {
out << "[";
for (auto const& tmp : v) out << tmp << ",";
bool is_first = true;
for (auto const& tmp : v) {
if (is_first) {
out << tmp;
is_first = false;
} else {
out << ", " << tmp;
}
}
out << "]";
return out;
}
Expand Down Expand Up @@ -113,7 +115,7 @@ struct ConvArgsBase {
auto w_shape = phi::vectorize(w->dims());
VLOG(10) << "[ConvArgs] x_dims=" << x_shape << ", w_dims=" << w_shape
<< ", strides=" << s << ", paddings=" << p << ", dilations=" << d
<< ",data= " << paddle::experimental::CppTypeToDataType<T>::Type()
<< ", data=" << paddle::experimental::CppTypeToDataType<T>::Type()
<< ", group=" << group
<< ", data layout=" << static_cast<int64_t>(data_layout);

Expand Down
Loading

0 comments on commit c0ed872

Please sign in to comment.