Skip to content

Commit

Permalink
reused workspace gpu memory
Browse files Browse the repository at this point in the history
  • Loading branch information
humingqing authored and humingqing committed Dec 7, 2023
1 parent dcd67aa commit fa9e8e8
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 53 deletions.
15 changes: 15 additions & 0 deletions paddle/phi/backends/gpu/gpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,14 @@ struct GPUContext::Impl {
}
}
}
// get workspace ptr
void* GetWorkSpacePtr(const size_t& len) {
if (workspace_ptr_ == nullptr || len > workspace_ptr_->size()) {
workspace_ptr_.reset();
workspace_ptr_ = allocator_->Allocate(len);
}
return workspace_ptr_->ptr();
}

// use one flag for all handles?
// they should be accessed consistently
Expand Down Expand Up @@ -786,6 +794,8 @@ struct GPUContext::Impl {
Allocator* allocator_{nullptr}; // external resource.
// A internal resouce to initinalize eigen_device.
std::unique_ptr<internal::EigenGpuStreamDevice> eigen_stream_{nullptr};
// work space
phi::Allocator::AllocationPtr workspace_ptr_{nullptr};
};

GPUContext::GPUContext(GPUContext&&) = default;
Expand Down Expand Up @@ -1006,4 +1016,9 @@ void GPUContext::SetDriverVersion(int val) { impl_->driver_version_ = val; }

void GPUContext::SetRuntimeVersion(int val) { impl_->runtime_version_ = val; }

// Get Work Space
void* GPUContext::GetWorkSpacePtr(const size_t& len) const {
return impl_->GetWorkSpacePtr(len);
}

} // namespace phi
3 changes: 3 additions & 0 deletions paddle/phi/backends/gpu/gpu_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ class PADDLE_API GPUContext : public DeviceContext {
// clear: whether clear the original CUDAStream or not
void SetCUDAStream(CUDAStream*, bool clear = true);

// Get Work Space
void* GetWorkSpacePtr(const size_t& len) const;

protected:
// NOTE: External users manage resources. Used in inference scenarios.
// The Set interface is for inference only, DeviceContext will mark the
Expand Down
91 changes: 49 additions & 42 deletions paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ limitations under the License. */
#include "paddle/phi/backends/dynload/cublasLt.h"
#include "paddle/phi/backends/gpu/cuda/cuda_helper.h"

#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/flags.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/autotune/cache.h"
#include "paddle/phi/kernels/autotune/gpu_timer.h"
#include "paddle/phi/kernels/autotune/switch_autotune.h"
Expand Down Expand Up @@ -63,23 +63,23 @@ enum MatmulFusedType {

static cublasLtEpilogue_t ConvertFusedType(MatmulFusedType fused_type) {
static std::map<MatmulFusedType, cublasLtEpilogue_t> fused_type_map = {
{MatmulFusedType::kMatmul, CUBLASLT_EPILOGUE_DEFAULT},
{MatmulFusedType::kMatmulGrad, CUBLASLT_EPILOGUE_DEFAULT},
{MatmulFusedType::kMatmulGradWithoutBias, CUBLASLT_EPILOGUE_DEFAULT},
{MatmulFusedType::kMatmulBias, CUBLASLT_EPILOGUE_BIAS},
{MatmulFusedType::kMatmulRelu, CUBLASLT_EPILOGUE_RELU},
{MatmulFusedType::kMatmulGelu, CUBLASLT_EPILOGUE_GELU},
{MatmulFusedType::kMatmulBiasRelu, CUBLASLT_EPILOGUE_RELU_BIAS},
{MatmulFusedType::kMatmulBiasGelu, CUBLASLT_EPILOGUE_GELU_BIAS},
{MatmulFusedType::kMatmulBiasReluWithReservedData,
CUBLASLT_EPILOGUE_RELU_AUX_BIAS},
{MatmulFusedType::kMatmulBiasGeluWithReservedData,
CUBLASLT_EPILOGUE_GELU_AUX_BIAS},
{MatmulFusedType::kMatmul, CUBLASLT_EPILOGUE_DEFAULT},
{MatmulFusedType::kMatmulGrad, CUBLASLT_EPILOGUE_DEFAULT},
{MatmulFusedType::kMatmulGradWithoutBias, CUBLASLT_EPILOGUE_DEFAULT},
{MatmulFusedType::kMatmulBias, CUBLASLT_EPILOGUE_BIAS},
{MatmulFusedType::kMatmulRelu, CUBLASLT_EPILOGUE_RELU},
{MatmulFusedType::kMatmulGelu, CUBLASLT_EPILOGUE_GELU},
{MatmulFusedType::kMatmulBiasRelu, CUBLASLT_EPILOGUE_RELU_BIAS},
{MatmulFusedType::kMatmulBiasGelu, CUBLASLT_EPILOGUE_GELU_BIAS},
{MatmulFusedType::kMatmulBiasReluWithReservedData,
CUBLASLT_EPILOGUE_RELU_AUX_BIAS},
{MatmulFusedType::kMatmulBiasGeluWithReservedData,
CUBLASLT_EPILOGUE_GELU_AUX_BIAS},
#if CUDA_VERSION >= 11060
{MatmulFusedType::kMatmulReluGrad, CUBLASLT_EPILOGUE_DRELU},
{MatmulFusedType::kMatmulGeluGrad, CUBLASLT_EPILOGUE_DGELU},
{MatmulFusedType::kMatmulBiasGradToA, CUBLASLT_EPILOGUE_BGRADA},
{MatmulFusedType::kMatmulBiasGradToB, CUBLASLT_EPILOGUE_BGRADB}
{MatmulFusedType::kMatmulReluGrad, CUBLASLT_EPILOGUE_DRELU},
{MatmulFusedType::kMatmulGeluGrad, CUBLASLT_EPILOGUE_DGELU},
{MatmulFusedType::kMatmulBiasGradToA, CUBLASLT_EPILOGUE_BGRADA},
{MatmulFusedType::kMatmulBiasGradToB, CUBLASLT_EPILOGUE_BGRADB}
#endif
};

Expand Down Expand Up @@ -230,7 +230,8 @@ struct MatmulDescriptor {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulDescDestroy(op_desc));
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutDestroy(y_desc));
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutDestroy(x_desc));
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutDestroy(out_desc));
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cublasLtMatrixLayoutDestroy(out_desc));
delete algo;

