Skip to content

Commit

Permalink
Merge pull request #136 from acon96/release/v0.2.15
Browse files Browse the repository at this point in the history
Release v0.2.15
  • Loading branch information
acon96 authored May 4, 2024
2 parents 3e30ac9 + 26be7d7 commit 687f49f
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 27 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/create-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
fail-fast: false
matrix:
home_assistant_version: ["2023.12.4", "2024.2.1"]
arch: [aarch64, amd64, i386] # armhf
arch: [aarch64, armhf, amd64, i386]
suffix: [""]
include:
- home_assistant_version: "2024.2.1"
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ In order to facilitate running the project entirely on the system where Home Ass
## Version History
| Version | Description |
|---------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| v0.2.15 | Fix startup error when using llama.cpp backend and add flash attention to llama.cpp backend |
| v0.2.14 | Fix llama.cpp wheels + AVX detection |
| v0.2.13 | Add support for Llama 3, build llama.cpp wheels that are compatible with non-AVX systems, fix an error with exposing script entities, fix multiple small Ollama backend issues, and add basic multi-language support |
| v0.2.12 | Fix cover ICL examples, allow setting number of ICL examples, add min P and typical P sampler options, recommend models during setup, add JSON mode for Ollama backend, fix missing default options |
Expand Down
46 changes: 29 additions & 17 deletions custom_components/llama_conversation/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS,
CONF_PROMPT_TEMPLATE,
CONF_ENABLE_FLASH_ATTENTION,
CONF_USE_GBNF_GRAMMAR,
CONF_GBNF_GRAMMAR_FILE,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
Expand Down Expand Up @@ -75,6 +76,7 @@
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
DEFAULT_PROMPT_TEMPLATE,
DEFAULT_ENABLE_FLASH_ATTENTION,
DEFAULT_USE_GBNF_GRAMMAR,
DEFAULT_GBNF_GRAMMAR_FILE,
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
Expand Down Expand Up @@ -548,6 +550,7 @@ def _load_model(self, entry: ConfigEntry) -> None:
if not install_result == True:
raise ConfigEntryError("llama-cpp-python was not installed on startup and re-installing it led to an error!")

validate_llama_cpp_python_installation()
self.llama_cpp_module = importlib.import_module("llama_cpp")

Llama = getattr(self.llama_cpp_module, "Llama")
Expand All @@ -558,13 +561,15 @@ def _load_model(self, entry: ConfigEntry) -> None:
self.loaded_model_settings[CONF_BATCH_SIZE] = entry.options.get(CONF_BATCH_SIZE, DEFAULT_BATCH_SIZE)
self.loaded_model_settings[CONF_THREAD_COUNT] = entry.options.get(CONF_THREAD_COUNT, DEFAULT_THREAD_COUNT)
self.loaded_model_settings[CONF_BATCH_THREAD_COUNT] = entry.options.get(CONF_BATCH_THREAD_COUNT, DEFAULT_BATCH_THREAD_COUNT)
self.loaded_model_settings[CONF_ENABLE_FLASH_ATTENTION] = entry.options.get(CONF_ENABLE_FLASH_ATTENTION, DEFAULT_ENABLE_FLASH_ATTENTION)

self.llm = Llama(
model_path=self.model_path,
n_ctx=int(self.loaded_model_settings[CONF_CONTEXT_LENGTH]),
n_batch=int(self.loaded_model_settings[CONF_BATCH_SIZE]),
n_threads=int(self.loaded_model_settings[CONF_THREAD_COUNT]),
n_threads_batch=int(self.loaded_model_settings[CONF_BATCH_THREAD_COUNT])
n_threads_batch=int(self.loaded_model_settings[CONF_BATCH_THREAD_COUNT]),
flash_attn=self.loaded_model_settings[CONF_ENABLE_FLASH_ATTENTION],
)
_LOGGER.debug("Model loaded")

Expand Down Expand Up @@ -613,21 +618,24 @@ def _update_options(self):
if self.loaded_model_settings[CONF_CONTEXT_LENGTH] != self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH) or \
self.loaded_model_settings[CONF_BATCH_SIZE] != self.entry.options.get(CONF_BATCH_SIZE, DEFAULT_BATCH_SIZE) or \
self.loaded_model_settings[CONF_THREAD_COUNT] != self.entry.options.get(CONF_THREAD_COUNT, DEFAULT_THREAD_COUNT) or \
self.loaded_model_settings[CONF_BATCH_THREAD_COUNT] != self.entry.options.get(CONF_BATCH_THREAD_COUNT, DEFAULT_BATCH_THREAD_COUNT):
self.loaded_model_settings[CONF_BATCH_THREAD_COUNT] != self.entry.options.get(CONF_BATCH_THREAD_COUNT, DEFAULT_BATCH_THREAD_COUNT) or \
self.loaded_model_settings[CONF_ENABLE_FLASH_ATTENTION] != self.entry.options.get(CONF_ENABLE_FLASH_ATTENTION, DEFAULT_ENABLE_FLASH_ATTENTION):

