diff --git a/csrc/include/aiter_hip_common.h b/csrc/include/aiter_hip_common.h index 846f375326..7c40ea6cbd 100644 --- a/csrc/include/aiter_hip_common.h +++ b/csrc/include/aiter_hip_common.h @@ -5,6 +5,9 @@ #include #include #include +#ifdef AITER_EMBEDDED_HSA_HEADER +#include AITER_EMBEDDED_HSA_HEADER +#endif enum class GPUArch { @@ -67,6 +70,35 @@ struct AiterAsmKernelArgs }; static const std::string get_gpu_arch(); + +inline void load_asm_kernel(const char* name, + const char* hsaco, + hipModule_t& module, + hipFunction_t& kernel_func) +{ + const char* AITER_ASM_DIR = std::getenv("AITER_ASM_DIR"); + std::string arch_name = get_gpu_arch(); + if(AITER_ASM_DIR != nullptr) + { + std::string hsa_path = std::string(AITER_ASM_DIR) + "/" + arch_name + "/" + hsaco; + std::cout << "[aiter] hipModuleLoad: " << hsa_path << " GetFunction: " << name; + HIP_CALL(hipModuleLoad(&module, hsa_path.c_str())); + } + else + { +#if defined(AITER_EMBEDDED_HSA_HEADER) && defined(AITER_EMBEDDED_HSA_MAP) + std::string fname = "hsa/" + arch_name + "/" + hsaco; + auto hasco_obj = AITER_EMBEDDED_HSA_MAP.find(fname); + CHECK_COND(hasco_obj != AITER_EMBEDDED_HSA_MAP.end()); + CHECK_COND(hasco_obj->second.data() != nullptr); + std::cout << "hipModuleLoad: " << fname << " GetFunction: " << name << std::endl; + HIP_CALL(hipModuleLoadData(&module, hasco_obj->second.data())); +#endif + } + HIP_CALL(hipModuleGetFunction(&kernel_func, module, name)); + std::cout << " Success" << std::endl; +} + class AiterAsmKernel { private: @@ -76,14 +108,7 @@ class AiterAsmKernel public: AiterAsmKernel(const char* name, const char* hsaco) { - const char* AITER_ASM_DIR = std::getenv("AITER_ASM_DIR"); - std::string arch_name = get_gpu_arch(); - std::string hsa_path = std::string(AITER_ASM_DIR) + "/" + arch_name + "/" + hsaco; - std::cout << "[aiter] hipModuleLoad: " << hsa_path - << " GetFunction: " << name; - HIP_CALL(hipModuleLoad(&module, hsa_path.c_str())); - HIP_CALL(hipModuleGetFunction(&kernel_func, module, name)); - std::cout << " Success" << std::endl; + load_asm_kernel(name, hsaco, module, kernel_func); }; ~AiterAsmKernel() { HIP_CALL(hipModuleUnload(module)); } diff --git a/csrc/py_itfs_cu/asm_fmoe.cu b/csrc/py_itfs_cu/asm_fmoe.cu index e025bf9fa5..a983f7ca48 100755 --- a/csrc/py_itfs_cu/asm_fmoe.cu +++ b/csrc/py_itfs_cu/asm_fmoe.cu @@ -87,13 +87,7 @@ class FMoeKernel uint32_t sub_GU = 512, uint32_t num_persistent_tgs = 0) { - const char* AITER_ASM_DIR = std::getenv("AITER_ASM_DIR"); - std::string arch_name = get_gpu_arch(); - std::string hsa_path = std::string(AITER_ASM_DIR) + "/" + arch_name + "/" + hsaco; - std::cout << "[aiter] hipModuleLoad: " << hsa_path.c_str() << " GetFunction: " << name; - HIP_CALL(hipModuleLoad(&module, hsa_path.c_str())); - HIP_CALL(hipModuleGetFunction(&kernel_func, module, name)); - std::cout << " Success" << std::endl; + load_asm_kernel(name, hsaco, module, kernel_func); this->sub_GU = sub_GU; this->num_persistent_tgs = num_persistent_tgs; this->name = name;