Skip to content

Commit

Permalink
add version support for external lib
Browse files Browse the repository at this point in the history
  • Loading branch information
zhekunz2 committed Mar 26, 2024
1 parent 78c61c6 commit 10c957b
Show file tree
Hide file tree
Showing 14 changed files with 48 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class Operation;

struct ByreCustomConfig {
std::function<llvm::StringRef(llvm::StringRef)> getCustomLibPath;
std::function<llvm::StringRef(llvm::StringRef)> getCustomLibVersion;
std::function<llvm::StringRef(llvm::StringRef)> getApiName;
std::function<ArrayAttr(mhlo::CustomCallOp)> getExtraArgs;
};
Expand Down
4 changes: 3 additions & 1 deletion compiler/include/byteir/Dialect/Byre/ByreOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,15 @@ def Byre_CustomOp : Byre_Op<"custom",
let description = [{
Example:
```mlir
%2 = byre.custom(%0, %1) { lib_path = "xxx.so", api_name = "add", extra_args = [0 : i64, 1 : i64, 2.0 : f32] } : (f32, f32) -> f32
%2 = byre.custom(%0, %1) { lib_path = "xxx.so", api_name = "add", version = "1.0.0", extra_args = [0 : i64, 1 : i64, 2.0 : f32] } : (f32, f32) -> f32
```
During execution, "xxx.so" will be loaded, and "add" function will be called.
}];

