Skip to content

Commit

Permalink
Update CUDNN Frontend API to v0.9.1 (#54949)
Browse files Browse the repository at this point in the history
* Update CUDNN Frontend API to v0.9.1
- Remove old patches
- Remove workarounds that are no longer needed

* Fix test_switch_autotune
  • Loading branch information
Tom-Zheng authored Jul 14, 2023
1 parent f1bffda commit 76b77d8
Show file tree
Hide file tree
Showing 6 changed files with 347 additions and 389 deletions.
18 changes: 8 additions & 10 deletions cmake/external/cudnn-frontend.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,24 @@ endif()

if((NOT DEFINED CUDNN_FRONTEND_NAME) OR (NOT DEFINED CUDNN_FRONTEND_URL))
set(CUDNN_FRONTEND_VER
"1.23.2"
"v0.9.1"
CACHE STRING "" FORCE)
set(CUDNN_FRONTEND_NAME
"cudnn-frontend"
CACHE STRING "" FORCE)
set(CUDNN_FRONTEND_URL
"https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v0.7.1.tar.gz"
"https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/${CUDNN_FRONTEND_VER}.tar.gz"
CACHE STRING "" FORCE)
set(CUDNN_FRONTEND_CACHE_FILENAME "v0.7.1.tar.gz")
endif()
set(CUDNN_FRONTEND_URL_MD5 "d8f911df571f8b0d40226efa9c0150c8")
set(CUDNN_FRONTEND_CACHE_FILENAME "${CUDNN_FRONTEND_VER}.tar.gz")
set(CUDNN_FRONTEND_URL_MD5 "da7cbad1305427f687dd4fd737178f80")

message(
STATUS
"CUDNN_FRONTEND_NAME: ${CUDNN_FRONTEND_NAME}, CUDNN_FRONTEND_URL: ${CUDNN_FRONTEND_URL}"
)
set(DIRENT_DOWNLOAD_DIR "${PADDLE_SOURCE_DIR}/third_party/cudnn-frontend")
# Version: v0.7.1
set(CUDNN_FRONTEND_DOWNLOAD_DIR
"${PADDLE_SOURCE_DIR}/third_party/cudnn-frontend")
set(CUDNN_FRONTEND_PREFIX_DIR ${THIRD_PARTY_PATH}/cudnn-frontend)
set(CUDNN_FRONTEND_SOURCE_DIR
${THIRD_PARTY_PATH}/cudnn-frontend/src/extern_cudnn_frontend/include)
Expand All @@ -55,7 +55,7 @@ include_directories(${CUDNN_FRONTEND_INCLUDE_DIR})

message(
STATUS
"Adding cudnn-frontend. Version: ${CUDNN_FRONTEND_VER}. Directory: ${DIRENT_DOWNLOAD_DIR}"
"Adding cudnn-frontend. Version: ${CUDNN_FRONTEND_VER}. Directory: ${CUDNN_FRONTEND_DOWNLOAD_DIR}"
)

function(download_cudnn_frontend)
Expand Down Expand Up @@ -99,9 +99,7 @@ ExternalProject_Add(
DOWNLOAD_DIR ${CUDNN_FRONTEND_DOWNLOAD_DIR}
DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND ""
PATCH_COMMAND
patch -d ${CUDNN_FRONTEND_SOURCE_DIR} -p2 <
${PADDLE_SOURCE_DIR}/patches/cudnn-frontend/0001-patch-for-paddle.patch
PATCH_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
Expand Down
114 changes: 81 additions & 33 deletions paddle/phi/kernels/autotune/cache_cudnn_frontend.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <map>
#include <mutex>
#include <string>
#include <thread>
#include <vector>

#include "paddle/phi/backends/dynload/cudnn_frontend.h"
Expand All @@ -34,7 +35,13 @@ class CudnnFrontendPlanCache {
saturation_count_ = FLAGS_cudnn_cache_saturation_count;
}

int64_t Size() const { return map_.size(); }
int64_t Size() const {
int64_t total_size = 0;
for (auto it = map_.begin(); it != map_.end(); it++) {
total_size += (it->second).size();
}
return total_size;
}

int64_t CacheHits() const { return cache_hits_; }

Expand All @@ -58,11 +65,12 @@ class CudnnFrontendPlanCache {
cache_misses_ = 0;
}

bool FindPlan(const cudnn_frontend::OperationGraph& op_graph,
bool use_addto = false) {
bool FindPlan(const cudnn_frontend::feature_vector_t &feature,
cudnnHandle_t handle) {
bool ret = false;
std::lock_guard<std::mutex> lock(*cache_mutex_);
if (map_.count(MakeKey(op_graph, use_addto)) > 0) {
auto &local_map = map_[hasher(std::this_thread::get_id())];
if (local_map.count(GetExtendedFeature(feature, handle)) > 0) {
cache_hits_++;
ret = true;
} else {
Expand All @@ -71,58 +79,98 @@ class CudnnFrontendPlanCache {
return ret;
}

cudnn_frontend::ManagedOpaqueDescriptor GetConfig(
const cudnn_frontend::OperationGraph& op_graph,
cudnnHandle_t handle,
bool use_addto = false) {
void GetPlan(const cudnn_frontend::feature_vector_t &feature,
const cudnn_frontend::ExecutionPlan **plan,
int64_t *workspace_size,
cudnnHandle_t handle) {
// Note(tizheng): CUDNNv8 execution plan is not thread-safe.
// A shared plan being executed by different threads is
// generally not safe (for now).
std::lock_guard<std::mutex> lock(*cache_mutex_);
auto engine_config = map_[MakeKey(op_graph, use_addto)];
return engine_config;
auto &local_map = map_[hasher(std::this_thread::get_id())];

auto it = local_map.find(GetExtendedFeature(feature, handle));
if (it == local_map.end()) {
PADDLE_THROW(phi::errors::InvalidArgument(
"[cudnn_frontend] Cached Plan Not Found."));
return;
}
*plan = &(it->second);
*workspace_size = (*plan)->getWorkspaceSize();
VLOG(4) << "Cached execution plan found." << (*plan)->getTag()
<< "; Require workspace: " << *workspace_size;
}

void InsertPlan(const cudnn_frontend::OperationGraph& op_graph,
const cudnn_frontend::ExecutionPlan& plan,
bool use_addto = false) {
VLOG(4) << "[cudnn_frontend] cache: Insert graph tag: "
<< op_graph.getTag();
void InsertPlan(const cudnn_frontend::feature_vector_t &feature,
const cudnn_frontend::ExecutionPlan &plan,
cudnnHandle_t handle) {
VLOG(4) << "[cudnn_frontend] cache: Insert plan: " << plan.getTag();
std::lock_guard<std::mutex> lock(*cache_mutex_);
map_.insert(
std::make_pair(MakeKey(op_graph, use_addto), plan.GetEngineConfig()));
auto &local_map = map_[hasher(std::this_thread::get_id())];
local_map.insert(std::make_pair(GetExtendedFeature(feature, handle), plan));
}

bool IsStable(const cudnn_frontend::OperationGraph& op_graph,
const std::string& tag,
bool use_addto = false) {
bool IsStable(const cudnn_frontend::feature_vector_t &feature,
const std::string &tag,
cudnnHandle_t handle) {
if (saturation_count_ == 1) {
return true;
}
std::lock_guard<std::mutex> lock(*cache_mutex_);
if (map_.count(MakeKey(op_graph, use_addto))) {
auto &local_map = map_[hasher(std::this_thread::get_id())];
auto &local_tracker = tracker_[hasher(std::this_thread::get_id())];
auto ext_feature = GetExtendedFeature(feature, handle);
if (local_map.count(ext_feature)) {
return false;
}
int cnt = tracker_[std::make_pair(MakeKey(op_graph, use_addto), tag)] += 1;
VLOG(4) << "[cudnn_frontend] SaturationTracker: " << op_graph.getTag()
<< " " << tag << " " << cnt;
int cnt = local_tracker[std::make_pair(ext_feature, tag)] += 1;
VLOG(4) << "[cudnn_frontend] SaturationTracker: " << tag << " " << cnt;
return cnt >= saturation_count_;
}

bool FindPlan(const cudnn_frontend::OperationGraph &op_graph,
cudnnHandle_t handle) {
return FindPlan(op_graph.getFeatureVector(), handle);
}

void GetPlan(const cudnn_frontend::OperationGraph &op_graph,
const cudnn_frontend::ExecutionPlan **plan,
int64_t *workspace_size,
cudnnHandle_t handle) {
GetPlan(op_graph.getFeatureVector(), plan, workspace_size, handle);
}

void InsertPlan(const cudnn_frontend::OperationGraph &op_graph,
const cudnn_frontend::ExecutionPlan &plan,
cudnnHandle_t handle) {
InsertPlan(op_graph.getFeatureVector(), plan, handle);
}

bool IsStable(const cudnn_frontend::OperationGraph &op_graph,
const std::string &tag,
cudnnHandle_t handle) {
return IsStable(op_graph.getFeatureVector(), tag, handle);
}

private:
static cudnn_frontend::feature_vector_t MakeKey(
const cudnn_frontend::OperationGraph& op_graph, bool use_addto) {
auto key = op_graph.getFeatureVector();
key.push_back(static_cast<uint64_t>(use_addto));
return key;
cudnn_frontend::feature_vector_t GetExtendedFeature(
cudnn_frontend::feature_vector_t feat, cudnnHandle_t handle) {
int64_t val = 0;
memcpy(&val, &handle, sizeof(int64_t));
feat.push_back(val);
return feat;
}
using FeatureVectorToPlanMap =
std::map<cudnn_frontend::feature_vector_t, cudnn_frontend::ExecutionPlan>;
std::map<std::size_t, FeatureVectorToPlanMap> map_;
std::hash<std::thread::id> hasher;

std::map<cudnn_frontend::feature_vector_t,
cudnn_frontend::ManagedOpaqueDescriptor>
map_;
std::shared_ptr<std::mutex> cache_mutex_;
int saturation_count_;

using SaturationTracker =
std::map<std::pair<cudnn_frontend::feature_vector_t, std::string>, int>;
SaturationTracker tracker_;
std::map<std::size_t, SaturationTracker> tracker_;

int64_t cache_hits_{0};
int64_t cache_misses_{0};
Expand Down
Loading

0 comments on commit 76b77d8

Please sign in to comment.