_LOGGER.debug(f"Reloading model '{self.model_path}'...")
self.loaded_model_settings[CONF_CONTEXT_LENGTH] = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
self.loaded_model_settings[CONF_BATCH_SIZE] = self.entry.options.get(CONF_BATCH_SIZE, DEFAULT_BATCH_SIZE)
self.loaded_model_settings[CONF_THREAD_COUNT] = self.entry.options.get(CONF_THREAD_COUNT, DEFAULT_THREAD_COUNT)
self.loaded_model_settings[CONF_BATCH_THREAD_COUNT] = self.entry.options.get(CONF_BATCH_THREAD_COUNT, DEFAULT_BATCH_THREAD_COUNT)
self.loaded_model_settings[CONF_ENABLE_FLASH_ATTENTION] = self.entry.options.get(CONF_ENABLE_FLASH_ATTENTION, DEFAULT_ENABLE_FLASH_ATTENTION)

Llama = getattr(self.llama_cpp_module, "Llama")
self.llm = Llama(
model_path=self.model_path,
n_ctx=int(self.loaded_model_settings[CONF_CONTEXT_LENGTH]),
n_batch=int(self.loaded_model_settings[CONF_BATCH_SIZE]),
n_threads=int(self.loaded_model_settings[CONF_THREAD_COUNT]),
n_threads_batch=int(self.loaded_model_settings[CONF_BATCH_THREAD_COUNT])
n_threads_batch=int(self.loaded_model_settings[CONF_BATCH_THREAD_COUNT]),
flash_attn=self.loaded_model_settings[CONF_ENABLE_FLASH_ATTENTION],
)
_LOGGER.debug("Model loaded")
model_reloaded = True
Expand Down Expand Up @@ -894,15 +902,17 @@ def _generate(self, conversation: dict) -> str:
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"

result = requests.post(
f"{self.api_host}{endpoint}",
json=request_params,
timeout=timeout,
headers=headers,
)

try:
result = requests.post(
f"{self.api_host}{endpoint}",
json=request_params,
timeout=timeout,
headers=headers,
)

result.raise_for_status()
except requests.exceptions.Timeout:
return f"The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities."
except requests.RequestException as err:
_LOGGER.debug(f"Err was: {err}")
_LOGGER.debug(f"Request was: {request_params}")
Expand Down Expand Up @@ -1141,15 +1151,17 @@ def _generate(self, conversation: dict) -> str:
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"

result = requests.post(
f"{self.api_host}{endpoint}",
json=request_params,
timeout=timeout,
headers=headers,
)

try:
result = requests.post(
f"{self.api_host}{endpoint}",
json=request_params,
timeout=timeout,
headers=headers,
)