op_desc = nullptr;
Expand Down Expand Up @@ -483,7 +484,9 @@ struct CublasLtBase {
// I wonder is there any smarter idea for workspace setting, currently I
// just followed the settings from the NVIDIA colleague`s setting.
size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024;
phi::Allocator::AllocationPtr workspace = GetWorkspace(ctx, workspace_size);
// phi::Allocator::AllocationPtr workspace = GetWorkspace(ctx,
// workspace_size);
void* workspace_ptr = ctx.GetWorkSpacePtr(workspace_size);

if (planner != nullptr) {
if (phi::autotune::AutoTuneStatus::Instance().UseAutoTune() &&
Expand All @@ -496,7 +499,7 @@ struct CublasLtBase {
y_ptr,
x_ptr,
out_ptr,
workspace->ptr(),
workspace_ptr,
workspace_size);
MatmulDescT* best_desc = new MatmulDescT(*desc);
VLOG(6) << best_desc->GetDescResultString(
Expand All @@ -522,7 +525,7 @@ struct CublasLtBase {
out_ptr,
desc->out_desc,
desc->algo,
workspace->ptr(),
workspace_ptr,
workspace_size,
ctx.stream()));
}
Expand Down Expand Up @@ -674,7 +677,9 @@ struct CublasLtBase<int8_t, int32_t, MatmulDescriptor> {
cublasLtHandle_t cublaslt_handle = ctx.cublaslt_handle();

size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024;
phi::Allocator::AllocationPtr workspace = GetWorkspace(ctx, workspace_size);
// phi::Allocator::AllocationPtr workspace = GetWorkspace(ctx,
// workspace_size);
void* workspace_ptr = ctx.GetWorkSpacePtr(workspace_size);

if (planner != nullptr) {
if (phi::autotune::AutoTuneStatus::Instance().UseAutoTune() &&
Expand All @@ -687,7 +692,7 @@ struct CublasLtBase<int8_t, int32_t, MatmulDescriptor> {
y_ptr,
x_ptr,
out_ptr,
workspace->ptr(),
workspace_ptr,
workspace_size);
MatmulDescriptor* best_desc = new MatmulDescriptor(*desc);
VLOG(6) << best_desc->GetDescResultString(
Expand All @@ -713,7 +718,7 @@ struct CublasLtBase<int8_t, int32_t, MatmulDescriptor> {
out_ptr,
desc->out_desc,
desc->algo,
workspace->ptr(),
workspace_ptr,
workspace_size,
ctx.stream()));
}
Expand Down Expand Up @@ -1048,14 +1053,15 @@ struct LinearWithCublasLt : public CublasLtBase<T> {
const bool trans_x,
const bool trans_y,
const MatmulFusedType fused_type) {
auto planner = phi::funcs::MatmulPlanner(vectorize(x->dims()),
vectorize(y->dims()),
trans_x,
trans_y,
paddle::experimental::CppTypeToDataType<T>::Type(),
fused_type,
bias_data,
reserve_data);
auto planner = phi::funcs::MatmulPlanner(
vectorize(x->dims()),
vectorize(y->dims()),
trans_x,
trans_y,
paddle::experimental::CppTypeToDataType<T>::Type(),
fused_type,
bias_data,
reserve_data);
auto setter = DescriptorSetter<MatmulDescriptor, T>(
&planner, M, N, K, trans_x, trans_y);
CublasLtBase<T>::RunImpl(ctx,
Expand Down Expand Up @@ -1086,16 +1092,17 @@ struct LinearGradWithCublasLt : public CublasLtBase<T> {
const bool use_addto,
const bool no_exchange, // exchange x_desc and y_desc for grad.
bool grad_for_dx = true) {
auto planner = phi::funcs::MatmulPlanner(vectorize(x->dims()),
vectorize(y->dims()),
trans_x,
trans_y,
paddle::experimental::CppTypeToDataType<T>::Type(),
fused_type,
bias_data,
reserve_data,
use_addto,
no_exchange);
auto planner = phi::funcs::MatmulPlanner(
vectorize(x->dims()),
vectorize(y->dims()),
trans_x,
trans_y,
paddle::experimental::CppTypeToDataType<T>::Type(),
fused_type,
bias_data,
reserve_data,
use_addto,
no_exchange);
auto setter =
DescriptorSetter<MatmulGradDescriptor, T, DXT, DYT, TransX, TransY>(
&planner,
Expand Down
26 changes: 15 additions & 11 deletions paddle/phi/kernels/gpu/weight_only_linear_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,16 @@ we havenot support sm70 weightonly gemv, because sm70 weight layout is RowMajor.
CutlassFpAIntBGemmRunner<typename PDDataTypeTraits<T>::DataType,
uint8_t>();
int mixgemm_max_size = std::max(m, k);
DenseTensor mixgemm_workspace;
int64_t mixgemm_workspace_size_bytes = mixed_gemm_runner.getWorkspaceSize(
m, mixgemm_max_size, mixgemm_max_size);

mixgemm_workspace.Resize({mixgemm_workspace_size_bytes});
dev_ctx.template Alloc<uint8_t>(&mixgemm_workspace);
char* mixgemm_workspace_data =
reinterpret_cast<char*>(mixgemm_workspace.data<uint8_t>());
// DenseTensor mixgemm_workspace;
// mixgemm_workspace.Resize({mixgemm_workspace_size_bytes});
// dev_ctx.template Alloc<uint8_t>(&mixgemm_workspace);
// char* mixgemm_workspace_data =
// reinterpret_cast<char*>(mixgemm_workspace.data<uint8_t>());
char* mixgemm_workspace_data = reinterpret_cast<char*>(
dev_ctx.template GetWorkSpacePtr(mixgemm_workspace_size_bytes));
if (bias_data) {
mixed_gemm_runner.gemm_bias_act(
reinterpret_cast<const typename PDDataTypeTraits<T>::DataType*>(
Expand Down Expand Up @@ -108,14 +110,16 @@ we havenot support sm70 weightonly gemv, because sm70 weight layout is RowMajor.
CutlassFpAIntBGemmRunner<typename PDDataTypeTraits<T>::DataType,
cutlass::uint4b_t>();
int mixgemm_max_size = std::max(m, k);
DenseTensor mixgemm_workspace;

int64_t mixgemm_workspace_size_bytes = mixed_gemm_runner.getWorkspaceSize(
m, mixgemm_max_size, mixgemm_max_size);

mixgemm_workspace.Resize({mixgemm_workspace_size_bytes});
dev_ctx.template Alloc<uint8_t>(&mixgemm_workspace);
char* mixgemm_workspace_data =
reinterpret_cast<char*>(mixgemm_workspace.data<uint8_t>());
// DenseTensor mixgemm_workspace;
// mixgemm_workspace.Resize({mixgemm_workspace_size_bytes});
// dev_ctx.template Alloc<uint8_t>(&mixgemm_workspace);
// char* mixgemm_workspace_data =
// reinterpret_cast<char*>(mixgemm_workspace.data<uint8_t>());
char* mixgemm_workspace_data = reinterpret_cast<char*>(
dev_ctx.template GetWorkSpacePtr(mixgemm_workspace_size_bytes));
if (bias_data) {
mixed_gemm_runner.gemm_bias_act(
reinterpret_cast<const typename PDDataTypeTraits<T>::DataType*>(
Expand Down

0 comments on commit fa9e8e8

Please sign in to comment.