let arguments = (ins
StrAttr:$lib_path,
StrAttr:$api_name,
StrAttr:$version,
Variadic<AnyType>:$operands,
ArrayAttr:$extra_args,
OptionalAttr<ArrayAttr>:$memory_effects
Expand All @@ -205,6 +206,7 @@ def Byre_CustomOp : Byre_Op<"custom",
let builders = [
OpBuilder<(ins "StringRef":$lib_path,
"StringRef":$api_name,
"StringRef":$version,
"ValueRange":$inputs,
"ValueRange":$outputs,
"ArrayAttr":$extra_args)>
Expand Down
11 changes: 10 additions & 1 deletion compiler/lib/Conversion/HloToByreTensor/HloToByreCustom.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ namespace {
constexpr StringRef getFlashAttnLibPath() {
return "external_libs/libs/libflash_attn.so";
}
constexpr StringRef getFlashAttnLibVersion() { return "2.5.3"; }
constexpr StringRef getFlashAttnFwdAPI() { return "run_flash_attn_fwd"; }
constexpr StringRef getFlashAttnBwdAPI() { return "run_flash_attn_bwd"; }
constexpr StringRef getFlashAttnKVCacheAPI() {
Expand All @@ -49,6 +50,12 @@ ByreCustomConfig mlir::getCudaByreCustomConfig() {
}
return StringRef("");
};
config.getCustomLibVersion = [=](StringRef callee) {
if (callee == getFlashAttnFwdName() || callee == getFlashAttnBwdName()) {
return getFlashAttnLibVersion();
}
return StringRef("");
};
config.getApiName = [=](StringRef callee) {
if (callee == getFlashAttnFwdName()) {
return getFlashAttnFwdAPI();
Expand Down Expand Up @@ -184,12 +191,14 @@ struct ConvertCustomCallOpToByreCustom : public RewritePattern {
auto libPath = converter.getCustomLibPath(callee);
if (libPath == "")
return failure();
auto version = converter.getCustomLibVersion(callee);
auto apiName = converter.getApiName(callee);
auto extraArgs = converter.getExtraArgs(customCallOp);

auto newOp = rewriter.create<byre::CustomOp>(
customCallOp.getLoc(), customCallOp.getResultTypes(), libPath, apiName,
customCallOp.getOperands(), extraArgs, /*memEffects*/ ArrayAttr{});
version, customCallOp.getOperands(), extraArgs,
/*memEffects*/ ArrayAttr{});
rewriter.replaceOp(op, newOp.getResults());
return success();
}
Expand Down
7 changes: 4 additions & 3 deletions compiler/lib/Dialect/Byre/IR/ByreDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -490,14 +490,15 @@ Value AliasOp::getViewSource() { return getSource(); }
//===----------------------------------------------------------------------===/

void CustomOp::build(OpBuilder &builder, OperationState &result,
StringRef lib_path, StringRef api_name, ValueRange inputs,
ValueRange outputs, ArrayAttr extra_args) {
StringRef lib_path, StringRef api_name, StringRef version,
ValueRange inputs, ValueRange outputs,
ArrayAttr extra_args) {
SmallVector<Attribute> memoryEffectAttrs;
memoryEffectAttrs.append(
inputs.size(), builder.getAttr<MemoryEffectAttr>(MemoryEffect::Read));
memoryEffectAttrs.append(
outputs.size(), builder.getAttr<MemoryEffectAttr>(MemoryEffect::Write));
build(builder, result, TypeRange{}, lib_path, api_name,
build(builder, result, TypeRange{}, lib_path, api_name, version,
llvm::to_vector(llvm::concat<Value>(llvm::to_vector(inputs),
llvm::to_vector(outputs))),
extra_args, builder.getArrayAttr(memoryEffectAttrs));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ struct ByreCustomOpBufferization

auto newOp = rewriter.create<byre::CustomOp>(
op->getLoc(), cast<byre::CustomOp>(op).getLibPath(),
cast<byre::CustomOp>(op).getApiName(), bufferOperands, bufferResults,
cast<byre::CustomOp>(op).getApiName(),
cast<byre::CustomOp>(op).getVersion(), bufferOperands, bufferResults,
cast<byre::CustomOp>(op).getExtraArgs());

for (auto &&namedAttr : op->getAttrs()) {
Expand Down
4 changes: 2 additions & 2 deletions external_libs/libs/libflash_attn.so
Git LFS file not shown
16 changes: 5 additions & 11 deletions external_libs/runtime/flash_attn/include/flash_api.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#pragma once

#include <cuda_runtime.h>
#include "cutlass/numeric_types.h"
#include "cutlass/half.h"
#include "cutlass/numeric_types.h"
#include <cuda_runtime.h>
#include <cutlass/cutlass.h>

void run_mha(void *q_ptr, void *k_ptr, void *v_ptr, void *o_ptr,
Expand Down Expand Up @@ -72,17 +72,11 @@ void run_mha_fwd_with_kvcache(

int window_size_left, int window_size_right, cudaStream_t stream);

#ifdef __cplusplus
extern "C" {
#endif
void run_flash_attn_fwd(void **tensors, void *extra_args,
cudaStream_t stream);
void run_flash_attn_fwd(void **tensors, void *extra_args, cudaStream_t stream);

void run_flash_attn_bwd(void **tensors, void *extra_args,
cudaStream_t stream);
void run_flash_attn_bwd(void **tensors, void *extra_args, cudaStream_t stream);

void run_flash_attn_kvcache(void **tensors, void *extra_args,
cudaStream_t stream);
#ifdef __cplusplus
cudaStream_t stream);
}
#endif
4 changes: 4 additions & 0 deletions external_libs/runtime/flash_attn/lib/flash_api.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
#include <iostream>
#include <algorithm>

extern "C" {
const char *VERSION = "2.5.3";
}

// for debug
void print_Qkv_params(Qkv_params &params) {
std::cout << "q_batch_stride: " << params.q_batch_stride << std::endl;
Expand Down
10 changes: 10 additions & 0 deletions runtime/lib/backends/cuda/providers/default/custom/custom.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "brt/core/framework/op_accessor.h"
#include "brt/core/ir/util.h"
#include "byteir/Dialect/Byre/ByreDialect.h"
#include <cstring>
#include <dlfcn.h>
#include <filesystem>
#include <vector>
Expand All @@ -38,13 +39,18 @@ CustomOpKernel::CustomOpKernel(const OpKernelInfo &info) : OpKernel(info) {
OpAccessor accessor(info_);
std::string lib_path = accessor.GetAttrAsString("lib_path");
std::string api_name = accessor.GetAttrAsString("api_name");
std::string version = accessor.GetAttrAsString("version");
custom_lib_hdl = dlopen(lib_path.c_str(), RTLD_LAZY | RTLD_GLOBAL);
std::string msg = std::string("Custom lib ") + lib_path + " load failed";
BRT_ENFORCE(custom_lib_hdl != nullptr, msg);
run_func_ = reinterpret_cast<decltype(run_func_)>(
dlsym(custom_lib_hdl, api_name.c_str()));
std::string api_msg = std::string("Couldn't find function: ") + api_name;
BRT_ENFORCE(run_func_ != NULL, api_msg);
void *lib_version = dlsym(custom_lib_hdl, "VERSION");
BRT_ENFORCE(lib_version != NULL, "Version doesn't exist in custom library!");
BRT_ENFORCE(strcmp((char *)lib_version, version.c_str()),
"Version doesn't match!");
}

int64_t getIntFromVoidPtr(void *data, size_t &pos) {
Expand Down Expand Up @@ -79,5 +85,9 @@ common::Status CustomOpKernel::RunImpl(const ExecutionContext &ctx) {
return common::Status::OK();
}

common::Status CustomOpKernel::EpiloguePerFrame(const ExecutionContext &ctx) {
dlclose(custom_lib_hdl);
return common::Status::OK();
}
} // namespace cuda
} // namespace brt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class CustomOpKernel final : public OpKernel {
typedef void (*CustomLibApiRun)(void **, void *, cudaStream_t);
explicit CustomOpKernel(const OpKernelInfo &info);
common::Status RunImpl(const ExecutionContext &) override;
common::Status EpiloguePerFrame(const ExecutionContext &) override;

private:
void *custom_lib_hdl;
Expand Down
4 changes: 2 additions & 2 deletions runtime/test/test_files/external_libs/libflash_attn.so
Git LFS file not shown
4 changes: 2 additions & 2 deletions runtime/test/test_files/flash_attn_bwd.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ module attributes {byre.container_module} {
%arg10 : memref<1x3x128xf32, "cuda"> {byre.argname = "d_SoftmaxLse", byre.argtype = 2: i32},
%arg11 : memref<1x3x128x32xf32, "cuda"> {byre.argname = "d_Q_accum", byre.argtype = 2: i32},
%arg12 : memref<1x3x128x128xf32, "cuda"> {byre.argname = "SoftmaxPtr", byre.argtype = 2: i32}) attributes {byre.entry_point} {
"byre.custom"(%arg1, %arg2, %arg3, %arg9, %arg4, %arg5, %arg12) {callee = "custom", lib_path = "test/test_files/external_libs/libflash_attn.so", api_name = "run_flash_attn_fwd", extra_args = [12288 : i64, 12288 : i64, 12288 : i64, 12288 : i64, 96 : i64, 96 : i64, 96 : i64, 96 : i64, 32 : i64, 32 : i64, 32 : i64, 32 : i64, 1 : i64, 3 : i64, 3 : i64, 32 : i64, 32 : i64, 0.5 : f32, 128 : i64, 128 : i64, 128 : i64, 128 : i64, 0.0 : f32, -1 : i64, 0 : i64]} : (memref<1x128x3x32xf16, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<2xi64, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<1x3x128xf32, "cuda">, memref<1x3x128x128xf32, "cuda">) -> ()
"byre.custom"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg9, %arg6, %arg7, %arg8, %arg10, %arg11) {callee = "custom", lib_path = "test/test_files/external_libs/libflash_attn.so", api_name = "run_flash_attn_bwd", extra_args = [12288 : i64, 12288 : i64, 12288 : i64, 12288 : i64, 96 : i64, 96 : i64, 96 : i64, 96 : i64, 32 : i64, 32 : i64, 32 : i64, 32 : i64, 1 : i64, 3 : i64, 3 : i64, 32 : i64, 32 : i64, 0.5 : f32, 128 : i64, 128 : i64, 128 : i64, 128 : i64, 0.0 : f32, -1 : i64, 0 : i64]} : (memref<1x128x3x32xf16, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<1x3x128xf32, "cuda">, memref<2xi64, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<1x3x128xf32, "cuda">, memref<1x3x128x32xf32, "cuda">) -> ()
"byre.custom"(%arg1, %arg2, %arg3, %arg9, %arg4, %arg5, %arg12) {callee = "custom", lib_path = "test/test_files/external_libs/libflash_attn.so", api_name = "run_flash_attn_fwd", version = "2.5.3", extra_args = [12288 : i64, 12288 : i64, 12288 : i64, 12288 : i64, 96 : i64, 96 : i64, 96 : i64, 96 : i64, 32 : i64, 32 : i64, 32 : i64, 32 : i64, 1 : i64, 3 : i64, 3 : i64, 32 : i64, 32 : i64, 0.5 : f32, 128 : i64, 128 : i64, 128 : i64, 128 : i64, 0.0 : f32, -1 : i64, 0 : i64]} : (memref<1x128x3x32xf16, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<2xi64, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<1x3x128xf32, "cuda">, memref<1x3x128x128xf32, "cuda">) -> ()
"byre.custom"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg9, %arg6, %arg7, %arg8, %arg10, %arg11) {callee = "custom", lib_path = "test/test_files/external_libs/libflash_attn.so", api_name = "run_flash_attn_bwd", version = "2.5.3", extra_args = [12288 : i64, 12288 : i64, 12288 : i64, 12288 : i64, 96 : i64, 96 : i64, 96 : i64, 96 : i64, 32 : i64, 32 : i64, 32 : i64, 32 : i64, 1 : i64, 3 : i64, 3 : i64, 32 : i64, 32 : i64, 0.5 : f32, 128 : i64, 128 : i64, 128 : i64, 128 : i64, 0.0 : f32, -1 : i64, 0 : i64]} : (memref<1x128x3x32xf16, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<1x3x128xf32, "cuda">, memref<2xi64, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<1x3x128xf32, "cuda">, memref<1x3x128x32xf32, "cuda">) -> ()
return
}
}
2 changes: 1 addition & 1 deletion runtime/test/test_files/flash_attn_fwd.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ module attributes {byre.container_module} {
%arg4 : memref<1x3x128xf32, "cuda"> {byre.argname = "SoftmaxLse", byre.argtype = 2: i32},
%arg5 : memref<1x3x128x128xf32, "cuda"> {byre.argname = "SoftmaxPtr", byre.argtype = 2: i32},
%arg6 : memref<2xi64, "cuda"> {byre.argname = "RngState", byre.argtype = 2: i32}) attributes {byre.entry_point} {
"byre.custom"(%arg0, %arg1, %arg2, %arg6, %arg3, %arg4, %arg5) {callee = "custom", lib_path = "test/test_files/external_libs/libflash_attn.so", api_name = "run_flash_attn_fwd", extra_args = [12288 : i64, 12288 : i64, 12288 : i64, 12288 : i64, 96 : i64, 96 : i64, 96 : i64, 96 : i64, 32 : i64, 32 : i64, 32 : i64, 32 : i64, 1 : i64, 3 : i64, 3 : i64, 32 : i64, 32 : i64, 0.5 : f32, 128 : i64, 128 : i64, 128 : i64, 128 : i64, 0.0 : f32, -1 : i64, 0 : i64]} : (memref<1x128x3x32xf16, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<2xi64, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<1x3x128xf32, "cuda">, memref<1x3x128x128xf32, "cuda">) -> ()
"byre.custom"(%arg0, %arg1, %arg2, %arg6, %arg3, %arg4, %arg5) {callee = "custom", lib_path = "test/test_files/external_libs/libflash_attn.so", api_name = "run_flash_attn_fwd", version = "2.5.3", extra_args = [12288 : i64, 12288 : i64, 12288 : i64, 12288 : i64, 96 : i64, 96 : i64, 96 : i64, 96 : i64, 32 : i64, 32 : i64, 32 : i64, 32 : i64, 1 : i64, 3 : i64, 3 : i64, 32 : i64, 32 : i64, 0.5 : f32, 128 : i64, 128 : i64, 128 : i64, 128 : i64, 0.0 : f32, -1 : i64, 0 : i64]} : (memref<1x128x3x32xf16, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<2xi64, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<1x3x128xf32, "cuda">, memref<1x3x128x128xf32, "cuda">) -> ()
return
}
}
2 changes: 1 addition & 1 deletion runtime/test/test_files/flash_attn_kvcache.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ module attributes {byre.container_module} {
%arg5: memref<2xi32, "cuda"> {byre.argname = "SeqLenK", byre.argtype = 2: i32},
%arg6: memref<2x1x3x32xf16, "cuda"> {byre.argname = "Output", byre.argtype = 2: i32},
%arg7 : memref<2x3x1xf32, "cuda"> {byre.argname = "SoftmaxLse", byre.argtype = 2: i32}) attributes {byre.entry_point} {
"byre.custom"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) {callee = "custom", lib_path = "test/test_files/external_libs/libflash_attn.so", api_name = "run_flash_attn_kvcache", extra_args = [96 : i64, 12288 : i64, 12288 : i64, 96 : i64, 96 : i64, 96 : i64, 96 : i64, 96 : i64, 96 : i64, 96 : i64, 96 : i64, 96 : i64, 32 : i64, 32 : i64, 32 : i64, 32 : i64, 32 : i64, 32 : i64, 2 : i64, 3 : i64, 3 : i64, 32 : i64, 32 : i64, 1 : i64, 0.5 : f32, 1 : i64, 128 : i64, 128 : i64, 128 : i64, -1 : i64, -1 : i64]} : (memref<2x1x3x32xf16, "cuda">, memref<2x128x3x32xf16, "cuda">, memref<2x128x3x32xf16, "cuda">, memref<2x1x3x32xf16, "cuda">, memref<2x1x3x32xf16, "cuda">, memref<2xi32, "cuda">, memref<2x1x3x32xf16, "cuda">, memref<2x3x1xf32, "cuda">) -> ()
"byre.custom"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) {callee = "custom", lib_path = "test/test_files/external_libs/libflash_attn.so", api_name = "run_flash_attn_kvcache", version = "2.5.3", extra_args = [96 : i64, 12288 : i64, 12288 : i64, 96 : i64, 96 : i64, 96 : i64, 96 : i64, 96 : i64, 96 : i64, 96 : i64, 96 : i64, 96 : i64, 32 : i64, 32 : i64, 32 : i64, 32 : i64, 32 : i64, 32 : i64, 2 : i64, 3 : i64, 3 : i64, 32 : i64, 32 : i64, 1 : i64, 0.5 : f32, 1 : i64, 128 : i64, 128 : i64, 128 : i64, -1 : i64, -1 : i64]} : (memref<2x1x3x32xf16, "cuda">, memref<2x128x3x32xf16, "cuda">, memref<2x128x3x32xf16, "cuda">, memref<2x1x3x32xf16, "cuda">, memref<2x1x3x32xf16, "cuda">, memref<2xi32, "cuda">, memref<2x1x3x32xf16, "cuda">, memref<2x3x1xf32, "cuda">) -> ()
return
}
}

0 comments on commit 10c957b

Please sign in to comment.