result.raise_for_status()
except requests.exceptions.Timeout:
return f"The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities."
except requests.RequestException as err:
_LOGGER.debug(f"Err was: {err}")
_LOGGER.debug(f"Request was: {request_params}")
Expand Down
7 changes: 7 additions & 0 deletions custom_components/llama_conversation/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
CONF_DOWNLOADED_MODEL_QUANTIZATION,
CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS,
CONF_PROMPT_TEMPLATE,
CONF_ENABLE_FLASH_ATTENTION,
CONF_USE_GBNF_GRAMMAR,
CONF_GBNF_GRAMMAR_FILE,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
Expand Down Expand Up @@ -93,6 +94,7 @@
DEFAULT_BACKEND_TYPE,
DEFAULT_DOWNLOADED_MODEL_QUANTIZATION,
DEFAULT_PROMPT_TEMPLATE,
DEFAULT_ENABLE_FLASH_ATTENTION,
DEFAULT_USE_GBNF_GRAMMAR,
DEFAULT_GBNF_GRAMMAR_FILE,
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
Expand Down Expand Up @@ -811,6 +813,11 @@ def local_llama_config_option_schema(options: MappingProxyType[str, Any], backen
description={"suggested_value": options.get(CONF_BATCH_THREAD_COUNT)},
default=DEFAULT_BATCH_THREAD_COUNT,
): NumberSelector(NumberSelectorConfig(min=1, max=(os.cpu_count() * 2), step=1)),
vol.Required(
CONF_ENABLE_FLASH_ATTENTION,
description={"suggested_value": options.get(CONF_ENABLE_FLASH_ATTENTION)},
default=DEFAULT_ENABLE_FLASH_ATTENTION,
): BooleanSelector(BooleanSelectorConfig()),
vol.Required(
CONF_USE_GBNF_GRAMMAR,
description={"suggested_value": options.get(CONF_USE_GBNF_GRAMMAR)},
Expand Down
5 changes: 4 additions & 1 deletion custom_components/llama_conversation/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@
"generation_prompt": "<|start_header_id|>assistant<|end_header_id|>\n\n"
}
}
CONF_ENABLE_FLASH_ATTENTION = "enable_flash_attention"
DEFAULT_ENABLE_FLASH_ATTENTION = False
CONF_USE_GBNF_GRAMMAR = "gbnf_grammar"
DEFAULT_USE_GBNF_GRAMMAR = False
CONF_GBNF_GRAMMAR_FILE = "gbnf_grammar_file"
Expand Down Expand Up @@ -178,6 +180,7 @@
CONF_TEMPERATURE: DEFAULT_TEMPERATURE,
CONF_REQUEST_TIMEOUT: DEFAULT_REQUEST_TIMEOUT,
CONF_PROMPT_TEMPLATE: DEFAULT_PROMPT_TEMPLATE,
CONF_ENABLE_FLASH_ATTENTION: DEFAULT_ENABLE_FLASH_ATTENTION,
CONF_USE_GBNF_GRAMMAR: DEFAULT_USE_GBNF_GRAMMAR,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS: DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
Expand Down Expand Up @@ -271,5 +274,5 @@
}
}

INTEGRATION_VERSION = "0.2.14"
INTEGRATION_VERSION = "0.2.15"
EMBEDDED_LLAMA_CPP_PYTHON_VERSION = "0.2.69"
2 changes: 1 addition & 1 deletion custom_components/llama_conversation/manifest.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"domain": "llama_conversation",
"name": "LLaMA Conversation",
"version": "0.2.14",
"version": "0.2.15",
"codeowners": ["@acon96"],
"config_flow": true,
"dependencies": ["conversation"],
Expand Down
2 changes: 2 additions & 0 deletions custom_components/llama_conversation/translations/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"ollama_json_mode": "JSON Output Mode",
"extra_attributes_to_expose": "Additional attribute to expose in the context",
"allowed_service_call_arguments": "Arguments allowed to be pass to service calls",
"enable_flash_attention": "Enable Flash Attention",
"gbnf_grammar": "Enable GBNF Grammar",
"gbnf_grammar_file": "GBNF Grammar Filename",
"openai_api_key": "API Key",
Expand Down Expand Up @@ -115,6 +116,7 @@
"ollama_json_mode": "JSON Output Mode",
"extra_attributes_to_expose": "Additional attribute to expose in the context",
"allowed_service_call_arguments": "Arguments allowed to be pass to service calls",
"enable_flash_attention": "Enable Flash Attention",
"gbnf_grammar": "Enable GBNF Grammar",
"gbnf_grammar_file": "GBNF Grammar Filename",
"openai_api_key": "API Key",
Expand Down
17 changes: 12 additions & 5 deletions custom_components/llama_conversation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import platform
import logging
import multiprocessing
import voluptuous as vol
import webcolors
from importlib.metadata import version
Expand Down Expand Up @@ -68,17 +69,23 @@ def download_model_from_hf(model_name: str, quantization_type: str, storage_fold
)

def _load_extension():
"""This needs to be at the root file level because we are using the 'spawn' start method"""
"""
Makes sure it is possible to load llama-cpp-python without crashing Home Assistant.
This needs to be at the root file level because we are using the 'spawn' start method.
Also ignore ModuleNotFoundError because that just means it's not installed. Not that it will crash HA
"""
import importlib
importlib.import_module("llama_cpp")
try:
importlib.import_module("llama_cpp")
except ModuleNotFoundError:
pass

