Skip to content

Commit

Permalink
Merge branch 'layla-build' into dry-sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
l3utterfly authored Apr 25, 2024
2 parents 7e08885 + a69169f commit 5a3059d
Show file tree
Hide file tree
Showing 28 changed files with 2,672 additions and 294 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ cmake-build-*
out/
tmp/

loras/*
models/*
models-mnt

Expand Down
100 changes: 22 additions & 78 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ if (LLAMA_METAL)
endif()

if (LLAMA_BLAS)
message(STATUS "Building with OpenBLAS")

if (LLAMA_STATIC)
set(BLA_STATIC ON)
endif()
Expand All @@ -303,77 +305,14 @@ if (LLAMA_BLAS)
endif()

set(BLA_VENDOR ${LLAMA_BLAS_VENDOR})
find_package(BLAS)

if (BLAS_FOUND)
message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}")

if ("${BLAS_INCLUDE_DIRS}" STREQUAL "")
# BLAS_INCLUDE_DIRS is missing in FindBLAS.cmake.
# see https://gitlab.kitware.com/cmake/cmake/-/issues/20268
find_package(PkgConfig REQUIRED)
if (${LLAMA_BLAS_VENDOR} MATCHES "Generic")
pkg_check_modules(DepBLAS REQUIRED blas)
elseif (${LLAMA_BLAS_VENDOR} MATCHES "OpenBLAS")
# As of openblas v0.3.22, the 64-bit is named openblas64.pc
pkg_check_modules(DepBLAS openblas64)
if (NOT DepBLAS_FOUND)
pkg_check_modules(DepBLAS REQUIRED openblas)
endif()
elseif (${LLAMA_BLAS_VENDOR} MATCHES "FLAME")
pkg_check_modules(DepBLAS REQUIRED blis)
elseif (${LLAMA_BLAS_VENDOR} MATCHES "ATLAS")
pkg_check_modules(DepBLAS REQUIRED blas-atlas)
elseif (${LLAMA_BLAS_VENDOR} MATCHES "FlexiBLAS")
pkg_check_modules(DepBLAS REQUIRED flexiblas_api)
elseif (${LLAMA_BLAS_VENDOR} MATCHES "Intel")
# all Intel* libraries share the same include path
pkg_check_modules(DepBLAS REQUIRED mkl-sdl)
elseif (${LLAMA_BLAS_VENDOR} MATCHES "NVHPC")
# this doesn't provide pkg-config
# suggest to assign BLAS_INCLUDE_DIRS on your own
if ("${NVHPC_VERSION}" STREQUAL "")
message(WARNING "Better to set NVHPC_VERSION")
else()
set(DepBLAS_FOUND ON)
set(DepBLAS_INCLUDE_DIRS "/opt/nvidia/hpc_sdk/${CMAKE_SYSTEM_NAME}_${CMAKE_SYSTEM_PROCESSOR}/${NVHPC_VERSION}/math_libs/include")
endif()
endif()
if (DepBLAS_FOUND)
set(BLAS_INCLUDE_DIRS ${DepBLAS_INCLUDE_DIRS})
else()
message(WARNING "BLAS_INCLUDE_DIRS neither been provided nor been automatically"
" detected by pkgconfig, trying to find cblas.h from possible paths...")
find_path(BLAS_INCLUDE_DIRS
NAMES cblas.h
HINTS
/usr/include
/usr/local/include
/usr/include/openblas
/opt/homebrew/opt/openblas/include
/usr/local/opt/openblas/include
/usr/include/x86_64-linux-gnu/openblas/include
)
endif()
endif()
add_compile_options(${BLAS_LINKER_FLAGS})

message(STATUS "BLAS found, Includes: ${BLAS_INCLUDE_DIRS}")
add_compile_definitions(GGML_USE_OPENBLAS)

add_compile_options(${BLAS_LINKER_FLAGS})
add_subdirectory(../OpenBLAS ${CMAKE_CURRENT_BINARY_DIR}/OpenBLAS)

add_compile_definitions(GGML_USE_OPENBLAS)

if (${BLAS_INCLUDE_DIRS} MATCHES "mkl" AND (${LLAMA_BLAS_VENDOR} MATCHES "Generic" OR ${LLAMA_BLAS_VENDOR} MATCHES "Intel"))
add_compile_definitions(GGML_BLAS_USE_MKL)
endif()

set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${BLAS_LIBRARIES})
set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${BLAS_INCLUDE_DIRS})
else()
message(WARNING "BLAS not found, please refer to "
"https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors"
" to set correct LLAMA_BLAS_VENDOR")
endif()
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} openblas_shared)
set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ../OpenBLAS ${CMAKE_CURRENT_BINARY_DIR}/OpenBLAS)
endif()

if (LLAMA_LLAMAFILE)
Expand Down Expand Up @@ -489,19 +428,24 @@ if (LLAMA_MPI)
endif()

if (LLAMA_CLBLAST)
find_package(CLBlast)
if (CLBlast_FOUND)
message(STATUS "CLBlast found")
message(STATUS "Building with CLBlast")

set(GGML_HEADERS_OPENCL ggml-opencl.h)
set(GGML_SOURCES_OPENCL ggml-opencl.cpp)
set(GGML_HEADERS_OPENCL ggml-opencl.h)
set(GGML_SOURCES_OPENCL ggml-opencl.cpp)

add_compile_definitions(GGML_USE_CLBLAST)
add_compile_definitions(GGML_USE_CLBLAST)

set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} clblast)
else()
message(WARNING "CLBlast not found")
endif()
# link our libOpenCL.so (this is only used during compile time)
add_library(OpenCL SHARED IMPORTED)
set_target_properties(OpenCL PROPERTIES IMPORTED_LOCATION ${PROJECT_SOURCE_DIR}/../OpenCL/lib/libOpenCL.so)

# add our prebuilt clblast library
add_library(clblast SHARED IMPORTED)
set_target_properties(clblast PROPERTIES IMPORTED_LOCATION ${PROJECT_SOURCE_DIR}/../../android/app/src/main/jniLibs/${ANDROID_ABI}/libclblast.so)

set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} clblast OpenCL)
set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ../CLBlast/include)
set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ../OpenCL/include)
endif()

if (LLAMA_VULKAN)
Expand Down
6 changes: 6 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -924,6 +924,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.cont_batching = true;
return true;
}
if (arg == "-fa" || arg == "--flash-attn") {
params.flash_attn = true;
return true;
}
if (arg == "--color") {
params.use_color = true;
return true;
Expand Down Expand Up @@ -1864,6 +1868,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.cb_eval = params.cb_eval;
cparams.cb_eval_user_data = params.cb_eval_user_data;
cparams.offload_kqv = !params.no_kv_offload;
cparams.flash_attn = params.flash_attn;

cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
Expand Down Expand Up @@ -2701,6 +2706,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed);
fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);

const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ struct gpt_params {
bool multiline_input = false; // reverse the usage of `\`
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
bool cont_batching = true; // insert new sequences for decoding on-the-fly
bool flash_attn = false; // flash attention

bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool ignore_eos = false; // ignore generated EOS tokens
Expand Down
22 changes: 18 additions & 4 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ void llama_sampling_free(struct llama_sampling_context * ctx) {
delete ctx;
}

void llama_sampling_reset(llama_sampling_context * ctx) {
void llama_sampling_reset_grammar(struct llama_sampling_context * ctx) {
if (ctx->grammar != NULL) {
llama_grammar_free(ctx->grammar);
ctx->grammar = NULL;
ctx->grammar = nullptr;
}

if (!ctx->parsed_grammar.rules.empty()) {
Expand All @@ -57,6 +57,10 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
grammar_rules.data(),
grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
}
}

void llama_sampling_reset(llama_sampling_context * ctx) {
llama_sampling_reset_grammar(ctx);

std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
ctx->cur.clear();
Expand Down Expand Up @@ -310,13 +314,12 @@ static llama_token_data_array llama_sampling_prepare_impl(

// DRY penalties (multiplier > 0 means enabled)
if(dry_multiplier > 0.0f) {
llama_sample_dry(&cur_p,
llama_sample_dry(&cur_p,
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length,
params.dry_sequence_breakers.data(), params.dry_sequence_breakers.size());
}


if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) {
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
Expand Down Expand Up @@ -366,3 +369,14 @@ void llama_sampling_accept(
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
}
}


void llama_sampling_rollback(
struct llama_sampling_context * ctx_sampling,
int rollback_num) {
if(rollback_num > ctx_sampling->prev.size()) {
rollback_num = ctx_sampling->prev.size();
}

ctx_sampling->prev.erase(ctx_sampling->prev.end() - rollback_num, ctx_sampling->prev.end());
}
7 changes: 7 additions & 0 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_

void llama_sampling_free(struct llama_sampling_context * ctx);

// Reset the sampler grammar without resetting the context
void llama_sampling_reset_grammar(struct llama_sampling_context * ctx);

// Reset the sampler context
// - clear prev tokens
// - reset grammar
Expand Down Expand Up @@ -149,3 +152,7 @@ void llama_sampling_accept(
struct llama_context * ctx_main,
llama_token id,
bool apply_grammar);

void llama_sampling_rollback(
struct llama_sampling_context * ctx_sampling,
int rollback_num);
28 changes: 17 additions & 11 deletions examples/batched-bench/batched-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ int main(int argc, char ** argv) {
gpt_params params;

if (argc == 1 || argv[1][0] == '-') {
printf("usage: %s MODEL_PATH [N_KV_MAX] [N_BATCH] [N_UBATCH] [IS_PP_SHARED] [NGL] <PP> <TG> <PL>\n" , argv[0]);
printf("usage: %s MODEL_PATH [N_KV_MAX] [N_BATCH] [N_UBATCH] [FATTN] [IS_PP_SHARED] [NGL] <PP> <TG> <PL>\n" , argv[0]);
printf(" <PP>, <TG> and PL are comma-separated lists of numbers without spaces\n\n");
printf(" example: %s ggml-model-f16.gguf 2048 2048 512 0 999 128,256,512 128,256 1,2,4,8,16,32\n\n", argv[0]);
return 1 ;
Expand All @@ -41,6 +41,7 @@ int main(int argc, char ** argv) {
int n_kv_max = 2048;
int n_batch = 2048;
int n_ubatch = 512;
bool flash_attn = false;
int is_pp_shared = 0;
int n_gpu_layers = 0;

Expand All @@ -66,23 +67,27 @@ int main(int argc, char ** argv) {
}

if (argc >= 6) {
is_pp_shared = std::atoi(argv[5]);
flash_attn = std::atoi(argv[5]);
}

if (argc >= 7) {
n_gpu_layers = std::atoi(argv[6]);
is_pp_shared = std::atoi(argv[6]);
}

if (argc >= 8) {
n_pp = parse_list(argv[7]);
n_gpu_layers = std::atoi(argv[7]);
}

if (argc >= 9) {
n_tg = parse_list(argv[8]);
n_pp = parse_list(argv[8]);
}

if (argc >= 10) {
n_pl = parse_list(argv[9]);
n_tg = parse_list(argv[9]);
}

if (argc >= 11) {
n_pl = parse_list(argv[10]);
}

// init LLM
Expand All @@ -108,10 +113,11 @@ int main(int argc, char ** argv) {

llama_context_params ctx_params = llama_context_default_params();

ctx_params.seed = 1234;
ctx_params.n_ctx = n_kv_max;
ctx_params.n_batch = n_batch;
ctx_params.n_ubatch = n_ubatch;
ctx_params.seed = 1234;
ctx_params.n_ctx = n_kv_max;
ctx_params.n_batch = n_batch;
ctx_params.n_ubatch = n_ubatch;
ctx_params.flash_attn = flash_attn;

ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
Expand Down Expand Up @@ -169,7 +175,7 @@ int main(int argc, char ** argv) {
}

LOG_TEE("\n");
LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, n_batch, n_ubatch, is_pp_shared, n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, n_batch, n_ubatch, flash_attn, is_pp_shared, n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
LOG_TEE("\n");

LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");
Expand Down
Loading

0 comments on commit 5a3059d

Please sign in to comment.