Skip to content

Commit

Permalink
Update llama.cpp
Browse files Browse the repository at this point in the history
Fix build examples

Exclude examples directory

Revert cmake changes

Try actions/checkout@v4

Try to update submodules

Revert
  • Loading branch information
abetlen committed Nov 2, 2023
1 parent d7ae8b5 commit 54f040a
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 39 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
python-version: ["3.8", "3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
with:
submodules: "true"
- name: Set up Python ${{ matrix.python-version }}
Expand Down
58 changes: 47 additions & 11 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,14 @@ def __init__(
n_batch: int = 512,
n_threads: Optional[int] = None,
n_threads_batch: Optional[int] = None,
rope_scaling_type: Optional[int] = llama_cpp.LLAMA_ROPE_SCALING_UNSPECIFIED,
rope_freq_base: float = 0.0,
rope_freq_scale: float = 0.0,
yarn_ext_factor: float = float("nan"),
yarn_attn_factor: float = 1.0,
yarn_beta_fast: float = 32.0,
yarn_beta_slow: float = 1.0,
yarn_orig_ctx: int = 0,
mul_mat_q: bool = True,
f16_kv: bool = True,
logits_all: bool = False,
Expand All @@ -255,30 +261,30 @@ def __init__(
Args:
model_path: Path to the model.
seed: Random seed. -1 for random.
n_ctx: Maximum context size.
n_batch: Maximum number of prompt tokens to batch together when calling llama_eval.
n_gpu_layers: Number of layers to offload to GPU (-ngl). If -1, all layers are offloaded.
main_gpu: Main GPU to use.
tensor_split: Optional list of floats to split the model across multiple GPUs. If None, the model is not split.
main_gpu: The GPU that is used for scratch and small tensors.
tensor_split: How split tensors should be distributed across GPUs. If None, the model is not split.
vocab_only: Only load the vocabulary no weights.
use_mmap: Use mmap if possible.
use_mlock: Force the system to keep the model in RAM.
seed: Random seed. -1 for random.
n_ctx: Context size.
n_batch: Batch size for prompt processing (must be >= 32 to use BLAS)
n_threads: Number of threads to use. If None, the number of threads is automatically determined.
n_threads_batch: Number of threads to use for batch processing. If None, use n_threads.
rope_scaling_type: Type of rope scaling to use.
rope_freq_base: Base frequency for rope sampling.
rope_freq_scale: Scale factor for rope sampling.
low_vram: Use low VRAM mode.
mul_mat_q: if true, use experimental mul_mat_q kernels
f16_kv: Use half-precision for key/value cache.
logits_all: Return logits for all tokens, not just the last token.
vocab_only: Only load the vocabulary no weights.
use_mmap: Use mmap if possible.
use_mlock: Force the system to keep the model in RAM.
embedding: Embedding mode only.
n_threads: Number of threads to use. If None, the number of threads is automatically determined.
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
lora_path: Path to a LoRA file to apply to the model.
numa: Enable NUMA support. (NOTE: The initial value of this parameter is used for the remainder of the program as this value is set in llama_backend_init)
chat_format: String specifying the chat format to use when calling create_chat_completion.
verbose: Print verbose output to stderr.
kwargs: Unused keyword arguments (for additional backwards compatibility).
Raises:
ValueError: If the model path does not exist.
Expand Down Expand Up @@ -332,12 +338,30 @@ def __init__(
self.context_params.n_batch = self.n_batch
self.context_params.n_threads = self.n_threads
self.context_params.n_threads_batch = self.n_threads_batch
self.context_params.rope_scaling_type = (
rope_scaling_type if rope_scaling_type is not None else llama_cpp.LLAMA_ROPE_SCALING_UNSPECIFIED
)
self.context_params.rope_freq_base = (
rope_freq_base if rope_freq_base != 0.0 else 0
)
self.context_params.rope_freq_scale = (
rope_freq_scale if rope_freq_scale != 0.0 else 0
)
self.context_params.yarn_ext_factor = (
yarn_ext_factor if yarn_ext_factor != 0.0 else 0
)
self.context_params.yarn_attn_factor = (
yarn_attn_factor if yarn_attn_factor != 0.0 else 0
)
self.context_params.yarn_beta_fast = (
yarn_beta_fast if yarn_beta_fast != 0.0 else 0
)
self.context_params.yarn_beta_slow = (
yarn_beta_slow if yarn_beta_slow != 0.0 else 0
)
self.context_params.yarn_orig_ctx = (
yarn_orig_ctx if yarn_orig_ctx != 0 else 0
)
self.context_params.mul_mat_q = mul_mat_q
self.context_params.f16_kv = f16_kv
self.context_params.logits_all = logits_all
Expand Down Expand Up @@ -1671,8 +1695,14 @@ def __getstate__(self):
n_batch=self.n_batch,
n_threads=self.context_params.n_threads,
n_threads_batch=self.context_params.n_threads_batch,
rope_scaling_type=self.context_params.rope_scaling_type,
rope_freq_base=self.context_params.rope_freq_base,
rope_freq_scale=self.context_params.rope_freq_scale,
yarn_ext_factor=self.context_params.yarn_ext_factor,
yarn_attn_factor=self.context_params.yarn_attn_factor,
yarn_beta_fast=self.context_params.yarn_beta_fast,
yarn_beta_slow=self.context_params.yarn_beta_slow,
yarn_orig_ctx=self.context_params.yarn_orig_ctx,
mul_mat_q=self.context_params.mul_mat_q,
f16_kv=self.context_params.f16_kv,
logits_all=self.context_params.logits_all,
Expand Down Expand Up @@ -1709,6 +1739,12 @@ def __setstate__(self, state):
n_threads_batch=state["n_threads_batch"],
rope_freq_base=state["rope_freq_base"],
rope_freq_scale=state["rope_freq_scale"],
rope_scaling_type=state["rope_scaling_type"],
yarn_ext_factor=state["yarn_ext_factor"],
yarn_attn_factor=state["yarn_attn_factor"],
yarn_beta_fast=state["yarn_beta_fast"],
yarn_beta_slow=state["yarn_beta_slow"],
yarn_orig_ctx=state["yarn_orig_ctx"],
mul_mat_q=state["mul_mat_q"],
f16_kv=state["f16_kv"],
logits_all=state["logits_all"],
Expand Down
28 changes: 26 additions & 2 deletions llama_cpp/llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,18 @@ def _load_shared_library(lib_base_name: str):
LLAMA_FTYPE_MOSTLY_Q6_K = 18
LLAMA_FTYPE_GUESSED = 1024

# enum llama_rope_scaling_type {
# LLAMA_ROPE_SCALING_UNSPECIFIED = -1,
# LLAMA_ROPE_SCALING_NONE = 0,
# LLAMA_ROPE_SCALING_LINEAR = 1,
# LLAMA_ROPE_SCALING_YARN = 2,
# LLAMA_ROPE_SCALING_MAX_VALUE = LLAMA_ROPE_SCALING_YARN,
# };
LLAMA_ROPE_SCALING_UNSPECIFIED = -1
LLAMA_ROPE_SCALING_NONE = 0
LLAMA_ROPE_SCALING_LINEAR = 1
LLAMA_ROPE_SCALING_YARN = 2
LLAMA_ROPE_SCALING_MAX_VALUE = LLAMA_ROPE_SCALING_YARN

# typedef struct llama_token_data {
# llama_token id; // token id
Expand Down Expand Up @@ -308,10 +320,16 @@ class llama_model_params(Structure):
# uint32_t n_batch; // prompt processing maximum batch size
# uint32_t n_threads; // number of threads to use for generation
# uint32_t n_threads_batch; // number of threads to use for batch processing
# int8_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`

# // ref: https://github.com/ggerganov/llama.cpp/pull/2054
# float rope_freq_base; // RoPE base frequency, 0 = from model
# float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
# float rope_freq_base; // RoPE base frequency, 0 = from model
# float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
# float yarn_ext_factor; // YaRN extrapolation mix factor, NaN = from model
# float yarn_attn_factor; // YaRN magnitude scaling factor
# float yarn_beta_fast; // YaRN low correction dim
# float yarn_beta_slow; // YaRN high correction dim
# uint32_t yarn_orig_ctx; // YaRN original context size


# // Keep the booleans together to avoid misalignment during copy-by-value.
Expand All @@ -327,8 +345,14 @@ class llama_context_params(Structure):
("n_batch", c_uint32),
("n_threads", c_uint32),
("n_threads_batch", c_uint32),
("rope_scaling_type", c_int8),
("rope_freq_base", c_float),
("rope_freq_scale", c_float),
("yarn_ext_factor", c_float),
("yarn_attn_factor", c_float),
("yarn_beta_fast", c_float),
("yarn_beta_slow", c_float),
("yarn_orig_ctx", c_uint32),
("mul_mat_q", c_bool),
("f16_kv", c_bool),
("logits_all", c_bool),
Expand Down
94 changes: 70 additions & 24 deletions llama_cpp/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,7 @@ class Settings(BaseSettings):
default=None,
description="The alias of the model to use for generating completions.",
)
seed: int = Field(default=llama_cpp.LLAMA_DEFAULT_SEED, description="Random seed. -1 for random.")
n_ctx: int = Field(default=2048, ge=1, description="The context size.")
n_batch: int = Field(
default=512, ge=1, description="The batch size to use per eval."
)
# Model Params
n_gpu_layers: int = Field(
default=0,
ge=-1,
Expand All @@ -60,17 +56,6 @@ class Settings(BaseSettings):
default=None,
description="Split layers across multiple GPUs in proportion.",
)
rope_freq_base: float = Field(
default=0.0, description="RoPE base frequency"
)
rope_freq_scale: float = Field(
default=0.0, description="RoPE frequency scaling factor"
)
mul_mat_q: bool = Field(
default=True, description="if true, use experimental mul_mat_q kernels"
)
f16_kv: bool = Field(default=True, description="Whether to use f16 key/value.")
logits_all: bool = Field(default=True, description="Whether to return logits.")
vocab_only: bool = Field(
default=False, description="Whether to only return the vocabulary."
)
Expand All @@ -82,17 +67,59 @@ class Settings(BaseSettings):
default=llama_cpp.llama_mlock_supported(),
description="Use mlock.",
)
embedding: bool = Field(default=True, description="Whether to use embeddings.")
# Context Params
seed: int = Field(default=llama_cpp.LLAMA_DEFAULT_SEED, description="Random seed. -1 for random.")
n_ctx: int = Field(default=2048, ge=1, description="The context size.")
n_batch: int = Field(
default=512, ge=1, description="The batch size to use per eval."
)
n_threads: int = Field(
default=max(multiprocessing.cpu_count() // 2, 1),
ge=1,
description="The number of threads to use.",
)
n_threads_batch: int = Field(
default=max(multiprocessing.cpu_count() // 2, 1),
ge=0,
description="The number of threads to use when batch processing.",
)
rope_scaling_type: int = Field(
default=llama_cpp.LLAMA_ROPE_SCALING_UNSPECIFIED
)
rope_freq_base: float = Field(
default=0.0, description="RoPE base frequency"
)
rope_freq_scale: float = Field(
default=0.0, description="RoPE frequency scaling factor"
)
yarn_ext_factor: float = Field(
default=float("nan")
)
yarn_attn_factor: float = Field(
default=1.0
)
yarn_beta_fast: float = Field(
default=32.0
)
yarn_beta_slow: float = Field(
default=1.0
)
yarn_orig_ctx: int = Field(
default=0
)
mul_mat_q: bool = Field(
default=True, description="if true, use experimental mul_mat_q kernels"
)
f16_kv: bool = Field(default=True, description="Whether to use f16 key/value.")
logits_all: bool = Field(default=True, description="Whether to return logits.")
embedding: bool = Field(default=True, description="Whether to use embeddings.")
# Sampling Params
last_n_tokens_size: int = Field(
default=64,
ge=0,
description="Last n tokens to keep for repeat penalty calculation.",
)
# LoRA Params
lora_base: Optional[str] = Field(
default=None,
description="Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model."
Expand All @@ -101,14 +128,17 @@ class Settings(BaseSettings):
default=None,
description="Path to a LoRA file to apply to the model.",
)
# Backend Params
numa: bool = Field(
default=False,
description="Enable NUMA support.",
)
# Chat Format Params
chat_format: str = Field(
default="llama-2",
description="Chat format to use.",
)
# Cache Params
cache: bool = Field(
default=False,
description="Use a cache to reduce processing times for evaluated prompts.",
Expand All @@ -121,9 +151,11 @@ class Settings(BaseSettings):
default=2 << 30,
description="The size of the cache in bytes. Only used if cache is True.",
)
# Misc
verbose: bool = Field(
default=True, description="Whether to print debug information."
)
# Server Params
host: str = Field(default="localhost", description="Listen address")
port: int = Field(default=8000, description="Listen port")
interrupt_requests: bool = Field(
Expand Down Expand Up @@ -345,27 +377,41 @@ def create_app(settings: Optional[Settings] = None):
global llama
llama = llama_cpp.Llama(
model_path=settings.model,
seed=settings.seed,
n_ctx=settings.n_ctx,
n_batch=settings.n_batch,
# Model Params
n_gpu_layers=settings.n_gpu_layers,
main_gpu=settings.main_gpu,
tensor_split=settings.tensor_split,
vocab_only=settings.vocab_only,
use_mmap=settings.use_mmap,
use_mlock=settings.use_mlock,
# Context Params
seed=settings.seed,
n_ctx=settings.n_ctx,
n_batch=settings.n_batch,
n_threads=settings.n_threads,
n_threads_batch=settings.n_threads_batch,
rope_scaling_type=settings.rope_scaling_type,
rope_freq_base=settings.rope_freq_base,
rope_freq_scale=settings.rope_freq_scale,
yarn_ext_factor=settings.yarn_ext_factor,
yarn_attn_factor=settings.yarn_attn_factor,
yarn_beta_fast=settings.yarn_beta_fast,
yarn_beta_slow=settings.yarn_beta_slow,
yarn_orig_ctx=settings.yarn_orig_ctx,
mul_mat_q=settings.mul_mat_q,
f16_kv=settings.f16_kv,
logits_all=settings.logits_all,
vocab_only=settings.vocab_only,
use_mmap=settings.use_mmap,
use_mlock=settings.use_mlock,
embedding=settings.embedding,
n_threads=settings.n_threads,
# Sampling Params
last_n_tokens_size=settings.last_n_tokens_size,
# LoRA Params
lora_base=settings.lora_base,
lora_path=settings.lora_path,
# Backend Params
numa=settings.numa,
# Chat Format Params
chat_format=settings.chat_format,
# Misc
verbose=settings.verbose,
)
if settings.cache:
Expand Down
2 changes: 1 addition & 1 deletion vendor/llama.cpp
Submodule llama.cpp updated 49 files
+1 −1 .gitignore
+1 −34 CMakeLists.txt
+37 −34 Makefile
+17 −21 build.zig
+40 −2 common/CMakeLists.txt
+4 −0 common/build-info.cpp.in
+70 −18 common/common.cpp
+16 −3 common/common.h
+2 −1 convert-baichuan-hf-to-gguf.py
+48 −49 convert.py
+1 −4 examples/benchmark/CMakeLists.txt
+0 −1 examples/benchmark/benchmark-matmult.cpp
+0 −3 examples/embedding/CMakeLists.txt
+0 −1 examples/embedding/embedding.cpp
+3 −2 examples/finetune/finetune.cpp
+0 −3 examples/infill/CMakeLists.txt
+2 −3 examples/infill/infill.cpp
+0 −3 examples/llama-bench/CMakeLists.txt
+2 −3 examples/llama-bench/llama-bench.cpp
+0 −6 examples/llava/CMakeLists.txt
+0 −3 examples/main/CMakeLists.txt
+2 −3 examples/main/main.cpp
+0 −3 examples/parallel/CMakeLists.txt
+0 −2 examples/parallel/parallel.cpp
+0 −3 examples/perplexity/CMakeLists.txt
+0 −1 examples/perplexity/perplexity.cpp
+1 −1 examples/quantize-stats/CMakeLists.txt
+0 −1 examples/quantize-stats/quantize-stats.cpp
+1 −4 examples/quantize/CMakeLists.txt
+0 −1 examples/quantize/quantize.cpp
+0 −3 examples/save-load-state/CMakeLists.txt
+0 −1 examples/save-load-state/save-load-state.cpp
+0 −3 examples/server/CMakeLists.txt
+57 −7 examples/server/server.cpp
+0 −3 examples/speculative/CMakeLists.txt
+0 −2 examples/speculative/speculative.cpp
+3 −3 examples/train-text-from-scratch/train-text-from-scratch.cpp
+158 −77 ggml-cuda.cu
+47 −33 ggml-metal.m
+162 −34 ggml-metal.metal
+184 −116 ggml.c
+17 −3 ggml.h
+23 −6 gguf-py/gguf/gguf.py
+195 −75 llama.cpp
+16 −2 llama.h
+ models/ggml-vocab-llama.gguf
+18 −12 scripts/build-info.cmake
+0 −9 scripts/build-info.h.in
+4 −9 scripts/build-info.sh

0 comments on commit 54f040a

Please sign in to comment.