def validate_llama_cpp_python_installation():
"""
Spawns another process and tries to import llama.cpp to avoid crashing the main process
"""
import multiprocessing
multiprocessing.set_start_method('spawn') # required because of aio
process = multiprocessing.Process(target=_load_extension)
mp_ctx = multiprocessing.get_context('spawn') # required because of aio
process = mp_ctx.Process(target=_load_extension)
process.start()
process.join()

Expand Down
5 changes: 5 additions & 0 deletions tests/llama_conversation/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS,
CONF_PROMPT_TEMPLATE,
CONF_ENABLE_FLASH_ATTENTION,
CONF_USE_GBNF_GRAMMAR,
CONF_GBNF_GRAMMAR_FILE,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
Expand Down Expand Up @@ -55,6 +56,7 @@
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
DEFAULT_PROMPT_TEMPLATE,
DEFAULT_ENABLE_FLASH_ATTENTION,
DEFAULT_USE_GBNF_GRAMMAR,
DEFAULT_GBNF_GRAMMAR_FILE,
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
Expand Down Expand Up @@ -208,6 +210,7 @@ async def test_local_llama_agent(local_llama_agent_fixture):
n_batch=local_llama_agent.entry.options.get(CONF_BATCH_SIZE),
n_threads=local_llama_agent.entry.options.get(CONF_THREAD_COUNT),
n_threads_batch=local_llama_agent.entry.options.get(CONF_BATCH_THREAD_COUNT),
flash_attn=local_llama_agent.entry.options.get(CONF_ENABLE_FLASH_ATTENTION)
)

all_mocks["tokenize"].assert_called_once()
Expand All @@ -231,6 +234,7 @@ async def test_local_llama_agent(local_llama_agent_fixture):
local_llama_agent.entry.options[CONF_THREAD_COUNT] = 24
local_llama_agent.entry.options[CONF_BATCH_THREAD_COUNT] = 24
local_llama_agent.entry.options[CONF_TEMPERATURE] = 2.0
local_llama_agent.entry.options[CONF_ENABLE_FLASH_ATTENTION] = True
local_llama_agent.entry.options[CONF_TOP_K] = 20
local_llama_agent.entry.options[CONF_TOP_P] = 0.9
local_llama_agent.entry.options[CONF_MIN_P] = 0.2
Expand All @@ -244,6 +248,7 @@ async def test_local_llama_agent(local_llama_agent_fixture):
n_batch=local_llama_agent.entry.options.get(CONF_BATCH_SIZE),
n_threads=local_llama_agent.entry.options.get(CONF_THREAD_COUNT),
n_threads_batch=local_llama_agent.entry.options.get(CONF_BATCH_THREAD_COUNT),
flash_attn=local_llama_agent.entry.options.get(CONF_ENABLE_FLASH_ATTENTION)
)

# do another turn of the same conversation
Expand Down
6 changes: 4 additions & 2 deletions tests/llama_conversation/test_config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS,
CONF_PROMPT_TEMPLATE,
CONF_ENABLE_FLASH_ATTENTION,
CONF_USE_GBNF_GRAMMAR,
CONF_GBNF_GRAMMAR_FILE,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
Expand Down Expand Up @@ -67,6 +68,7 @@
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
DEFAULT_PROMPT_TEMPLATE,
DEFAULT_ENABLE_FLASH_ATTENTION,
DEFAULT_USE_GBNF_GRAMMAR,
DEFAULT_GBNF_GRAMMAR_FILE,
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
Expand Down Expand Up @@ -304,7 +306,7 @@ def test_validate_options_schema():
options_llama_hf = local_llama_config_option_schema(None, BACKEND_TYPE_LLAMA_HF)
assert set(options_llama_hf.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, # llama.cpp specific
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, CONF_ENABLE_FLASH_ATTENTION, # llama.cpp specific
CONF_CONTEXT_LENGTH, # supports context length
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
CONF_PROMPT_CACHING_ENABLED, CONF_PROMPT_CACHING_INTERVAL # supports prompt caching
Expand All @@ -313,7 +315,7 @@ def test_validate_options_schema():
options_llama_existing = local_llama_config_option_schema(None, BACKEND_TYPE_LLAMA_EXISTING)
assert set(options_llama_existing.keys()) == set(universal_options + [
CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, # llama.cpp specific
CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, CONF_ENABLE_FLASH_ATTENTION, # llama.cpp specific
CONF_CONTEXT_LENGTH, # supports context length
CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF
CONF_PROMPT_CACHING_ENABLED, CONF_PROMPT_CACHING_INTERVAL # supports prompt caching
Expand Down

0 comments on commit 687f49f

Please sign in to comment.